mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
--no-half-vae
This commit is contained in:
parent
a357823339
commit
7349088d32
@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
|
|||||||
|
|
||||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
dtype_vae = torch.float16
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
@ -59,9 +60,12 @@ def randn_without_seed(shape):
|
|||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
def autocast():
|
def autocast(disable=False):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
if disable:
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def decode_first_stage(model, x):
|
||||||
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||||
|
x = model.decode_first_stage(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_seed(seed):
|
def get_fixed_seed(seed):
|
||||||
if seed is None or seed == '' or seed == -1:
|
if seed is None or seed == '' or seed == -1:
|
||||||
return int(random.randrange(4294967294))
|
return int(random.randrange(4294967294))
|
||||||
@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype)
|
samples_ddim = samples_ddim.to(devices.dtype)
|
||||||
|
|
||||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
del samples_ddim
|
del samples_ddim
|
||||||
@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.scale_latent:
|
if self.scale_latent:
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
else:
|
else:
|
||||||
decoded_samples = self.sd_model.decode_first_stage(samples)
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
|
|
||||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
||||||
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
||||||
|
@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
||||||
if os.path.exists(vae_file):
|
if os.path.exists(vae_file):
|
||||||
@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict)
|
model.first_stage_model.load_state_dict(vae_dict)
|
||||||
|
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_file
|
model.sd_model_checkpoint = checkpoint_file
|
||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
@ -7,7 +7,7 @@ import inspect
|
|||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser, devices, processing
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
|
|||||||
|
|
||||||
|
|
||||||
def sample_to_image(samples):
|
def sample_to_image(samples):
|
||||||
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
|
x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
|
|||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
|
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
|
Loading…
Reference in New Issue
Block a user