Merge pull request #10365 from Sakura-Luna/taesd-a

Add Tiny AE live preview
This commit is contained in:
AUTOMATIC1111 2023-05-17 09:26:50 +03:00 committed by GitHub
commit 9ac85b8b73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 124 additions and 3 deletions

View File

@ -158,5 +158,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)

View File

@ -662,3 +662,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
</pre>
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
<pre>
MIT License
Copyright (c) 2023 Ollin Boer Bohan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>

View File

@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, processing, images, sd_vae_approx, sd_samplers
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@ -22,10 +22,11 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
@ -33,12 +34,17 @@ def single_sample_to_image(sample, approximation=None):
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
elif approximation == 3:
x_sample = sd_vae_taesd.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)

88
modules/sd_vae_taesd.py Normal file
View File

@ -0,0 +1,88 @@
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
https://github.com/madebyollin/taesd
"""
import os
import torch
import torch.nn as nn
from modules import devices, paths_internal
sd_vae_taesd = None
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
@staticmethod
def forward(x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 2
latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.decoder = decoder()
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def download_model(model_path):
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
print(f'Downloading TAESD decoder to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)
def model():
global sd_vae_taesd
if sd_vae_taesd is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
download_model(model_path)
if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path)
sd_vae_taesd.eval()
sd_vae_taesd.to(devices.device, devices.dtype)
else:
raise FileNotFoundError('TAESD model not found')
return sd_vae_taesd.decoder

View File

@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))