mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Depth2img model support
This commit is contained in:
parent
44c46f0ed3
commit
1ed4f0e228
@ -135,6 +135,7 @@ The documentation was moved from this README over to the project's [wiki](https:
|
|||||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||||
- Swin2SR - https://github.com/mv-lab/swin2sr
|
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||||
|
- MiDaS - https://github.com/isl-org/MiDaS
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||||
|
@ -21,7 +21,10 @@ import modules.face_restoration
|
|||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import logging
|
import logging
|
||||||
|
from ldm.data.util import AddMiDaS
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||||
|
|
||||||
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
@ -150,11 +153,26 @@ class StableDiffusionProcessing():
|
|||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
|
def depth2img_image_conditioning(self, source_image):
|
||||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
transformer = AddMiDaS(model_type="dpt_hybrid")
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
||||||
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||||
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||||
|
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||||
|
conditioning = torch.nn.functional.interpolate(
|
||||||
|
self.sd_model.depth_model(midas_in),
|
||||||
|
size=conditioning_image.shape[2:],
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
(depth_min, depth_max) = torch.aminmax(conditioning)
|
||||||
|
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||||
self.is_using_inpainting_conditioning = True
|
self.is_using_inpainting_conditioning = True
|
||||||
|
|
||||||
# Handle the different mask inputs
|
# Handle the different mask inputs
|
||||||
@ -191,6 +209,18 @@ class StableDiffusionProcessing():
|
|||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||||
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||||
|
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||||
|
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||||
|
return self.depth2img_image_conditioning(source_image)
|
||||||
|
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -7,6 +7,9 @@ import torch
|
|||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from os import mkdir
|
||||||
|
from urllib import request
|
||||||
|
import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
@ -36,6 +39,7 @@ def setup_model():
|
|||||||
os.makedirs(model_path)
|
os.makedirs(model_path)
|
||||||
|
|
||||||
list_models()
|
list_models()
|
||||||
|
enable_midas_autodownload()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles():
|
||||||
@ -227,6 +231,48 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||||||
sd_vae.load_vae(model, vae_file)
|
sd_vae.load_vae(model, vae_file)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_midas_autodownload():
|
||||||
|
"""
|
||||||
|
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
||||||
|
|
||||||
|
When the 512-depth-ema model, and other future models like it, is loaded,
|
||||||
|
it calls midas.api.load_model to load the associated midas depth model.
|
||||||
|
This function applies a wrapper to download the model to the correct
|
||||||
|
location automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
midas_path = os.path.join(models_path, 'midas')
|
||||||
|
|
||||||
|
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||||
|
# a location that differs from where other scripts using this model look.
|
||||||
|
# HACK: Overriding the path here.
|
||||||
|
for k, v in midas.api.ISL_PATHS.items():
|
||||||
|
file_name = os.path.basename(v)
|
||||||
|
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
||||||
|
|
||||||
|
midas_urls = {
|
||||||
|
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||||
|
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||||
|
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
||||||
|
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
midas.api.load_model_inner = midas.api.load_model
|
||||||
|
|
||||||
|
def load_model_wrapper(model_type):
|
||||||
|
path = midas.api.ISL_PATHS[model_type]
|
||||||
|
if not os.path.exists(path):
|
||||||
|
if not os.path.exists(midas_path):
|
||||||
|
mkdir(midas_path)
|
||||||
|
|
||||||
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
|
print(f"{model_type} downloaded")
|
||||||
|
|
||||||
|
return midas.api.load_model_inner(model_type)
|
||||||
|
|
||||||
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
Loading…
Reference in New Issue
Block a user