From 702a1e1cc70240f2adbcfb707a644a5a98b5443c Mon Sep 17 00:00:00 2001 From: superhero-7 <537093830@qq.com> Date: Sat, 23 Sep 2023 17:51:41 +0800 Subject: [PATCH] support m18 --- configs/alt-diffusion-m18-inference.yaml | 73 ++++++++++ modules/sd_hijack.py | 6 +- modules/sd_models_config.py | 6 +- modules/xlmr_m18.py | 164 +++++++++++++++++++++++ 4 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 configs/alt-diffusion-m18-inference.yaml create mode 100644 modules/xlmr_m18.py diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml new file mode 100644 index 000000000..41a031d55 --- /dev/null +++ b/configs/alt-diffusion-m18-inference.yaml @@ -0,0 +1,73 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr_m18.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f00551..ae9b2a656 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -208,11 +208,10 @@ class StableDiffusionModelHijack: else: m.cond_stage_model = conditioner - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) - elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) @@ -258,7 +257,6 @@ class StableDiffusionModelHijack: if hasattr(m, 'cond_stage_model'): delattr(m, 'cond_stage_model') - elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 08dd03f19..9ba89dfc0 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") - +config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): """ @@ -95,7 +95,11 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix + + # import pdb; pdb.set_trace() if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: + if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: + return config_alt_diffusion_m18 return config_alt_diffusion return config_default diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py new file mode 100644 index 000000000..18785692a --- /dev/null +++ b/modules/xlmr_m18.py @@ -0,0 +1,164 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 1024 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + # self.pooler = lambda x: x[:,0] + # self.post_init() + + self.has_pre_transformation = True + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # # last module outputs + # sequence_output = outputs[0] + + + # # project every module + # sequence_output_ln = self.pre_LN(sequence_output) + + # # pooler + # pooler_output = self.pooler(sequence_output_ln) + # pooler_output = self.transformation(pooler_output) + # projection_state = self.transformation(outputs.last_hidden_state) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return { + "projection_state": projection_state2, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + else: + projection_state = self.transformation(outputs.last_hidden_state) + return { + "projection_state": projection_state, + "last_hidden_state": outputs.last_hidden_state, + "hidden_states": outputs.hidden_states, + "attentions": outputs.attentions, + } + + + # return { + # 'pooler_output':pooler_output, + # 'last_hidden_state':outputs.last_hidden_state, + # 'hidden_states':outputs.hidden_states, + # 'attentions':outputs.attentions, + # 'projection_state':projection_state, + # 'sequence_out': sequence_output + # } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file