add pixel data footer

This commit is contained in:
DepFA 2022-10-10 15:34:49 +01:00 committed by GitHub
parent ce2d7f7eac
commit 707a431100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -12,6 +12,7 @@ from ..images import captionImageOverlay
import numpy as np import numpy as np
import base64 import base64
import json import json
import zlib
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, o) return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder): class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -38,6 +39,45 @@ def embeddingFromB64(data):
d = base64.b64decode(data) d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder) return json.loads(d,cls=EmbeddingDecoder)
def appendImageDataFooter(image,data):
d = 3
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
dnp = np.frombuffer(data_compressed,np.uint8).copy()
w = image.size[0]
next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
next_size = next_size + ((w*d)-(next_size%(w*d)))
dnp.resize(next_size)
dnp = dnp.reshape((-1,w,d))
print(dnp.shape)
im = Image.fromarray(dnp,mode='RGB')
background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
background.paste(image,(0,0))
background.paste(im,(0,image.size[1]+1))
return background
def crop_black(img,tol=0):
mask = (img>tol).all(2)
mask0,mask1 = mask.any(0),mask.any(1)
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end,col_start:col_end]
def extractImageDataFooter(image):
d=3
outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
if lastRow[0].shape[0] == 0:
print('Image data block not found.')
return None
lastRow = lastRow[0]
lastRow = lastRow.max()
dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
print(lastRow)
data = zlib.decompress(dataBlock)
return json.loads(data,cls=EmbeddingDecoder)
class Embedding: class Embedding:
def __init__(self, vec, name, step=None): def __init__(self, vec, name, step=None):
self.vec = vec self.vec = vec
@ -113,6 +153,9 @@ class EmbeddingDatabase:
if 'sd-ti-embedding' in embed_image.text: if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding']) data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name) name = data.get('name',name)
else:
data = extractImageDataFooter(embed_image)
name = data.get('name',name)
else: else:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_right = '{}'.format(embedding.step) footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
captioned_image = appendImageDataFooter(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)