prototype progress api

This commit is contained in:
evshiron 2022-10-26 22:33:45 +08:00
parent 99d728b5b1
commit fddb4883f4
2 changed files with 88 additions and 14 deletions

View File

@ -1,8 +1,11 @@
import time
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo from modules.extras import run_pnginfo
import modules.shared as shared import modules.shared as shared
from modules import devices
import uvicorn import uvicorn
from fastapi import Body, APIRouter, HTTPException from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters: Json parameters: Json
info: Json info: Json
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: Json
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def before_gpu_call():
devices.torch_gc()
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
shared.state.job_timestamp = shared.state.get_job_timestamp()
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
shared.state.time_start = time.time()
def after_gpu_call():
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
@ -33,6 +67,7 @@ class Api:
self.queue_lock = queue_lock self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string): def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix # if has a comma, deal with prefix
@ -44,12 +79,12 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
@ -57,9 +92,11 @@ class Api:
) )
p = StableDiffusionProcessingTxt2Img(**vars(populate)) p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param # Override object param
before_gpu_call()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
after_gpu_call()
b64images = [] b64images = []
for i in processed.images: for i in processed.images:
buffer = io.BytesIO() buffer = io.BytesIO()
@ -67,30 +104,30 @@ class Api:
b64images.append(base64.b64encode(buffer.getvalue())) b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js()) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index) sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = self.__base64_to_image(mask) mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
} }
) )
@ -103,9 +140,11 @@ class Api:
p.init_images = imgs p.init_images = imgs
# Override object param # Override object param
before_gpu_call()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
after_gpu_call()
b64images = [] b64images = []
for i in processed.images: for i in processed.images:
buffer = io.BytesIO() buffer = io.BytesIO()
@ -118,6 +157,28 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def progressapi(self):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError

View File

@ -146,6 +146,19 @@ class State:
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
def js(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return json.dumps(obj)
state = State() state = State()