mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add 'interrogate' and 'all' choices to --use-cpu
* Add 'interrogate' and 'all' choices to --use-cpu * Change type for --use-cpu argument to str.lower, so that choices are case insensitive
This commit is contained in:
parent
fdecb63685
commit
fdef8253a4
@ -34,7 +34,7 @@ def enable_tf32():
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
|
@ -55,7 +55,7 @@ class InterrogateModels:
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model.eval()
|
||||
model = model.to(shared.device)
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
return model, preprocess
|
||||
|
||||
@ -65,14 +65,14 @@ class InterrogateModels:
|
||||
if not shared.cmd_opts.no_half:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(shared.device)
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
self.clip_model = self.clip_model.half()
|
||||
|
||||
self.clip_model = self.clip_model.to(shared.device)
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = next(self.clip_model.parameters()).dtype
|
||||
|
||||
@ -99,11 +99,11 @@ class InterrogateModels:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
||||
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
@ -116,7 +116,7 @@ class InterrogateModels:
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
||||
@ -140,7 +140,7 @@ class InterrogateModels:
|
||||
|
||||
res = caption
|
||||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
|
@ -54,7 +54,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@ -76,8 +76,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
||||
device = devices.device
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user