deepdanbooru interrogator

This commit is contained in:
Greendayle 2022-10-05 20:50:10 +02:00
parent 1eb588cbf1
commit 59a2b9e5af
6 changed files with 91 additions and 6 deletions

60
modules/deepbooru.py Normal file
View File

@ -0,0 +1,60 @@
import os.path
from concurrent.futures import ProcessPoolExecutor
import numpy as np
import deepdanbooru as dd
import tensorflow as tf
def _load_tf_and_return_tags(pil_image, threshold):
this_folder = os.path.dirname(__file__)
model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
if not os.path.exists(model_path):
return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru"
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=True
)
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
)
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0]
result_dict = {}
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
result_tags_out = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
result_tags_out.append(tag)
result_tags_print.append(f'{result_dict[tag]} {tag}')
print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out)
def get_deepbooru_tags(pil_image, threshold=0.5):
with ProcessPoolExecutor() as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold)
ret = f.result() # will rethrow any exceptions
return ret

View File

@ -23,6 +23,7 @@ import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack from modules import sd_hijack
from modules.deepbooru import get_deepbooru_tags
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
@ -312,6 +313,11 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image)
return gr_show(True) if prompt is None else prompt
def create_seed_inputs(): def create_seed_inputs():
with gr.Row(): with gr.Row():
with gr.Box(): with gr.Box():
@ -439,15 +445,17 @@ def create_toprow(is_img2img):
outputs=[], outputs=[],
) )
with gr.Row(): with gr.Row(scale=1):
if is_img2img: if is_img2img:
interrogate = gr.Button('Interrogate', elem_id="interrogate") interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
else: else:
interrogate = None interrogate = None
deepbooru = None
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create") save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@ -476,7 +484,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.txt2img import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
@ -628,7 +636,7 @@ def create_ui(wrap_gradio_gpu_call):
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
@ -785,6 +793,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
img2img_deepbooru.click(
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
)
save.click( save.click(
fn=wrap_gradio_call(save_files), fn=wrap_gradio_call(save_files),
_js="(x, y, z) => [x, y, selected_gallery_index()]", _js="(x, y, z) => [x, y, selected_gallery_index()]",

View File

@ -23,3 +23,6 @@ resize-right
torchdiffeq torchdiffeq
kornia kornia
lark lark
deepdanbooru
tensorflow
tensorflow-io

View File

@ -22,3 +22,6 @@ resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow]
tensorflow==2.10.0
tensorflow-io==0.27.0

View File

@ -103,7 +103,12 @@
#style_apply, #style_create, #interrogate{ #style_apply, #style_create, #interrogate{
margin: 0.75em 0.25em 0.25em 0.25em; margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 3em; min-width: 5em;
}
#style_apply, #style_create, #deepbooru{
margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 5em;
} }
#style_pos_col, #style_neg_col{ #style_pos_col, #style_neg_col{