mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add utility to inspect a model's parameters (to get dtype/device)
This commit is contained in:
parent
a84e842189
commit
5768afc776
@ -4,6 +4,7 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors, shared
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@ -131,7 +132,7 @@ patch_module_list = [
|
||||
|
||||
|
||||
def manual_cast_forward(self, *args, **kwargs):
|
||||
org_dtype = next(self.parameters()).dtype
|
||||
org_dtype = get_param(self).dtype
|
||||
self.to(dtype)
|
||||
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
@ -11,6 +11,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
@ -131,7 +132,7 @@ class InterrogateModels:
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = next(self.clip_model.parameters()).dtype
|
||||
self.dtype = get_param(self.clip_model).dtype
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
|
@ -6,6 +6,7 @@ import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||
dtype = get_param(model.model.diffusion_model).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
|
17
modules/torch_utils.py
Normal file
17
modules/torch_utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
||||
def get_param(model) -> torch.nn.Parameter:
|
||||
"""
|
||||
Find the first parameter in a model or module.
|
||||
"""
|
||||
if hasattr(model, "model") and hasattr(model.model, "parameters"):
|
||||
# Unpeel a model descriptor to get at the actual Torch module.
|
||||
model = model.model
|
||||
|
||||
for param in model.parameters():
|
||||
return param
|
||||
|
||||
raise ValueError(f"No parameters found in model {model!r}")
|
@ -7,6 +7,7 @@ import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from modules import images, shared
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
|
||||
model_weight = next(iter(model.model.parameters()))
|
||||
img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype)
|
||||
param = get_param(model)
|
||||
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
|
@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
device = get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
device = get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
19
test/test_torch_utils.py
Normal file
19
test/test_torch_utils.py
Normal file
@ -0,0 +1,19 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from modules.torch_utils import get_param
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapped", [True, False])
|
||||
def test_get_param(wrapped):
|
||||
mod = torch.nn.Linear(1, 1)
|
||||
cpu = torch.device("cpu")
|
||||
mod.to(dtype=torch.float16, device=cpu)
|
||||
if wrapped:
|
||||
# more or less how spandrel wraps a thing
|
||||
mod = types.SimpleNamespace(model=mod)
|
||||
p = get_param(mod)
|
||||
assert p.dtype == torch.float16
|
||||
assert p.device == cpu
|
Loading…
Reference in New Issue
Block a user