mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #3842 from R-N/gradient-clipping
Gradient clipping in train tab
This commit is contained in:
commit
9092e1ca77
@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
@ -448,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
@ -466,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
shared.parallel_processing_allowed = False
|
shared.parallel_processing_allowed = False
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
hypernetwork.train_mode()
|
hypernetwork.train_mode()
|
||||||
|
|
||||||
@ -525,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
@ -539,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
|
||||||
# scaler.unscale_(optimizer)
|
if clip_grad:
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
|
@ -58,14 +58,19 @@ class LearnRateScheduler:
|
|||||||
|
|
||||||
self.finished = False
|
self.finished = False
|
||||||
|
|
||||||
def apply(self, optimizer, step_number):
|
def step(self, step_number):
|
||||||
if step_number < self.end_step:
|
if step_number < self.end_step:
|
||||||
return
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||||
except Exception:
|
except StopIteration:
|
||||||
self.finished = True
|
self.finished = True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, optimizer, step_number):
|
||||||
|
if not self.step(step_number):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -251,8 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||||
@ -295,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
return embedding, filename
|
return embedding, filename
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||||
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||||
|
None
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
@ -361,6 +365,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
@ -382,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
embedding.step += 1
|
embedding.step += 1
|
||||||
|
@ -1290,6 +1290,10 @@ def create_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
|
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||||
@ -1402,6 +1406,8 @@ def create_ui():
|
|||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
shuffle_tags,
|
shuffle_tags,
|
||||||
tag_drop_out,
|
tag_drop_out,
|
||||||
latent_sampling_method,
|
latent_sampling_method,
|
||||||
@ -1431,6 +1437,8 @@ def create_ui():
|
|||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
shuffle_tags,
|
shuffle_tags,
|
||||||
tag_drop_out,
|
tag_drop_out,
|
||||||
latent_sampling_method,
|
latent_sampling_method,
|
||||||
|
Loading…
Reference in New Issue
Block a user