mirror of
https://github.com/mudler/LocalAI.git
synced 2024-06-07 19:40:48 +00:00
feat(diffusers): don't set seed in params and respect device (#1010)
**Description** Follow up of #998 - respect the device used to load the model and do not specify a seed in the parameters, but rather just configure the generator as described in https://huggingface.co/docs/diffusers/using-diffusers/reusing_seeds Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
dc307a1cc0
commit
605c319157
@ -221,8 +221,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
modelFileBase = os.path.dirname(request.ModelFile)
|
modelFileBase = os.path.dirname(request.ModelFile)
|
||||||
# modify LoraAdapter to be relative to modelFileBase
|
# modify LoraAdapter to be relative to modelFileBase
|
||||||
request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter)
|
request.LoraAdapter = os.path.join(modelFileBase, request.LoraAdapter)
|
||||||
|
device = "cpu" if not request.CUDA else "cuda"
|
||||||
|
self.device = device
|
||||||
if request.LoraAdapter:
|
if request.LoraAdapter:
|
||||||
device = "cpu" if not request.CUDA else "cuda"
|
|
||||||
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
||||||
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
||||||
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
||||||
@ -300,7 +301,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
"width": request.width,
|
"width": request.width,
|
||||||
"height": request.height,
|
"height": request.height,
|
||||||
"num_inference_steps": request.step,
|
"num_inference_steps": request.step,
|
||||||
"seed": request.seed,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.src != "":
|
if request.src != "":
|
||||||
@ -321,7 +321,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
if request.seed > 0:
|
if request.seed > 0:
|
||||||
kwargs["generator"] = torch.Generator(device="cuda").manual_seed(
|
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(
|
||||||
request.seed
|
request.seed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user