diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py new file mode 100644 index 000000000..898ce3b3b --- /dev/null +++ b/modules/textual_inversion/image_embedding.py @@ -0,0 +1,219 @@ +import base64 +import json +import numpy as np +import zlib +from PIL import Image, PngImagePlugin, ImageDraw, ImageFont +from fonts.ttf import Roboto +import torch + + +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, obj) + + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, d): + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) + return d + + +def embedding_to_b64(data): + d = json.dumps(data, cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + + +def embedding_from_b64(data): + d = base64.b64decode(data) + return json.loads(d, cls=EmbeddingDecoder) + + +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed % 255 + + +def xor_block(block): + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) + + +def style_block(block, sequence): + im = Image.new('RGB', (block.shape[1], block.shape[0])) + draw = ImageDraw.Draw(im) + i = 0 + for x in range(-6, im.size[0], 8): + for yi, y in enumerate(range(-6, im.size[1], 8)): + offset = 0 + if yi % 2 == 0: + offset = 4 + shade = sequence[i % len(sequence)] + i += 1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade)) + + fg = np.array(im).astype(np.uint8) & 0xF0 + + return block ^ fg + + +def insert_image_data_embed(image, data): + d = 3 + data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9) + data_np_ = np.frombuffer(data_compressed, np.uint8).copy() + data_np_high = data_np_ >> 4 + data_np_low = data_np_ & 0x0F + + h = image.size[1] + next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) + next_size = next_size + ((h*d)-(next_size % (h*d))) + + data_np_low.resize(next_size) + data_np_low = data_np_low.reshape((h, -1, d)) + + data_np_high.resize(next_size) + data_np_high = data_np_high.reshape((h, -1, d)) + + edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) + + data_np_low = style_block(data_np_low, sequence=edge_style) + data_np_low = xor_block(data_np_low) + data_np_high = style_block(data_np_high, sequence=edge_style[::-1]) + data_np_high = xor_block(data_np_high) + + im_low = Image.fromarray(data_np_low, mode='RGB') + im_high = Image.fromarray(data_np_high, mode='RGB') + + background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0)) + background.paste(im_low, (0, 0)) + background.paste(image, (im_low.size[0]+1, 0)) + background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0)) + + return background + + +def crop_black(img, tol=0): + mask = (img > tol).all(2) + mask0, mask1 = mask.any(0), mask.any(1) + col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax() + row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end, col_start:col_end] + + +def extract_image_data_embed(image): + d = 3 + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F + black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) + if black_cols[0].shape[0] < 2: + print('No Image data blocks found.') + return None + + data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) + data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8) + + data_block_lower = xor_block(data_block_lower) + data_block_upper = xor_block(data_block_upper) + + data_block = (data_block_upper << 4) | (data_block_lower) + data_block = data_block.flatten().tobytes() + + data = zlib.decompress(data_block) + return json.loads(data, cls=EmbeddingDecoder) + + +def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None): + from math import cos + + image = srcimage.copy() + + if textfont is None: + try: + textfont = ImageFont.truetype(opts.font or Roboto, fontsize) + textfont = opts.font or Roboto + except Exception: + textfont = Roboto + + factor = 1.5 + gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0)) + for y in range(image.size[1]): + mag = 1-cos(y/image.size[1]*factor) + mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0, 0, 0, int(mag*255))) + image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) + + draw = ImageDraw.Draw(image) + fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) + padding = 10 + + _, _, w, h = draw.textbbox((0, 0), title, font=font) + fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72) + font = ImageFont.truetype(textfont, fontsize) + _, _, w, h = draw.textbbox((0, 0), title, font=font) + draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230)) + + _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font) + fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerMid, font=font) + fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerRight, font=font) + fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + + font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right)) + + draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230)) + + return image + + +if __name__ == '__main__': + + testEmbed = Image.open('test_embedding.png') + data = extract_image_data_embed(testEmbed) + assert data is not None + + data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) + assert data is not None + + image = Image.new('RGBA', (512, 512), (255, 255, 200, 255)) + cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') + + test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}} + + embedded_image = insert_image_data_embed(cap_image, test_embed) + + retrived_embed = extract_image_data_embed(embedded_image) + + assert str(retrived_embed) == str(test_embed) + + embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) + + assert embedded_image == embedded_image2 + + g = lcg() + shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() + + reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, + 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, + 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, + 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, + 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, + 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, + 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, + 204, 86, 73, 222, 44, 198, 118, 240, 97] + + assert shared_random == reference_random + + hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) + + assert 12731374 == hunna_kay_random_sum diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png new file mode 100644 index 000000000..07e2d9afa Binary files /dev/null and b/modules/textual_inversion/test_embedding.png differ diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7717837da..c5153e4aa 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,11 +7,15 @@ import tqdm import html import datetime +from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, + insert_image_data_embed, extract_image_data_embed, + caption_image_overlay) class Embedding: def __init__(self, vec, name, step=None): @@ -81,7 +85,18 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = torch.load(path, map_location="cpu") + data = [] + + if filename.upper().endswith('.PNG'): + embed_image = Image.open(path) + if 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) + else: + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + else: + data = torch.load(path, map_location="cpu") # textual inversion embeddings if 'string_to_param' in data: @@ -157,7 +172,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -179,6 +195,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini else: images_dir = None + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + cond_model = shared.sd_model.cond_stage_model shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -262,6 +284,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini image = processed.images[0] shared.state.current_image = image + + if save_image_with_stored_embedding and os.path.exists(last_saved_file): + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) + + title = "<{}>".format(data.get('name', '???')) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}'.format(embedding.step) + + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + image.save(last_saved_image) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/ui.py b/modules/ui.py index 8cd12b518..a3364f76a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1109,6 +1109,7 @@ def create_ui(wrap_gradio_gpu_call): num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_image_prompt = gr.Textbox(label='Preview prompt', value="") with gr.Row(): @@ -1187,6 +1188,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + save_image_with_stored_embedding, preview_image_prompt, ], outputs=[