mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
add embedding load and save from b64 json
This commit is contained in:
parent
fa0c5eb81b
commit
03694e1f99
@ -7,9 +7,11 @@ import tqdm
|
|||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image,PngImagePlugin
|
||||||
|
from ..images import captionImge
|
||||||
|
import numpy as np
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
import json
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
@ -87,9 +89,9 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
if filename.upper().endswith('.PNG'):
|
if filename.upper().endswith('.PNG'):
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if 'sd-embedding' in embed_image.text:
|
if 'sd-ti-embedding' in embed_image.text:
|
||||||
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
|
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
||||||
data = torch.load(BytesIO(embeddingData), map_location="cpu")
|
name = data.get('name',name)
|
||||||
else:
|
else:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||||||
|
|
||||||
if save_image_with_stored_embedding:
|
if save_image_with_stored_embedding:
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
|
data = torch.load(last_saved_file)
|
||||||
image.save(last_saved_image, "PNG", pnginfo=info)
|
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
||||||
|
|
||||||
|
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
|
||||||
|
|
||||||
|
caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
|
||||||
|
caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
|
||||||
|
caption_stepcount = data.get('step',0)
|
||||||
|
caption_stepcount = caption_stepcount if caption_stepcount else 0
|
||||||
|
|
||||||
|
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
|
||||||
|
caption_stepcount))]
|
||||||
|
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
|
||||||
|
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
|
||||||
else:
|
else:
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
last_saved_image += f", prompt: {text}"
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
shared.state.job_no = embedding.step
|
shared.state.job_no = embedding.step
|
||||||
|
Loading…
Reference in New Issue
Block a user