Merge pull request #1066 from moorehousew/master

Add support for checkpoint merging
This commit is contained in:
AUTOMATIC1111 2022-09-27 09:59:37 +03:00 committed by GitHub
commit a9dc307a21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 2 deletions

View File

@ -3,6 +3,8 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
@ -135,3 +137,40 @@ def run_pnginfo(image):
info = f"<div><p>{message}<p></div>" info = f"<div><p>{message}<p></div>"
return '', geninfo, info return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def sigmoid(theta0, theta1, alpha):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
model_0 = torch.load('models/' + modelname_0 + '.ckpt')
model_1 = torch.load('models/' + modelname_1 + '.ckpt')
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
theta_func = weighted_sum
if interp_method == "Weighted Sum":
theta_func = weighted_sum
if interp_method == "Sigmoid":
theta_func = sigmoid
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt';
torch.save(model_0, output_modelname)
return "<p>Model saved to " + output_modelname + "</p>"

View File

@ -393,7 +393,7 @@ def setup_progressbar(progressbar, preview, id_part):
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -853,6 +853,33 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)")
modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)")
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.HTML(elem_id="modelmerger_result")
submit.click(
fn=run_modelmerger,
inputs=[
modelname_0,
modelname_1,
interp_method,
interp_amount
],
outputs=[
submit_result,
]
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -950,6 +977,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]

View File

@ -85,7 +85,8 @@ def webui():
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img), img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
) )
demo.launch( demo.launch(