do not load wait for shared.sd_model to load at startup

This commit is contained in:
AUTOMATIC 2023-05-02 09:08:00 +03:00
parent 696c338ee2
commit b1717c0a48
4 changed files with 76 additions and 35 deletions

View File

@ -2,6 +2,8 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import threading
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
@ -404,13 +406,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
class SdModelData:
def __init__(self):
self.sd_model = None
self.lock = threading.Lock()
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
return self.sd_model
def set_sd_model(self, v):
self.sd_model = v
model_data = SdModelData()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if shared.sd_model: if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
shared.sd_model = None model_data.sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
timer.record("hijack") timer.record("hijack")
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = model_data.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model return model_data.sd_model
try: try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model return sd_model
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
timer = Timer() timer = Timer()
if shared.sd_model: if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
# shared.sd_model.cond_stage_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
# shared.sd_model.first_stage_model.to(devices.cpu) model_data.sd_model = None
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()

View File

@ -16,6 +16,7 @@ import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None demo = None
@ -600,13 +601,37 @@ class Options:
return value return value
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" """assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent" latent_upscale_default_mode = "Latent"
latent_upscale_modes = { latent_upscale_modes = {
@ -620,8 +645,6 @@ latent_upscale_modes = {
sd_upscalers = [] sd_upscalers = []
sd_model = None
clip_model = None clip_model = None
progress_print_out = sys.stdout progress_print_out = sys.stdout

View File

@ -828,7 +828,7 @@ def create_ui():
with FormGroup(): with FormGroup():
with FormRow(): with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed": elif category == "seed":
@ -1693,11 +1693,9 @@ def create_ui():
show_progress=info.refresh is not None, show_progress=info.refresh is not None,
) )
text_settings.change( update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"), text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
inputs=[], demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
outputs=[image_cfg_scale],
)
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click( button_set_checkpoint.click(

View File

@ -6,6 +6,8 @@ import signal
import re import re
import warnings import warnings
import json import json
from threading import Thread
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
@ -191,18 +193,10 @@ def initialize():
modules.textual_inversion.textual_inversion.list_textual_inversion_templates() modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates") startup_timer.record("refresh textual inversion templates")
try: # load model in parallel to other startup stuff
modules.sd_models.load_model() Thread(target=lambda: shared.sd_model).start()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
startup_timer.record("load SD checkpoint")
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)