mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #294 from EliasOenal/master
Fixes for mps/Metal: use of seeds, img2img, CodeFormer
This commit is contained in:
commit
11e03b9abd
@ -47,6 +47,8 @@ def setup_codeformer():
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.net = None
|
self.net = None
|
||||||
self.face_helper = None
|
self.face_helper = None
|
||||||
|
if shared.device.type == 'mps': # CodeFormer currently does not support mps backend
|
||||||
|
shared.device_codeformer = torch.device('cpu')
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
|
|
||||||
@ -54,13 +56,13 @@ def setup_codeformer():
|
|||||||
self.net.to(shared.device)
|
self.net.to(shared.device)
|
||||||
return self.net, self.face_helper
|
return self.net, self.face_helper
|
||||||
|
|
||||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device)
|
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device_codeformer)
|
||||||
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
||||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device)
|
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device_codeformer)
|
||||||
|
|
||||||
self.net = net
|
self.net = net
|
||||||
self.face_helper = face_helper
|
self.face_helper = face_helper
|
||||||
@ -82,7 +84,7 @@ def setup_codeformer():
|
|||||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device_codeformer)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -103,17 +103,32 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
for i, seed in enumerate(seeds):
|
for i, seed in enumerate(seeds):
|
||||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||||
|
|
||||||
|
# Pytorch currently doesn't handle seeting randomness correctly when the metal backend is used.
|
||||||
|
generator = torch
|
||||||
|
if shared.device.type == 'mps':
|
||||||
|
shared.device_seed_type = 'cpu'
|
||||||
|
generator = torch.Generator(device=shared.device_seed_type)
|
||||||
|
|
||||||
subnoise = None
|
subnoise = None
|
||||||
if subseeds is not None:
|
if subseeds is not None:
|
||||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||||
torch.manual_seed(subseed)
|
generator.manual_seed(subseed)
|
||||||
|
|
||||||
|
if shared.device.type != shared.device_seed_type:
|
||||||
|
subnoise = torch.randn(noise_shape, generator=generator, device=shared.device_seed_type).to(shared.device)
|
||||||
|
else:
|
||||||
subnoise = torch.randn(noise_shape, device=shared.device)
|
subnoise = torch.randn(noise_shape, device=shared.device)
|
||||||
|
|
||||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||||
# but the original script had it like this, so I do not dare change it for now because
|
# but the original script had it like this, so I do not dare change it for now because
|
||||||
# it will break everyone's seeds.
|
# it will break everyone's seeds.
|
||||||
torch.manual_seed(seed)
|
# When using the mps backend falling back to the cpu device is needed, since mps currently
|
||||||
|
# does not implement seeding properly.
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
if shared.device.type != shared.device_seed_type:
|
||||||
|
noise = torch.randn(noise_shape, generator=generator, device=shared.device_seed_type).to(shared.device)
|
||||||
|
else:
|
||||||
noise = torch.randn(noise_shape, device=shared.device)
|
noise = torch.randn(noise_shape, device=shared.device)
|
||||||
|
|
||||||
if subnoise is not None:
|
if subnoise is not None:
|
||||||
@ -124,8 +139,10 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
|
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
|
||||||
# noise_shape = (64, 80)
|
# noise_shape = (64, 80)
|
||||||
# shape = (64, 72)
|
# shape = (64, 72)
|
||||||
|
generator.manual_seed(seed)
|
||||||
torch.manual_seed(seed)
|
if shared.device.type != shared.device_seed_type:
|
||||||
|
x = torch.randn(shape, generator=generator, device=shared.device_seed_type).to(shared.device)
|
||||||
|
else:
|
||||||
x = torch.randn(shape, device=shared.device)
|
x = torch.randn(shape, device=shared.device)
|
||||||
dx = (shape[2] - noise_shape[2]) // 2 # -4
|
dx = (shape[2] - noise_shape[2]) // 2 # -4
|
||||||
dy = (shape[1] - noise_shape[1]) // 2
|
dy = (shape[1] - noise_shape[1]) // 2
|
||||||
@ -465,7 +482,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
init_mask = latent_mask
|
init_mask = latent_mask
|
||||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
|
precision = np.float64
|
||||||
|
if shared.device.type == 'mps': # mps backend does not support float64
|
||||||
|
precision = np.float32
|
||||||
|
latmask = np.moveaxis(np.array(latmask, dtype=precision), 2, 0) / 255
|
||||||
latmask = latmask[0]
|
latmask = latmask[0]
|
||||||
latmask = np.around(latmask)
|
latmask = np.around(latmask)
|
||||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||||
|
@ -49,6 +49,8 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo
|
|||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
device_codeformer = device
|
||||||
|
device_seed_type = device
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
|
Loading…
Reference in New Issue
Block a user