diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 450fecac1..6a9b1398a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() - return fn - -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -448,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None + if clip_grad: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -466,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.parallel_processing_allowed = False shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - + weights = hypernetwork.weights() hypernetwork.train_mode() @@ -525,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.state.interrupted: break + if clip_grad: + clip_grad_sched.step(hypernetwork.step) + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if tag_drop_out != 0 or shuffle_tags: @@ -539,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, _loss_step += loss.item() scaler.scale(loss).backward() + # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") - # scaler.unscale_(optimizer) - # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") - # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) - # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + + if clip_grad: + clip_grad(weights, clip_grad_sched.learn_rate) + scaler.step(optimizer) scaler.update() hypernetwork.step += 1 diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index dd0c0ad14..f63fc72ff 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -58,14 +58,19 @@ class LearnRateScheduler: self.finished = False - def apply(self, optimizer, step_number): + def step(self, step_number): if step_number < self.end_step: - return + return False try: (self.learn_rate, self.end_step) = next(self.schedules) - except Exception: + except StopIteration: self.finished = True + return False + return True + + def apply(self, optimizer, step_number): + if not self.step(step_number): return if self.verbose: diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2250e41b1..71e07bcc2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,8 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" - -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -295,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ + torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ + None + if clip_grad: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -361,6 +365,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if shared.state.interrupted: break + if clip_grad: + clip_grad_sched.step(embedding.step) + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) @@ -382,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue + + if clip_grad: + clip_grad(embedding.vec, clip_grad_sched.learn_rate) + scaler.step(optimizer) scaler.update() embedding.step += 1 diff --git a/modules/ui.py b/modules/ui.py index 184af7ad5..72e7b7d21 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1290,6 +1290,10 @@ def create_ui(): with gr.Row(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with gr.Row(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") @@ -1402,6 +1406,8 @@ def create_ui(): training_width, training_height, steps, + clip_grad_mode, + clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, @@ -1431,6 +1437,8 @@ def create_ui(): training_width, training_height, steps, + clip_grad_mode, + clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method,