feat(diffusers): don't set seed in params and respect device (#1010)

**Description**

Follow up of #998 - respect the device used to load the model and do not
specify a seed in the parameters, but rather just configure the
generator as described in
https://huggingface.co/docs/diffusers/using-diffusers/reusing_seeds

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2023-09-04 19:38:38 +02:00 committed by GitHub
parent dc307a1cc0
commit 605c319157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -221,8 +221,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFileBase = os.path.dirname(request.ModelFile) modelFileBase = os.path.dirname(request.ModelFile)
# modify LoraAdapter to be relative to modelFileBase # modify LoraAdapter to be relative to modelFileBase
request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter) request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter)
device = "cpu" if not request.CUDA else "cuda"
self.device = device
if request.LoraAdapter: if request.LoraAdapter:
device = "cpu" if not request.CUDA else "cuda"
# Check if its a local file and not a directory ( we load lora differently for a safetensor file ) # Check if its a local file and not a directory ( we load lora differently for a safetensor file )
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter): if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
self.load_lora_weights(request.LoraAdapter, 1, device, torchType) self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
@ -300,7 +301,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
"width": request.width, "width": request.width,
"height": request.height, "height": request.height,
"num_inference_steps": request.step, "num_inference_steps": request.step,
"seed": request.seed,
} }
if request.src != "": if request.src != "":
@ -321,7 +321,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Set seed # Set seed
if request.seed > 0: if request.seed > 0:
kwargs["generator"] = torch.Generator(device="cuda").manual_seed( kwargs["generator"] = torch.Generator(device=self.device).manual_seed(
request.seed request.seed
) )