mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge branch 'AUTOMATIC1111:master' into master
This commit is contained in:
commit
63391419c1
99
configs/instruct-pix2pix.yaml
Normal file
99
configs/instruct-pix2pix.yaml
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||||
|
# See more details in LICENSE.
|
||||||
|
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: edited
|
||||||
|
cond_stage_key: edit
|
||||||
|
# image_size: 64
|
||||||
|
# image_size: 32
|
||||||
|
image_size: 16
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: hybrid
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: true
|
||||||
|
load_ema: true
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 0 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 8
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 1
|
||||||
|
wrap: false
|
||||||
|
validation:
|
||||||
|
target: edit_dataset.EditDataset
|
||||||
|
params:
|
||||||
|
path: data/clip-filtered-dataset
|
||||||
|
cache_dir: data/
|
||||||
|
cache_name: data_10k
|
||||||
|
split: val
|
||||||
|
min_text_sim: 0.2
|
||||||
|
min_image_sim: 0.75
|
||||||
|
min_direction_sim: 0.2
|
||||||
|
max_samples_per_prompt: 1
|
||||||
|
min_resize_res: 512
|
||||||
|
max_resize_res: 512
|
||||||
|
crop_res: 512
|
||||||
|
output_as_edit: False
|
||||||
|
real_input: True
|
@ -1,8 +1,7 @@
|
|||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-4
|
base_learning_rate: 7.5e-05
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||||
params:
|
params:
|
||||||
parameterization: "v"
|
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
@ -12,29 +11,36 @@ model:
|
|||||||
cond_stage_key: "txt"
|
cond_stage_key: "txt"
|
||||||
image_size: 64
|
image_size: 64
|
||||||
channels: 4
|
channels: 4
|
||||||
cond_stage_trainable: false
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
conditioning_key: crossattn
|
conditioning_key: hybrid # important
|
||||||
monitor: val/loss_simple_ema
|
monitor: val/loss_simple_ema
|
||||||
scale_factor: 0.18215
|
scale_factor: 0.18215
|
||||||
use_ema: False # we set this to false because this is an inference only config
|
finetune_keys: null
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
use_checkpoint: True
|
|
||||||
use_fp16: True
|
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
in_channels: 4
|
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||||
out_channels: 4
|
out_channels: 4
|
||||||
model_channels: 320
|
model_channels: 320
|
||||||
attention_resolutions: [ 4, 2, 1 ]
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
num_res_blocks: 2
|
num_res_blocks: 2
|
||||||
channel_mult: [ 1, 2, 4, 4 ]
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
num_head_channels: 64 # need to fix for flash-attn
|
num_heads: 8
|
||||||
use_spatial_transformer: True
|
use_spatial_transformer: True
|
||||||
use_linear_in_transformer: True
|
|
||||||
transformer_depth: 1
|
transformer_depth: 1
|
||||||
context_dim: 1024
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
legacy: False
|
legacy: False
|
||||||
|
|
||||||
first_stage_config:
|
first_stage_config:
|
||||||
@ -43,7 +49,6 @@ model:
|
|||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
ddconfig:
|
ddconfig:
|
||||||
#attn_type: "vanilla-xformers"
|
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
resolution: 256
|
resolution: 256
|
||||||
@ -62,7 +67,4 @@ model:
|
|||||||
target: torch.nn.Identity
|
target: torch.nn.Identity
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
|
||||||
freeze: True
|
|
||||||
layer: "penultimate"
|
|
@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
from modules.sd_models import checkpoints_list
|
||||||
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -387,7 +388,7 @@ class Api:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
|
|||||||
hash: Optional[str] = Field(title="Short hash")
|
hash: Optional[str] = Field(title="Short hash")
|
||||||
sha256: Optional[str] = Field(title="sha256 hash")
|
sha256: Optional[str] = Field(title="sha256 hash")
|
||||||
filename: str = Field(title="Filename")
|
filename: str = Field(title="Filename")
|
||||||
config: str = Field(title="Config file")
|
config: Optional[str] = Field(title="Config file")
|
||||||
|
|
||||||
class HypernetworkItem(BaseModel):
|
class HypernetworkItem(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
|
@ -34,14 +34,18 @@ def get_cuda_device_string():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device_name():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device(get_cuda_device_string())
|
return get_cuda_device_string()
|
||||||
|
|
||||||
if has_mps():
|
if has_mps():
|
||||||
return torch.device("mps")
|
return "mps"
|
||||||
|
|
||||||
return cpu
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_device():
|
||||||
|
return torch.device(get_optimal_device_name())
|
||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
@ -139,6 +143,8 @@ def test_for_nans(x, where):
|
|||||||
else:
|
else:
|
||||||
message = "A tensor with all NaNs was produced."
|
message = "A tensor with all NaNs was produced."
|
||||||
|
|
||||||
|
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||||
|
|
||||||
raise NansException(message)
|
raise NansException(message)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -172,7 +172,7 @@ class StableDiffusionProcessing:
|
|||||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||||
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
|
||||||
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
||||||
conditioning = torch.nn.functional.interpolate(
|
conditioning = torch.nn.functional.interpolate(
|
||||||
self.sd_model.depth_model(midas_in),
|
self.sd_model.depth_model(midas_in),
|
||||||
@ -185,6 +185,11 @@ class StableDiffusionProcessing:
|
|||||||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
|
def edit_image_conditioning(self, source_image):
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||||
|
|
||||||
|
return conditioning_image
|
||||||
|
|
||||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||||
self.is_using_inpainting_conditioning = True
|
self.is_using_inpainting_conditioning = True
|
||||||
|
|
||||||
@ -212,7 +217,7 @@ class StableDiffusionProcessing:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Encode the new masked image using first stage of network.
|
# Encode the new masked image using first stage of network.
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||||
@ -228,6 +233,9 @@ class StableDiffusionProcessing:
|
|||||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||||
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
|
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
|
||||||
|
|
||||||
|
if self.sd_model.cond_stage_key == "edit":
|
||||||
|
return self.edit_image_conditioning(source_image)
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
@ -409,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
|
|
||||||
def decode_first_stage(model, x):
|
def decode_first_stage(model, x):
|
||||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||||
x = model.decode_first_stage(x)
|
x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -650,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
pp = scripts.PostprocessImageArgs(image)
|
||||||
|
p.scripts.postprocess_image(p, pp)
|
||||||
|
image = pp.image
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||||
@ -993,7 +1006,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
image = torch.from_numpy(batch_images)
|
||||||
image = 2. * image - 1.
|
image = 2. * image - 1.
|
||||||
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
|
image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
|
||||||
|
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
|
@ -6,12 +6,16 @@ from collections import namedtuple
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessImageArgs:
|
||||||
|
def __init__(self, image):
|
||||||
|
self.image = image
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
filename = None
|
filename = None
|
||||||
args_from = None
|
args_from = None
|
||||||
@ -65,7 +69,7 @@ class Script:
|
|||||||
args contains all values returned by components from ui()
|
args contains all values returned by components from ui()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError()
|
pass
|
||||||
|
|
||||||
def process(self, p, *args):
|
def process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
@ -100,6 +104,13 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||||
|
"""
|
||||||
|
Called for every image after it has been generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess(self, p, processed, *args):
|
||||||
"""
|
"""
|
||||||
This function is called after processing ends for AlwaysVisible scripts.
|
This function is called after processing ends for AlwaysVisible scripts.
|
||||||
@ -247,11 +258,15 @@ class ScriptRunner:
|
|||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
|
from modules import scripts_auto_postprocessing
|
||||||
|
|
||||||
self.scripts.clear()
|
self.scripts.clear()
|
||||||
self.alwayson_scripts.clear()
|
self.alwayson_scripts.clear()
|
||||||
self.selectable_scripts.clear()
|
self.selectable_scripts.clear()
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in scripts_data:
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
|
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||||
script = script_class()
|
script = script_class()
|
||||||
script.filename = path
|
script.filename = path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
@ -332,7 +347,7 @@ class ScriptRunner:
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def run(self, p: StableDiffusionProcessing, *args):
|
def run(self, p, *args):
|
||||||
script_index = args[0]
|
script_index = args[0]
|
||||||
|
|
||||||
if script_index == 0:
|
if script_index == 0:
|
||||||
@ -386,6 +401,15 @@ class ScriptRunner:
|
|||||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_image(p, pp, *script_args)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
|
42
modules/scripts_auto_postprocessing.py
Normal file
42
modules/scripts_auto_postprocessing.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from modules import scripts, scripts_postprocessing, shared
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||||
|
def __init__(self, script_postproc):
|
||||||
|
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||||
|
self.postprocessing_controls = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return self.script.name
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
self.postprocessing_controls = self.script.ui()
|
||||||
|
return self.postprocessing_controls.values()
|
||||||
|
|
||||||
|
def postprocess_image(self, p, script_pp, *args):
|
||||||
|
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||||
|
|
||||||
|
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||||
|
pp.info = {}
|
||||||
|
self.script.process(pp, **args_dict)
|
||||||
|
p.extra_generation_params.update(pp.info)
|
||||||
|
script_pp.image = pp.image
|
||||||
|
|
||||||
|
|
||||||
|
def create_auto_preprocessing_script_data():
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||||
|
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||||
|
if script is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||||
|
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||||
|
|
||||||
|
return res
|
@ -46,6 +46,8 @@ class ScriptPostprocessing:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
|
|||||||
script: ScriptPostprocessing = script_class()
|
script: ScriptPostprocessing = script_class()
|
||||||
script.filename = path
|
script.filename = path
|
||||||
|
|
||||||
|
if script.name == "Simple Upscale":
|
||||||
|
continue
|
||||||
|
|
||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
|
|
||||||
def create_script_ui(self, script, inputs):
|
def create_script_ui(self, script, inputs):
|
||||||
@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
|
|||||||
import modules.scripts
|
import modules.scripts
|
||||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||||
|
|
||||||
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
|
scripts_order = shared.opts.postprocessing_operation_order
|
||||||
|
|
||||||
def script_score(name):
|
def script_score(name):
|
||||||
name = name.lower()
|
|
||||||
for i, possible_match in enumerate(scripts_order):
|
for i, possible_match in enumerate(scripts_order):
|
||||||
if possible_match in name:
|
if possible_match == name:
|
||||||
return i
|
return i
|
||||||
|
|
||||||
return len(self.scripts)
|
return len(self.scripts)
|
||||||
@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
|
|||||||
def image_changed(self):
|
def image_changed(self):
|
||||||
for script in self.scripts_in_preferred_order():
|
for script in self.scripts_in_preferred_order():
|
||||||
script.image_changed()
|
script.image_changed()
|
||||||
|
|
||||||
|
@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
return x_prev, pred_x0, e_t
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
|
||||||
from modules import sd_models
|
|
||||||
|
|
||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
|
||||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
|
||||||
|
|
||||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ class CondFunc:
|
|||||||
self = super(CondFunc, cls).__new__(cls)
|
self = super(CondFunc, cls).__new__(cls)
|
||||||
if isinstance(orig_func, str):
|
if isinstance(orig_func, str):
|
||||||
func_path = orig_func.split('.')
|
func_path = orig_func.split('.')
|
||||||
for i in range(len(func_path)-2, -1, -1):
|
for i in range(len(func_path)-1, -1, -1):
|
||||||
try:
|
try:
|
||||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||||
break
|
break
|
||||||
|
@ -2,8 +2,6 @@ import collections
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
import time
|
|
||||||
from collections import namedtuple
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.sd_hijack_ip2p import should_hijack_ip2p
|
from modules.timer import Timer
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
@ -99,17 +97,6 @@ def checkpoint_tiles():
|
|||||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||||
|
|
||||||
|
|
||||||
def find_checkpoint_config(info):
|
|
||||||
if info is None:
|
|
||||||
return shared.cmd_opts.config
|
|
||||||
|
|
||||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
|
||||||
if os.path.exists(config):
|
|
||||||
return config
|
|
||||||
|
|
||||||
return shared.cmd_opts.config
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
checkpoint_alisases.clear()
|
checkpoint_alisases.clear()
|
||||||
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
device = map_location or shared.weight_load_location
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
if device is None:
|
|
||||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
@ -229,32 +214,44 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||||
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
timer.record("calculate hash")
|
||||||
|
|
||||||
|
if checkpoint_info in checkpoints_loaded:
|
||||||
|
# use checkpoint cache
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||||
|
return checkpoints_loaded[checkpoint_info]
|
||||||
|
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
|
res = read_state_dict(checkpoint_info.filename)
|
||||||
|
timer.record("load weights from disk")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
title = checkpoint_info.title
|
title = checkpoint_info.title
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
timer.record("calculate hash")
|
||||||
|
|
||||||
if checkpoint_info.title != title:
|
if checkpoint_info.title != title:
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
if state_dict is None:
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
model.load_state_dict(state_dict, strict=False)
|
||||||
# use checkpoint cache
|
del state_dict
|
||||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
timer.record("apply weights to model")
|
||||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
|
||||||
else:
|
|
||||||
# load from file
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
|
||||||
|
|
||||||
sd = read_state_dict(checkpoint_info.filename)
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
model.load_state_dict(sd, strict=False)
|
|
||||||
del sd
|
|
||||||
|
|
||||||
if cache_enabled:
|
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
timer.record("apply channels_last")
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
@ -272,17 +269,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
|||||||
if depth_model:
|
if depth_model:
|
||||||
model.depth_model = depth_model
|
model.depth_model = depth_model
|
||||||
|
|
||||||
|
timer.record("apply half()")
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
timer.record("apply dtype to VAE")
|
||||||
|
|
||||||
# clean up cache if limit is reached
|
# clean up cache if limit is reached
|
||||||
if cache_enabled:
|
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
checkpoints_loaded.popitem(last=False)
|
||||||
checkpoints_loaded.popitem(last=False) # LRU
|
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_info.filename
|
model.sd_model_checkpoint = checkpoint_info.filename
|
||||||
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
|||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
|
timer.record("load VAE")
|
||||||
|
|
||||||
|
|
||||||
def enable_midas_autodownload():
|
def enable_midas_autodownload():
|
||||||
@ -340,24 +340,20 @@ def enable_midas_autodownload():
|
|||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
def repair_config(sd_config):
|
||||||
def __init__(self):
|
|
||||||
self.start = time.time()
|
|
||||||
|
|
||||||
def elapsed(self):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
end = time.time()
|
sd_config.model.params.use_ema = False
|
||||||
res = end - self.start
|
|
||||||
self.start = end
|
if shared.cmd_opts.no_half:
|
||||||
return res
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
|
elif shared.cmd_opts.upcast_sampling:
|
||||||
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
|
||||||
|
|
||||||
if checkpoint_config != shared.cmd_opts.config:
|
|
||||||
print(f"Loading config from: {checkpoint_config}")
|
|
||||||
|
|
||||||
if shared.sd_model:
|
if shared.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_config)
|
|
||||||
|
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
|
||||||
# Hardcoded config for now...
|
|
||||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
|
||||||
sd_config.model.params.conditioning_key = "hybrid"
|
|
||||||
sd_config.model.params.unet_config.params.in_channels = 9
|
|
||||||
sd_config.model.params.finetune_keys = None
|
|
||||||
|
|
||||||
if should_hijack_ip2p(checkpoint_info):
|
|
||||||
sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
|
|
||||||
sd_config.model.params.conditioning_key = "hybrid"
|
|
||||||
sd_config.model.params.first_stage_key = "edited"
|
|
||||||
sd_config.model.params.cond_stage_key = "edit"
|
|
||||||
sd_config.model.params.image_size = 16
|
|
||||||
sd_config.model.params.unet_config.params.in_channels = 8
|
|
||||||
sd_config.model.params.unet_config.params.out_channels = 4
|
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
|
||||||
sd_config.model.params.use_ema = False
|
|
||||||
|
|
||||||
do_inpainting_hijack()
|
do_inpainting_hijack()
|
||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
|
||||||
elif shared.cmd_opts.upcast_sampling:
|
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
sd_model = None
|
if already_loaded_state_dict is not None:
|
||||||
|
state_dict = already_loaded_state_dict
|
||||||
|
else:
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
|
|
||||||
|
timer.record("find config")
|
||||||
|
|
||||||
|
sd_config = OmegaConf.load(checkpoint_config)
|
||||||
|
repair_config(sd_config)
|
||||||
|
|
||||||
|
timer.record("load config")
|
||||||
|
|
||||||
|
print(f"Creating model from config: {checkpoint_config}")
|
||||||
|
|
||||||
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization():
|
with sd_disable_initialization.DisableInitialization():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
|
|||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
elapsed_create = timer.elapsed()
|
sd_model.used_config = checkpoint_config
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
timer.record("create model")
|
||||||
|
|
||||||
elapsed_load_weights = timer.elapsed()
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
else:
|
else:
|
||||||
sd_model.to(shared.device)
|
sd_model.to(shared.device)
|
||||||
|
|
||||||
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
shared.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
|
timer.record("load textual inversion embeddings")
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
elapsed_the_rest = timer.elapsed()
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
|
print(f"Model loaded in {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
else:
|
else:
|
||||||
@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
|
||||||
|
|
||||||
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
|
|
||||||
del sd_model
|
|
||||||
checkpoints_loaded.clear()
|
|
||||||
load_model(checkpoint_info)
|
|
||||||
return shared.sd_model
|
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
|
|
||||||
|
timer.record("find config")
|
||||||
|
|
||||||
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
|
del sd_model
|
||||||
|
checkpoints_loaded.clear()
|
||||||
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||||
|
return shared.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
timer.record("hijack")
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
timer.record("script callbacks")
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
timer.record("move model to device")
|
||||||
|
|
||||||
elapsed = timer.elapsed()
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
print(f"Weights loaded in {elapsed:.1f}s.")
|
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
68
modules/sd_models_config.py
Normal file
68
modules/sd_models_config.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
from modules import shared, paths
|
||||||
|
|
||||||
|
sd_configs_path = shared.sd_configs_path
|
||||||
|
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
|
|
||||||
|
|
||||||
|
config_default = shared.sd_default_config
|
||||||
|
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
|
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
|
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
|
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
|
|
||||||
|
re_parametrization_v = re.compile(r'-v\b')
|
||||||
|
|
||||||
|
|
||||||
|
def guess_model_config_from_state_dict(sd, filename):
|
||||||
|
fn = os.path.basename(filename)
|
||||||
|
|
||||||
|
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||||
|
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
|
||||||
|
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
|
return config_depth_model
|
||||||
|
|
||||||
|
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||||
|
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
|
||||||
|
return config_sd2v
|
||||||
|
else:
|
||||||
|
return config_sd2
|
||||||
|
|
||||||
|
if diffusion_model_input is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return config_inpainting
|
||||||
|
if diffusion_model_input.shape[1] == 8:
|
||||||
|
return config_instruct_pix2pix
|
||||||
|
|
||||||
|
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
|
return config_alt_diffusion
|
||||||
|
|
||||||
|
return config_default
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config(state_dict, info):
|
||||||
|
if info is None:
|
||||||
|
return guess_model_config_from_state_dict(state_dict, "")
|
||||||
|
|
||||||
|
config = find_checkpoint_config_near_filename(info)
|
||||||
|
if config is not None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config_near_filename(info):
|
||||||
|
if info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||||
|
if os.path.exists(config):
|
||||||
|
return config
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
@ -454,7 +454,7 @@ class KDiffusionSampler:
|
|||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
|
@ -13,13 +13,14 @@ import modules.interrogate
|
|||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
|
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path
|
||||||
|
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
|
|
||||||
@ -264,12 +265,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
|
||||||
import modules.realesrgan_model
|
|
||||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
|
||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||||
self.default = default
|
self.default = default
|
||||||
@ -360,7 +355,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -397,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
@ -483,7 +478,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||||
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
|
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
|
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
23
modules/shared_items.py
Normal file
23
modules/shared_items.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
|
||||||
|
def realesrgan_models_names():
|
||||||
|
import modules.realesrgan_model
|
||||||
|
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||||
|
|
||||||
|
|
||||||
|
def postprocessing_scripts():
|
||||||
|
import modules.scripts
|
||||||
|
|
||||||
|
return modules.scripts.scripts_postproc.scripts
|
||||||
|
|
||||||
|
|
||||||
|
def sd_vae_items():
|
||||||
|
import modules.sd_vae
|
||||||
|
|
||||||
|
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_vae_list():
|
||||||
|
import modules.sd_vae
|
||||||
|
|
||||||
|
return modules.sd_vae.refresh_vae_list
|
@ -194,7 +194,7 @@ class EmbeddingDatabase:
|
|||||||
if not os.path.isdir(embdir.path):
|
if not os.path.isdir(embdir.path):
|
||||||
return
|
return
|
||||||
|
|
||||||
for root, dirs, fns in os.walk(embdir.path):
|
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
try:
|
try:
|
||||||
fullfn = os.path.join(root, fn)
|
fullfn = os.path.join(root, fn)
|
||||||
|
35
modules/timer.py
Normal file
35
modules/timer.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.start = time.time()
|
||||||
|
self.records = {}
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def elapsed(self):
|
||||||
|
end = time.time()
|
||||||
|
res = end - self.start
|
||||||
|
self.start = end
|
||||||
|
return res
|
||||||
|
|
||||||
|
def record(self, category, extra_time=0):
|
||||||
|
e = self.elapsed()
|
||||||
|
if category not in self.records:
|
||||||
|
self.records[category] = 0
|
||||||
|
|
||||||
|
self.records[category] += e + extra_time
|
||||||
|
self.total += e + extra_time
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
res = f"{self.total:.1f}s"
|
||||||
|
|
||||||
|
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||||
|
if not additions:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res += " ("
|
||||||
|
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||||
|
res += ")"
|
||||||
|
|
||||||
|
return res
|
@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
|||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "colorpicker"
|
return "colorpicker"
|
||||||
|
|
||||||
|
|
||||||
|
class DropdownMulti(gr.Dropdown):
|
||||||
|
"""Same as gr.Dropdown but always multiselect"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(multiselect=True, **kwargs)
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "dropdown"
|
||||||
|
@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
|
|
||||||
def image_changed(self):
|
def image_changed(self):
|
||||||
upscale_cache.clear()
|
upscale_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
|
||||||
|
name = "Simple Upscale"
|
||||||
|
order = 900
|
||||||
|
|
||||||
|
def ui(self):
|
||||||
|
with FormRow():
|
||||||
|
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
|
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"upscale_by": upscale_by,
|
||||||
|
"upscaler_name": upscaler_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
|
||||||
|
if upscaler_name is None or upscaler_name == "None":
|
||||||
|
return
|
||||||
|
|
||||||
|
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
|
||||||
|
assert upscaler1, f'could not find upscaler named {upscaler_name}'
|
||||||
|
|
||||||
|
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
|
||||||
|
pp.info[f"Postprocess upscaler"] = upscaler1.name
|
||||||
|
@ -164,7 +164,7 @@
|
|||||||
min-height: 3.2em;
|
min-height: 3.2em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_styles ul, #img2img_styles ul{
|
ul.list-none{
|
||||||
max-height: 35em;
|
max-height: 35em;
|
||||||
z-index: 2000;
|
z-index: 2000;
|
||||||
}
|
}
|
||||||
@ -714,9 +714,6 @@ footer {
|
|||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
min-width: auto;
|
min-width: auto;
|
||||||
}
|
}
|
||||||
#txt2img_hires_fix{
|
|
||||||
margin-left: -0.8em;
|
|
||||||
}
|
|
||||||
|
|
||||||
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
|
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
|
||||||
margin-left: 0em;
|
margin-left: 0em;
|
||||||
@ -744,7 +741,6 @@ footer {
|
|||||||
|
|
||||||
.dark .gr-compact{
|
.dark .gr-compact{
|
||||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||||
margin-left: 0.8em;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.gr-compact{
|
.gr-compact{
|
||||||
|
@ -10,7 +10,7 @@ then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
export install_dir="$HOME"
|
export install_dir="$HOME"
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
|
||||||
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
||||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||||
|
Loading…
Reference in New Issue
Block a user