From 51f81efb02876d24c9e6d844e8c0cbd2384f6514 Mon Sep 17 00:00:00 2001 From: InvincibleDude <81354513+InvincibleDude@users.noreply.github.com> Date: Wed, 1 Mar 2023 21:30:20 +0300 Subject: [PATCH] Image processing changes Image processing changes --- modules/processing.py | 76 +++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 43fac21c1..72dd3f6f5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -528,7 +528,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] if type(p) == StableDiffusionProcessingTxt2Img: - if p.enable_hr and p.hr_prompt != '': + if p.enable_hr and p.hr_prompt == '': + p.all_hr_prompts, p.all_hr_negative_prompts = p.all_prompts, p.all_negative_prompts + elif p.enable_hr and p.hr_prompt != '': if type(p.prompt) == list: p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt] else: @@ -555,14 +557,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() - _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1]) - if type(p) == StableDiffusionProcessingTxt2Img: - if p.enable_hr and p.hr_prompt != '': - _, hr_extra_network_data = extra_networks.parse_prompts(p.all_hr_prompts[0:1]) - if p.all_hr_prompts != p.all_prompts: - extra_network_data.update(hr_extra_network_data) - - if p.scripts is not None: p.scripts.process(p) @@ -600,13 +594,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() - if not p.disable_extra_networks: - extra_networks.activate(p, extra_network_data) - - with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - if state.job_count == -1: state.job_count = p.n_iter @@ -623,9 +610,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] if type(p) == StableDiffusionProcessingTxt2Img: - if p.enable_hr and p.hr_prompt != '': - hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size] - hr_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + if p.enable_hr: + if p.hr_prompt == '': + hr_prompts, hr_negative_prompts = prompts, negative_prompts + else: + hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size] + hr_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -633,19 +623,40 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0: break - prompts, _ = extra_networks.parse_prompts(prompts) + prompts, extra_network_data = extra_networks.parse_prompts(prompts) + if type(p) == StableDiffusionProcessingTxt2Img: + if p.enable_hr and hr_prompts != prompts: + _, hr_extra_network_data = extra_networks.parse_prompts(hr_prompts) + extra_network_data.update(hr_extra_network_data) + + + + if not p.disable_extra_networks: + with devices.autocast(): + extra_networks.activate(p, extra_network_data) if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) + # params.txt should be saved after scripts.process_batch, since the + # infotext could be modified by that callback + # Example: a wildcard processed by process_batch sets an extra model + # strength, which is saved as "Model Strength: 1.0" in the infotext + if n == 0: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) + if type(p) == StableDiffusionProcessingTxt2Img: - if p.enable_hr and p.hr_prompt != '': - hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, - cached_uc) - hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, - cached_c) + if p.enable_hr: + if prompts != hr_prompts: + hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, cached_uc) + hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, cached_c) + else: + hr_uc, hr_c = uc, c if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -658,20 +669,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): if type(p) == StableDiffusionProcessingTxt2Img: if p.enable_hr: - if p.hr_prompt != '': - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) - else: - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=c, - hr_unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, - subseed_strength=p.subseed_strength, prompts=prompts) - else: - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, - subseed_strength=p.subseed_strength, prompts=prompts) - + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) else: - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, - subseeds=subseeds, - subseed_strength=p.subseed_strength, prompts=prompts) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] for x in x_samples_ddim: