add embedding load and save from b64 json

This commit is contained in:
DepFA 2022-10-09 21:58:14 +01:00 committed by GitHub
parent fa0c5eb81b
commit 03694e1f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,8 +8,10 @@ import html
import datetime
from PIL import Image,PngImagePlugin
from ..images import captionImge
import numpy as np
import base64
from io import BytesIO
import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@ -87,9 +89,9 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-embedding' in embed_image.text:
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
data = torch.load(BytesIO(embeddingData), map_location="cpu")
if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo()
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
image.save(last_saved_image, "PNG", pnginfo=info)
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
caption_stepcount = data.get('step',0)
caption_stepcount = caption_stepcount if caption_stepcount else 0
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
caption_stepcount))]
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step