mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
change caption method
This commit is contained in:
parent
0ac3a07eec
commit
d6a599ef9b
@ -8,7 +8,7 @@ import html
|
||||
import datetime
|
||||
|
||||
from PIL import Image,PngImagePlugin
|
||||
from ..images import captionImge
|
||||
from ..images import captionImageOverlay
|
||||
import numpy as np
|
||||
import base64
|
||||
import json
|
||||
@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
if create_image_every > 0 and save_image_with_stored_embedding:
|
||||
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
||||
os.makedirs(images_embeds_dir, exist_ok=True)
|
||||
else:
|
||||
images_embeds_dir = None
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||
|
||||
shared.state.current_image = image
|
||||
|
||||
if save_image_with_stored_embedding:
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
||||
|
||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
||||
|
||||
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
|
||||
title = "<{}>".format(data.get('name','???'))
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
|
||||
embedding.step))]
|
||||
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
|
||||
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
|
||||
else:
|
||||
image.save(last_saved_image)
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||
footer_right = '[{}]'.format(embedding.step)
|
||||
|
||||
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
||||
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image += f", prompt: {text}"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user