morpheushoc commited on
Commit
0248af2
·
verified ·
1 Parent(s): dc963f2

Upload modeling_base.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_base.py +189 -0
modeling_base.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import warnings
4
+ import logging
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import MSELoss
9
+
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from modeling_internvideo2_vit import pretrain_internvideo2_giant_patch14_224_clean
13
+ from modeling_qformer import build_qformer
14
+ from model_config import VideoChat2Config
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor
19
+ from transformers import AutoConfig, PreTrainedModel
20
+
21
+ try:
22
+ token = os.environ['HF_TOKEN']
23
+ except:
24
+ warnings.warn("The HF_TOKEN was not found in the system variables. Please ensure that it is filled out correctly and that you have requested access to the model. If you haven't applied, please visit https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 to request access.")
25
+ token=None
26
+
27
+ def disabled_train(self, mode=True):
28
+ """Overwrite model.train with this function to make sure train/eval mode
29
+ does not change anymore."""
30
+ return self
31
+
32
+
33
+ def freeze_module(module):
34
+ for _, param in module.named_parameters():
35
+ param.requires_grad = False
36
+ module = module.eval()
37
+ module.train = disabled_train
38
+ return module
39
+
40
+
41
+ class BaseMLLM(PreTrainedModel):
42
+ config_class = VideoChat2Config
43
+ def __init__(self, config):
44
+ self.model_config = config.model_config
45
+ super().__init__(config)
46
+ self.build_vision_encoder()
47
+ if 'llm' in self.model_config: self.build_llm()
48
+ if 'bridge' in self.model_config: self.build_bridge()
49
+ if 'loss' in self.model_config: self.build_loss()
50
+ # NOTE place it after freeze llm
51
+ for n, p in self.named_parameters():
52
+ if p.requires_grad:
53
+ logger.info(f'{n} requires_grad')
54
+
55
+
56
+ def build_vision_encoder(self):
57
+ # load pretrained internvideo2-1b here, simplified as it receives no args
58
+ # note that we haven't load the internvideo pretrained version
59
+ if 'internvideo2' in self.model_config.vision_encoder.name.lower():
60
+ encoder_name = self.model_config.vision_encoder.name
61
+ logger.info(f"Build vision_encoder: {encoder_name}")
62
+ if encoder_name == 'internvideo2-1B':
63
+ self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config)
64
+ else:
65
+ raise ValueError(f"Not implemented: {encoder_name}")
66
+ else:
67
+ raise NotImplementedError(self.model_config.vision_encoder.name)
68
+
69
+ if self.model_config.vision_encoder.vit_add_ln:
70
+ self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12)
71
+ else:
72
+ self.vision_layernorm = nn.Identity()
73
+
74
+ self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False)
75
+
76
+ if self.freeze_vision_encoder:
77
+ logger.info("freeze vision encoder")
78
+ freeze_module(self.vision_encoder)
79
+ freeze_module(self.vision_layernorm)
80
+
81
+
82
+ def build_bridge(self):
83
+ # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim
84
+ self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed?
85
+ # LM to ViT: 6656 -> 1792
86
+ self.project_down = nn.Linear(self.lm.config.hidden_size, 768)
87
+
88
+ if 'qformer' in self.model_config.bridge.name.lower():
89
+ from transformers import BertTokenizer
90
+ self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
91
+ self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
92
+ self.qformer_tokenizer.padding_side = "left"
93
+ if self.model_config.bridge.name == 'qformer':
94
+ self.qformer, self.query_tokens = build_qformer(
95
+ self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim,
96
+ qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob,
97
+ qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob,
98
+ qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate,
99
+ )
100
+ self.qformer.resize_token_embeddings(len(self.qformer_tokenizer))
101
+ self.qformer.cls = None
102
+ self.extra_num_query_token = self.model_config.bridge.extra_num_query_token
103
+ if self.model_config.bridge.extra_num_query_token > 0:
104
+ logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer")
105
+ self.extra_query_tokens = nn.Parameter(
106
+ torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1])
107
+ )
108
+
109
+ self.freeze_bridge = self.model_config.get("freeze_bridge", False)
110
+ if self.freeze_bridge:
111
+ logger.info("freeze bridge")
112
+ freeze_module(self.qformer)
113
+ self.query_tokens.requires_grad = False
114
+
115
+ def build_llm(self):
116
+ self.lm_name = self.model_config.llm.name
117
+ if self.model_config.llm.name == 'mistral_7b':
118
+ from transformers import AutoModelForCausalLM
119
+ config = AutoConfig.from_pretrained(
120
+ self.model_config.llm.pretrained_llm_path,
121
+ torch_dtype=torch.bfloat16,
122
+ token=token,
123
+ # attn_implementation="flash_attention_2",
124
+ )
125
+ self.lm = AutoModelForCausalLM.from_config(config)
126
+ elif self.model_config.llm.name == 'internlm_20b':
127
+ from transformers import AutoModelForCausalLM
128
+ self.lm = AutoModelForCausalLM.from_pretrained(
129
+ self.model_config.llm.pretrained_llm_path,
130
+ torch_dtype=torch.bfloat16,
131
+ trust_remote_code=True,
132
+ )
133
+ self.lm.gradient_checkpointing = True
134
+ self.lm._set_gradient_checkpointing()
135
+ elif self.model_config.llm.name == 'internlm2_5_7b':
136
+ from transformers import AutoModelForCausalLM
137
+ self.lm = AutoModelForCausalLM.from_pretrained(
138
+ self.model_config.llm.pretrained_llm_path,
139
+ torch_dtype=torch.bfloat16,
140
+ trust_remote_code=True,
141
+ local_files_only=True,
142
+ )
143
+ else:
144
+ raise NotImplementedError(self.model_config.llm.name)
145
+
146
+ self.freeze_llm = self.model_config.get("freeze_llm", True)
147
+ logger.info(f'freeze_llm: {self.freeze_llm}')
148
+ if self.freeze_llm:
149
+ logger.info("freeze llm")
150
+ freeze_module(self.lm)
151
+
152
+ if self.model_config.llm.use_lora:
153
+ self.use_lora = True
154
+ from peft import get_peft_model, LoraConfig, TaskType
155
+ logger.info("Use lora")
156
+ if self.model_config.llm.name == 'internlm_20b':
157
+ peft_config = LoraConfig(
158
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
159
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
160
+ target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output']
161
+ )
162
+ else:
163
+ peft_config = LoraConfig(
164
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
165
+ r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout,
166
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
167
+ "gate_proj", "up_proj", "down_proj", "lm_head"]
168
+ )
169
+
170
+ self.lm = get_peft_model(self.lm, peft_config)
171
+ self.lm.enable_input_require_grads()
172
+ self.lm.print_trainable_parameters()
173
+ else:
174
+ self.use_lora = False
175
+
176
+
177
+ def build_loss(self):
178
+ self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False)
179
+ if self.use_vision_regression_loss:
180
+ self.image_loss_fct = MSELoss()
181
+
182
+ @property
183
+ def dtype(self):
184
+ return self.lm.dtype
185
+
186
+
187
+ @property
188
+ def device(self):
189
+ return self.lm.device