mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #4056 from MarkovInequality/TI_optimizations
Allow TI training using 6GB VRAM when xformers is available
This commit is contained in:
commit
f071a1d25a
@ -288,11 +288,12 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
if save_embedding_every > 0:
|
if save_embedding_every > 0:
|
||||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
embedding.vec.requires_grad = True
|
embedding.vec.requires_grad = True
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
forced_filename = f'{embedding_name}-{steps_done}'
|
forced_filename = f'{embedding_name}-{steps_done}'
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
|
@ -25,8 +25,10 @@ def train_embedding(*args):
|
|||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||||
|
|
||||||
|
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
try:
|
try:
|
||||||
sd_hijack.undo_optimizations()
|
if not apply_optimizations:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||||
|
|
||||||
@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
sd_hijack.apply_optimizations()
|
if not apply_optimizations:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user