From 75b67eebf21f72f5b693926476d9c3b12471f0d6 Mon Sep 17 00:00:00 2001 From: Sena <34237511+sena-nana@users.noreply.github.com> Date: Wed, 23 Nov 2022 17:43:58 +0800 Subject: [PATCH 1/2] Fix bare base64 not accept --- modules/api/api.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7a567be38..648bd6a86 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -3,6 +3,7 @@ import io import time import uvicorn from threading import Lock +from io import BytesIO from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials @@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo -from PIL import PngImagePlugin +from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -133,7 +134,10 @@ class Api: mask = img2imgreq.mask if mask: - mask = decode_base64_to_image(mask) + if mask.startswith("data:image/"): + mask = decode_base64_to_image(mask) + else: + mask = Image.open(BytesIO(base64.b64decode(mask))) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, @@ -147,7 +151,10 @@ class Api: imgs = [] for img in init_images: - img = decode_base64_to_image(img) + if img.startswith("data:image/"): + img = decode_base64_to_image(img) + else: + img = Image.open(BytesIO(base64.b64decode(img))) imgs = [img] * p.batch_size p.init_images = imgs From fcd75bd8740855e0c7bc80c0e8a4e1033b76d007 Mon Sep 17 00:00:00 2001 From: Sena <34237511+sena-nana@users.noreply.github.com> Date: Thu, 24 Nov 2022 13:10:40 +0800 Subject: [PATCH 2/2] Fix other apis --- modules/api/api.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 648bd6a86..efcedbba2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,7 +4,7 @@ import time import uvicorn from threading import Lock from io import BytesIO -from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image +from gradio.processing_utils import decode_base64_to_file from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -41,6 +41,10 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + return Image.open(BytesIO(base64.b64decode(encoding))) def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: @@ -134,10 +138,7 @@ class Api: mask = img2imgreq.mask if mask: - if mask.startswith("data:image/"): - mask = decode_base64_to_image(mask) - else: - mask = Image.open(BytesIO(base64.b64decode(mask))) + mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, @@ -151,10 +152,7 @@ class Api: imgs = [] for img in init_images: - if img.startswith("data:image/"): - img = decode_base64_to_image(img) - else: - img = Image.open(BytesIO(base64.b64decode(img))) + img = decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs