lyclyc52 commited on
Commit
48ca1e2
·
1 Parent(s): ba9be79

Update: add original llava code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. llava/__init__.py +1 -0
  2. llava/constants.py +12 -0
  3. llava/conversation.py +554 -0
  4. llava/eval/evaluate_interleave.py +339 -0
  5. llava/eval/model_vqa.py +240 -0
  6. llava/mm_utils.py +381 -0
  7. llava/model/__init__.py +23 -0
  8. llava/model/apply_delta.py +47 -0
  9. llava/model/builder.py +250 -0
  10. llava/model/consolidate.py +30 -0
  11. llava/model/language_model/llava_gemma.py +122 -0
  12. llava/model/language_model/llava_llama.py +131 -0
  13. llava/model/language_model/llava_mistral.py +127 -0
  14. llava/model/language_model/llava_mixtral.py +122 -0
  15. llava/model/language_model/llava_mpt.py +105 -0
  16. llava/model/language_model/llava_qwen.py +128 -0
  17. llava/model/language_model/llava_qwen_moe.py +128 -0
  18. llava/model/llava_arch.py +390 -0
  19. llava/model/make_delta.py +52 -0
  20. llava/model/multimodal_encoder/builder.py +14 -0
  21. llava/model/multimodal_encoder/clip_encoder.py +114 -0
  22. llava/model/multimodal_encoder/siglip_encoder.py +620 -0
  23. llava/model/multimodal_projector/builder.py +65 -0
  24. llava/model/multimodal_projector/pooler_projector.py +33 -0
  25. llava/model/multimodal_resampler/builder.py +34 -0
  26. llava/model/multimodal_resampler/masked_drop.py +80 -0
  27. llava/model/multimodal_resampler/perceiver.py +155 -0
  28. llava/model/multimodal_resampler/qformer.py +1160 -0
  29. llava/model/multimodal_resampler/spatial_pool.py +45 -0
  30. llava/model/utils.py +20 -0
  31. llava/utils.py +134 -0
  32. llavavid/__init__.py +1 -0
  33. llavavid/constants.py +13 -0
  34. llavavid/conversation.py +406 -0
  35. llavavid/mm_utils.py +246 -0
  36. llavavid/model/__init__.py +6 -0
  37. llavavid/model/apply_delta.py +48 -0
  38. llavavid/model/builder.py +172 -0
  39. llavavid/model/consolidate.py +29 -0
  40. llavavid/model/language_model/llava_llama.py +137 -0
  41. llavavid/model/language_model/llava_mistral.py +158 -0
  42. llavavid/model/language_model/llava_mpt.py +97 -0
  43. llavavid/model/llava_arch.py +481 -0
  44. llavavid/model/make_delta.py +52 -0
  45. llavavid/model/multimodal_encoder/builder.py +11 -0
  46. llavavid/model/multimodal_encoder/clip_encoder.py +110 -0
  47. llavavid/model/multimodal_projector/builder.py +51 -0
  48. llavavid/model/multimodal_resampler/builder.py +25 -0
  49. llavavid/model/multimodal_resampler/spatial_pool.py +47 -0
  50. llavavid/model/utils.py +20 -0
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
llava/conversation.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Any, Dict, Union, Tuple
4
+ import re
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+
14
+ SINGLE = auto()
15
+ TWO = auto()
16
+ MPT = auto()
17
+ PLAIN = auto()
18
+ CHATML = auto()
19
+ LLAMA_2 = auto()
20
+ LLAMA_3 = auto()
21
+ QWEN = auto()
22
+ GEMMA = auto()
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Conversation:
27
+ """A class that keeps all conversation history."""
28
+
29
+ system: str
30
+ roles: List[str]
31
+ messages: List[List[str]]
32
+ offset: int
33
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
34
+ sep: str = "###"
35
+ sep2: str = None
36
+ version: str = "Unknown"
37
+
38
+ tokenizer_id: str = ""
39
+ tokenizer: Any = None
40
+ # Stop criteria (the default one is EOS token)
41
+ stop_str: Union[str, List[str]] = None
42
+ # Stops generation if meeting any token in this list
43
+ stop_token_ids: List[int] = None
44
+
45
+ skip_next: bool = False
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0]
53
+ if "mmtag" in self.version:
54
+ init_msg = init_msg.replace("<image>", "").strip()
55
+ messages[0] = (init_role, init_msg)
56
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
57
+ messages.insert(1, (self.roles[1], "Received."))
58
+ elif not init_msg.startswith("<image>"):
59
+ init_msg = init_msg.replace("<image>", "").strip()
60
+ messages[0] = (init_role, "<image>\n" + init_msg)
61
+ else:
62
+ messages[0] = (init_role, init_msg)
63
+
64
+ if self.sep_style == SeparatorStyle.SINGLE:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + ": " + message + self.sep
71
+ else:
72
+ ret += role + ":"
73
+
74
+ elif self.sep_style == SeparatorStyle.TWO:
75
+ seps = [self.sep, self.sep2]
76
+ ret = self.system + seps[0]
77
+ for i, (role, message) in enumerate(messages):
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + ": " + message + seps[i % 2]
82
+ else:
83
+ ret += role + ":"
84
+
85
+ elif self.sep_style == SeparatorStyle.CHATML:
86
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
87
+ for role, message in messages:
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, images = message
91
+ message = "<image>" * len(images) + message
92
+ ret += role + "\n" + message + self.sep + "\n"
93
+ else:
94
+ ret += role + "\n"
95
+ return ret
96
+
97
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
98
+ chat_template_messages = [{"role": "system", "content": self.system}]
99
+ for role, message in messages:
100
+ if message:
101
+ if type(message) is tuple:
102
+ message, images = message
103
+ message = "<image>" * len(images) + message
104
+ chat_template_messages.append({"role": role, "content": message})
105
+
106
+ # print(chat_template_messages)
107
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
108
+ # ret = "" if self.system == "" else self.system + self.sep + "\n"
109
+ # for role, message in messages:
110
+ # if message:
111
+ # if type(message) is tuple:
112
+ # message, images = message
113
+ # message = "<image>" * len(images) + message
114
+ # ret += role + "\n" + message + self.sep + "\n"
115
+ # else:
116
+ # ret += role + "\n"
117
+ # return ret
118
+
119
+ elif self.sep_style == SeparatorStyle.MPT:
120
+ ret = self.system + self.sep
121
+ for role, message in messages:
122
+ if message:
123
+ if type(message) is tuple:
124
+ message, _, _ = message
125
+ ret += role + message + self.sep
126
+ else:
127
+ ret += role
128
+
129
+ elif self.sep_style == SeparatorStyle.GEMMA:
130
+ ret = ""
131
+ for i, (role, message) in enumerate(messages):
132
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
133
+ if message:
134
+ if type(message) is tuple:
135
+ message, _, _ = message
136
+ ret += role + message + self.sep
137
+ else:
138
+ ret += role
139
+
140
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
141
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
142
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
143
+ ret = ""
144
+
145
+ for i, (role, message) in enumerate(messages):
146
+ if i == 0:
147
+ assert message, "first message should not be none"
148
+ assert role == self.roles[0], "first message should come from user"
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i == 0:
153
+ message = wrap_sys(self.system) + message
154
+ if i % 2 == 0:
155
+ message = wrap_inst(message)
156
+ ret += self.sep + message
157
+ else:
158
+ ret += " " + message + " " + self.sep2
159
+ else:
160
+ ret += ""
161
+ ret = ret.lstrip(self.sep)
162
+
163
+ elif self.sep_style == SeparatorStyle.PLAIN:
164
+ seps = [self.sep, self.sep2]
165
+ ret = self.system
166
+ for i, (role, message) in enumerate(messages):
167
+ if message:
168
+ if type(message) is tuple:
169
+ message, _, _ = message
170
+ ret += message + seps[i % 2]
171
+ else:
172
+ ret += ""
173
+ else:
174
+ raise ValueError(f"Invalid style: {self.sep_style}")
175
+
176
+ return ret
177
+
178
+ def append_message(self, role, message):
179
+ self.messages.append([role, message])
180
+
181
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
182
+ if image_process_mode == "Pad":
183
+
184
+ def expand2square(pil_img, background_color=(122, 116, 104)):
185
+ width, height = pil_img.size
186
+ if width == height:
187
+ return pil_img
188
+ elif width > height:
189
+ result = Image.new(pil_img.mode, (width, width), background_color)
190
+ result.paste(pil_img, (0, (width - height) // 2))
191
+ return result
192
+ else:
193
+ result = Image.new(pil_img.mode, (height, height), background_color)
194
+ result.paste(pil_img, ((height - width) // 2, 0))
195
+ return result
196
+
197
+ image = expand2square(image)
198
+ elif image_process_mode in ["Default", "Crop"]:
199
+ pass
200
+ elif image_process_mode == "Resize":
201
+ image = image.resize((336, 336))
202
+ else:
203
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
204
+
205
+ if type(image) is not Image.Image:
206
+ image = Image.open(image).convert("RGB")
207
+
208
+ max_hw, min_hw = max(image.size), min(image.size)
209
+ aspect_ratio = max_hw / min_hw
210
+ max_len, min_len = 1008, 672
211
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
212
+ longest_edge = int(shortest_edge * aspect_ratio)
213
+ W, H = image.size
214
+ if H > W:
215
+ H, W = longest_edge, shortest_edge
216
+ else:
217
+ H, W = shortest_edge, longest_edge
218
+ image = image.resize((W, H))
219
+ if return_pil:
220
+ return image
221
+ else:
222
+ buffered = BytesIO()
223
+ image.save(buffered, format=image_format)
224
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
225
+ return img_b64_str
226
+
227
+ def get_images(self, return_pil=False, return_path=False):
228
+ images = []
229
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
230
+ if i % 2 == 0:
231
+ if type(msg) is tuple:
232
+ msg, image, image_process_mode = msg
233
+ if type(image) != list:
234
+ image = [image]
235
+ for img in image:
236
+ if not return_path:
237
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
238
+ else:
239
+ images.append(img)
240
+ return images
241
+
242
+ def to_gradio_chatbot(self):
243
+ ret = []
244
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
245
+ if i % 2 == 0:
246
+ if type(msg) is tuple:
247
+ msg, image, image_process_mode = msg
248
+ if type(image) != list:
249
+ image = [image]
250
+ if len(image) == 1:
251
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
252
+ else:
253
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
254
+ for img in image:
255
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
256
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}"/>'
257
+ msg = msg.replace("<image>", img_str, 1).strip()
258
+ if len(msg) > 0:
259
+ ret.append([msg, None])
260
+ else:
261
+ ret.append([msg, None])
262
+ else:
263
+ ret[-1][-1] = msg
264
+ return ret
265
+
266
+ def copy(self):
267
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
268
+
269
+ def dict(self):
270
+ if len(self.get_images()) > 0:
271
+ return {
272
+ "system": self.system,
273
+ "roles": self.roles,
274
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
275
+ "offset": self.offset,
276
+ "sep": self.sep,
277
+ "sep2": self.sep2,
278
+ }
279
+ return {
280
+ "system": self.system,
281
+ "roles": self.roles,
282
+ "messages": self.messages,
283
+ "offset": self.offset,
284
+ "sep": self.sep,
285
+ "sep2": self.sep2,
286
+ }
287
+
288
+
289
+ conv_vicuna_v0 = Conversation(
290
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
291
+ roles=("Human", "Assistant"),
292
+ messages=[
293
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
294
+ [
295
+ "Assistant",
296
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
297
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
298
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
299
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
300
+ "renewable and non-renewable energy sources:\n"
301
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
302
+ "energy sources are finite and will eventually run out.\n"
303
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
304
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
305
+ "and other negative effects.\n"
306
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
307
+ "have lower operational costs than non-renewable sources.\n"
308
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
309
+ "locations than non-renewable sources.\n"
310
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
311
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
312
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
313
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
314
+ ],
315
+ ],
316
+ offset=2,
317
+ sep_style=SeparatorStyle.SINGLE,
318
+ sep="###",
319
+ )
320
+
321
+ conv_vicuna_v1 = Conversation(
322
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
323
+ roles=("USER", "ASSISTANT"),
324
+ version="v1",
325
+ messages=[],
326
+ offset=0,
327
+ sep_style=SeparatorStyle.TWO,
328
+ sep=" ",
329
+ sep2="</s>",
330
+ )
331
+
332
+ conv_llama_2 = Conversation(
333
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
334
+
335
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
336
+ roles=("USER", "ASSISTANT"),
337
+ version="llama_v2",
338
+ messages=[],
339
+ offset=0,
340
+ sep_style=SeparatorStyle.LLAMA_2,
341
+ sep="<s>",
342
+ sep2="</s>",
343
+ )
344
+
345
+ conv_llava_llama_2 = Conversation(
346
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
347
+ roles=("USER", "ASSISTANT"),
348
+ version="llama_v2",
349
+ messages=[],
350
+ offset=0,
351
+ sep_style=SeparatorStyle.LLAMA_2,
352
+ sep="<s>",
353
+ sep2="</s>",
354
+ )
355
+
356
+ conv_llava_llama_3 = Conversation(
357
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
358
+ roles=("user", "assistant"),
359
+ version="llama_v3",
360
+ messages=[],
361
+ offset=0,
362
+ sep="<|eot_id|>",
363
+ sep_style=SeparatorStyle.LLAMA_3,
364
+ tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
365
+ tokenizer=AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct"),
366
+ stop_token_ids=[128009],
367
+ )
368
+
369
+ conv_mistral_instruct = Conversation(
370
+ system="",
371
+ roles=("USER", "ASSISTANT"),
372
+ version="llama_v2",
373
+ messages=[],
374
+ offset=0,
375
+ sep_style=SeparatorStyle.LLAMA_2,
376
+ sep="",
377
+ sep2="</s>",
378
+ )
379
+
380
+ conv_llava_llama_2_simple = Conversation(
381
+ system="Answer the questions about the visual content that the user provides.",
382
+ roles=("USER", "ASSISTANT"),
383
+ version="llama_v2",
384
+ messages=[],
385
+ offset=0,
386
+ sep_style=SeparatorStyle.LLAMA_2,
387
+ sep="<s>",
388
+ sep2="</s>",
389
+ )
390
+
391
+ conv_llava_llama_2_mmtag = Conversation(
392
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
393
+ roles=("USER", "ASSISTANT"),
394
+ version="llama_v2_mmtag",
395
+ messages=[],
396
+ offset=0,
397
+ sep_style=SeparatorStyle.LLAMA_2,
398
+ sep="<s>",
399
+ sep2="</s>",
400
+ )
401
+
402
+ conv_mpt = Conversation(
403
+ system="""<|im_start|>system
404
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
405
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
406
+ version="mpt",
407
+ messages=[],
408
+ offset=0,
409
+ sep_style=SeparatorStyle.MPT,
410
+ sep="<|im_end|>",
411
+ )
412
+
413
+ conv_qwen = Conversation(
414
+ system="""<|im_start|>system
415
+ You are a helpful assistant.""",
416
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
417
+ version="qwen",
418
+ messages=[],
419
+ offset=0,
420
+ sep_style=SeparatorStyle.CHATML,
421
+ sep="<|im_end|>",
422
+ )
423
+
424
+ conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
425
+
426
+ conv_llava_plain = Conversation(
427
+ system="",
428
+ roles=("", ""),
429
+ messages=[],
430
+ offset=0,
431
+ sep_style=SeparatorStyle.PLAIN,
432
+ sep="\n",
433
+ )
434
+
435
+ conv_llava_v0 = Conversation(
436
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
437
+ roles=("Human", "Assistant"),
438
+ messages=[],
439
+ offset=0,
440
+ sep_style=SeparatorStyle.SINGLE,
441
+ sep="###",
442
+ )
443
+
444
+ conv_llava_v0_mmtag = Conversation(
445
+ system="A chat between a curious user and an artificial intelligence assistant. "
446
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
447
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
448
+ roles=("Human", "Assistant"),
449
+ messages=[],
450
+ offset=0,
451
+ sep_style=SeparatorStyle.SINGLE,
452
+ sep="###",
453
+ version="v0_mmtag",
454
+ )
455
+
456
+ conv_llava_v1 = Conversation(
457
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
458
+ roles=("USER", "ASSISTANT"),
459
+ version="v1",
460
+ messages=[],
461
+ offset=0,
462
+ sep_style=SeparatorStyle.TWO,
463
+ sep=" ",
464
+ sep2="</s>",
465
+ )
466
+
467
+ conv_llava_v1_mmtag = Conversation(
468
+ system="A chat between a curious user and an artificial intelligence assistant. "
469
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
470
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
471
+ roles=("USER", "ASSISTANT"),
472
+ messages=[],
473
+ offset=0,
474
+ sep_style=SeparatorStyle.TWO,
475
+ sep=" ",
476
+ sep2="</s>",
477
+ version="v1_mmtag",
478
+ )
479
+
480
+ conv_mistral_orca = Conversation(
481
+ system="""<|im_start|>system
482
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
483
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
484
+ version="mpt",
485
+ messages=[],
486
+ offset=0,
487
+ sep_style=SeparatorStyle.MPT,
488
+ sep="<|im_end|>",
489
+ )
490
+
491
+ conv_mistral_zephyr = Conversation(
492
+ system="""<|system|>
493
+ You are a helpful AI assistant.""",
494
+ roles=("<|user|>\n", "<|assistant|>\n"),
495
+ version="mpt",
496
+ messages=[],
497
+ offset=0,
498
+ sep_style=SeparatorStyle.MPT,
499
+ sep="</s>",
500
+ )
501
+
502
+ conv_mistral_direct = Conversation(
503
+ system="""<|im_start|>system
504
+ Answer the questions.""",
505
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
506
+ version="mpt",
507
+ messages=[],
508
+ offset=0,
509
+ sep_style=SeparatorStyle.MPT,
510
+ sep="<|im_end|>",
511
+ )
512
+
513
+ conv_chatml_direct = Conversation(
514
+ system="""<|im_start|>system
515
+ Answer the questions.""",
516
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
517
+ version="mpt",
518
+ messages=[],
519
+ offset=0,
520
+ sep_style=SeparatorStyle.MPT,
521
+ sep="<|im_end|>",
522
+ )
523
+
524
+ default_conversation = conv_vicuna_v0
525
+ conv_templates = {
526
+ "default": conv_vicuna_v0,
527
+ "v0": conv_vicuna_v0,
528
+ "v1": conv_vicuna_v1,
529
+ "vicuna_v1": conv_vicuna_v1,
530
+ "llama_2": conv_llama_2,
531
+ "mistral_instruct": conv_mistral_instruct,
532
+ "mistral_orca": conv_mistral_orca,
533
+ "mistral_zephyr": conv_mistral_zephyr,
534
+ "mistral_direct": conv_mistral_direct,
535
+ "plain": conv_llava_plain,
536
+ "v0_plain": conv_llava_plain,
537
+ "chatml_direct": conv_chatml_direct,
538
+ "llava_v0": conv_llava_v0,
539
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
540
+ "llava_v1": conv_llava_v1,
541
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
542
+ "llava_llama_2": conv_llava_llama_2,
543
+ "llava_llama_3": conv_llava_llama_3,
544
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
545
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
546
+ "llava_mistral_instruct": conv_mistral_instruct,
547
+ "mpt": conv_mpt,
548
+ "qwen_1_5": conv_qwen,
549
+ "gemma_instruct": conv_gemma_instruct,
550
+ }
551
+
552
+
553
+ if __name__ == "__main__":
554
+ print(default_conversation.get_prompt())
llava/eval/evaluate_interleave.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from rouge import Rouge
3
+ import argparse
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"]
12
+ image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"]
13
+ visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"]
14
+ visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"]
15
+ text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"]
16
+ multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"]
17
+
18
+ puzzle = ["RAVEN"]
19
+ nlrv2 = ["NLVR2_Mantis"]
20
+ qbench = ["QBench"]
21
+
22
+ class Eval:
23
+ def __init__(self):
24
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
25
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
26
+ self.punct = [
27
+ ";",
28
+ r"/",
29
+ "[",
30
+ "]",
31
+ '"',
32
+ "{",
33
+ "}",
34
+ "(",
35
+ ")",
36
+ "=",
37
+ "+",
38
+ "\\",
39
+ "_",
40
+ "-",
41
+ ">",
42
+ "<",
43
+ "@",
44
+ "`",
45
+ ",",
46
+ "?",
47
+ "!",
48
+ ]
49
+
50
+ def processPunctuation(self, inText):
51
+ outText = inText
52
+ for p in self.punct:
53
+ if (p + " " in inText or " " + p in inText) or (
54
+ re.search(self.commaStrip, inText) != None
55
+ ):
56
+ outText = outText.replace(p, "")
57
+ else:
58
+ outText = outText.replace(p, " ")
59
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
60
+ return outText
61
+
62
+ def process(self, answer):
63
+ answer = answer.replace("\n", " ")
64
+ answer = answer.replace("\t", " ")
65
+ answer = answer.strip()
66
+ answer = self.processPunctuation(answer)
67
+ answer = answer.strip('\'')
68
+ answer = answer.strip('\"')
69
+ answer = answer.strip(')')
70
+ answer = answer.strip('(')
71
+ answer = answer.strip().lower()
72
+ return answer
73
+
74
+ def evaluate_rouge(self,preds):
75
+ rouge = Rouge()
76
+ acc = {'f': []}
77
+ eval_list = []
78
+ for i, res in enumerate(preds):
79
+ sample_id = res['sample_id']
80
+ # print(sample_id)
81
+ gt_ans = self.process(res["gt_response"])
82
+ pred_ans = self.process(res["pred_response"])
83
+ # assert gt_ans != ''
84
+
85
+ if gt_ans == '':
86
+ continue
87
+
88
+ if pred_ans == '':
89
+ s = 0
90
+ else:
91
+ if len(pred_ans) > 512:
92
+ pred_ans = pred_ans[0: 512]
93
+ s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f']
94
+ acc['f'].append(s)
95
+ eval_list.append({'id':str(sample_id),'score':str(round(s,3))})
96
+ results = {'Rouge-L f': np.mean(acc['f'])}
97
+ return results,eval_list
98
+
99
+
100
+ def judge_multi_choice(self,sample):
101
+ sample_id = sample['sample_id']
102
+ gt_ans = sample["gt_response"]
103
+ pred_ans = sample["pred_response"]
104
+
105
+ if ":" in pred_ans:
106
+ a_list = pred_ans.split(":")
107
+ a_list = [a.strip() for a in a_list ]
108
+ for a in a_list:
109
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
110
+ pred_ans = a
111
+
112
+ if pred_ans == gt_ans:
113
+ return 1
114
+ else:
115
+ return 0
116
+
117
+ def process_sample(self,sample):
118
+ sample["gt_response"] = self.process(sample["gt_response"])
119
+ sample["pred_response"] = self.process(sample["pred_response"])
120
+
121
+ def evaluate_multichoice(self, preditions):
122
+ correct = 0
123
+ eval_list = []
124
+ for i, sample in enumerate(preditions):
125
+ self.process_sample(sample)
126
+ score = self.judge_multi_choice(sample)
127
+ sample_id = sample['sample_id']
128
+ sample['result'] = score
129
+ eval_list.append({'id':str(sample_id),'score':str(score)})
130
+ correct+=score
131
+ return {'Accuracy':correct/len(preditions)},eval_list
132
+
133
+ def evaluate_multi_choice_image(self,preditions):
134
+ correct = 0
135
+ eval_list = []
136
+ for i,sample in enumerate(preditions):
137
+ gt_ans = self.process(sample["gt_response"])
138
+ pred_ans = self.process(sample["pred_response"])
139
+ sample_id = sample['sample_id']
140
+
141
+ if ":" in pred_ans:
142
+ a_list = pred_ans.split(":")
143
+ a_list = [a.strip() for a in a_list ]
144
+ for a in a_list:
145
+ if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]:
146
+ pred_ans = a
147
+
148
+ if gt_ans == pred_ans:
149
+ score = 1
150
+ else:
151
+ score = 0
152
+ sample_id = sample['sample_id']
153
+ sample['result'] = score
154
+ eval_list.append({'id':str(sample_id),'score':str(score)})
155
+ correct+=score
156
+ return {'Accuracy':correct/len(preditions)},eval_list
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = argparse.ArgumentParser()
161
+ parser.add_argument('--result-dir', type=str, required=True)
162
+
163
+ args = parser.parse_args()
164
+
165
+ result_file = os.path.join(args.result_dir, "result.jsonl")
166
+
167
+ if not os.path.exists(result_file):
168
+ print('No prediction file found')
169
+ exit(0)
170
+ with open(result_file, 'r') as f:
171
+ preds_all = [json.loads(line) for line in f]
172
+
173
+ preds_all_dict = dict()
174
+ for pred in preds_all:
175
+ if pred["dataset"] not in preds_all_dict:
176
+ preds_all_dict[pred["dataset"]] = list()
177
+ preds_all_dict[pred["dataset"]].append(pred)
178
+
179
+ image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"]
180
+ E = Eval()
181
+
182
+ eval_result_list = dict()
183
+ eval_result_list_detail = dict()
184
+
185
+ for dataset in preds_all_dict:
186
+
187
+ preds = preds_all_dict[dataset]
188
+ question_type = preds[0]["question_type"]
189
+
190
+ if question_type == 'open-ended':
191
+ eval_result, eval_list = E.evaluate_rouge(preds)
192
+
193
+ elif question_type == 'multi-choice' or dataset == 'nlrv2':
194
+ if dataset in image_choice_dataset_list:
195
+ eval_result, eval_list = E.evaluate_multi_choice_image(preds)
196
+ else:
197
+ eval_result, eval_list = E.evaluate_multichoice(preds)
198
+
199
+ else:
200
+ eval_result = 'Dataset not supported'
201
+ print('Dataset not supported')
202
+ exit(0)
203
+
204
+ print(dataset, end = ': ')
205
+ print(eval_result)
206
+
207
+ eval_result_list[dataset] = eval_result
208
+ eval_result_list_detail[dataset] = eval_list
209
+
210
+ os.makedirs(args.result_dir, exist_ok=True)
211
+ with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f:
212
+ json.dump(eval_result_list, f, indent=4)
213
+
214
+ with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f:
215
+ json.dump(eval_result_list_detail, f, indent=4)
216
+
217
+
218
+ eval_cat_list = dict()
219
+ print()
220
+
221
+ # spot_the_diff
222
+ score = 0
223
+ count = 0
224
+ for dataset in eval_result_list:
225
+ if dataset in spot_the_diff:
226
+ count += 1
227
+ score += list(eval_result_list[dataset].values())[0]
228
+ if count > 0:
229
+ score /= count
230
+ eval_cat_list["spot_the_diff"] = score
231
+ print("spot_the_diff", end = ': ')
232
+ print('{:.2f}'.format(100 * score))
233
+
234
+ # image_edit_instruct
235
+ score = 0
236
+ count = 0
237
+ for dataset in eval_result_list:
238
+ if dataset in image_edit_instruct:
239
+ count += 1
240
+ score += list(eval_result_list[dataset].values())[0]
241
+ if count > 0:
242
+ score /= count
243
+ eval_cat_list["image_edit_instruct"] = score
244
+ print("image_edit_instruct", end = ': ')
245
+ print('{:.2f}'.format(100 * score))
246
+
247
+ # visual_story_telling
248
+ score = 0
249
+ count = 0
250
+ for dataset in eval_result_list:
251
+ if dataset in visual_story_telling:
252
+ count += 1
253
+ score += list(eval_result_list[dataset].values())[0]
254
+ if count > 0:
255
+ score /= count
256
+ eval_cat_list["visual_story_telling"] = score
257
+ print("visual_story_telling", end = ': ')
258
+ print('{:.2f}'.format(100 * score))
259
+
260
+ # visual_cloze
261
+ score = 0
262
+ count = 0
263
+ for dataset in eval_result_list:
264
+ if dataset in visual_cloze:
265
+ count += 1
266
+ score += list(eval_result_list[dataset].values())[0]
267
+ if count > 0:
268
+ score /= count
269
+ eval_cat_list["visual_cloze"] = score
270
+ print("visual_cloze", end = ': ')
271
+ print('{:.2f}'.format(100 * score))
272
+
273
+ # text_rich_vqa
274
+ score = 0
275
+ count = 0
276
+ for dataset in eval_result_list:
277
+ if dataset in text_rich_vqa:
278
+ count += 1
279
+ score += list(eval_result_list[dataset].values())[0]
280
+ if count > 0:
281
+ score /= count
282
+ eval_cat_list["text_rich_vqa"] = score
283
+ print("text_rich_vqa", end = ': ')
284
+ print('{:.2f}'.format(100 * score))
285
+
286
+ # multi_image_vqa
287
+ score = 0
288
+ count = 0
289
+ for dataset in eval_result_list:
290
+ if dataset in multi_image_vqa:
291
+ count += 1
292
+ score += list(eval_result_list[dataset].values())[0]
293
+ if count > 0:
294
+ score /= count
295
+ eval_cat_list["multi_image_vqa"] = score
296
+ print("multi_image_vqa", end = ': ')
297
+ print('{:.2f}'.format(100 * score))
298
+
299
+ # puzzle
300
+ score = 0
301
+ count = 0
302
+ for dataset in eval_result_list:
303
+ if dataset in puzzle:
304
+ count += 1
305
+ score += list(eval_result_list[dataset].values())[0]
306
+ if count > 0:
307
+ score /= count
308
+ eval_cat_list["puzzle"] = score
309
+ print("puzzle", end = ': ')
310
+ print('{:.2f}'.format(100 * score))
311
+
312
+ # nlrv2
313
+ score = 0
314
+ count = 0
315
+ for dataset in eval_result_list:
316
+ if dataset in nlrv2:
317
+ count += 1
318
+ score += list(eval_result_list[dataset].values())[0]
319
+ if count > 0:
320
+ score /= count
321
+ eval_cat_list["nlrv2"] = score
322
+ print("nlrv2", end = ': ')
323
+ print('{:.2f}'.format(100 * score))
324
+
325
+ # qbench
326
+ score = 0
327
+ count = 0
328
+ for dataset in eval_result_list:
329
+ if dataset in qbench:
330
+ count += 1
331
+ score += list(eval_result_list[dataset].values())[0]
332
+ if count > 0:
333
+ score /= count
334
+ eval_cat_list["qbench"] = score
335
+ print("qbench", end = ': ')
336
+ print('{:.2f}'.format(100 * score))
337
+
338
+ with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f:
339
+ json.dump(eval_cat_list, f, indent=4)
llava/eval/model_vqa.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ import shortuuid
7
+
8
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from llava.conversation import conv_templates, SeparatorStyle
10
+ from llava.model.builder import load_pretrained_model
11
+ from llava.utils import disable_torch_init
12
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13
+
14
+ from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
15
+ from typing import Dict, Optional, Sequence, List
16
+ import transformers
17
+ import re
18
+
19
+ from PIL import Image
20
+ import math
21
+
22
+
23
+ def split_list(lst, n):
24
+ """Split a list into n (roughly) equal-sized chunks"""
25
+ chunk_size = math.ceil(len(lst) / n) # integer division
26
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
27
+
28
+
29
+ def get_chunk(lst, n, k):
30
+ chunks = split_list(lst, n)
31
+ return chunks[k]
32
+
33
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
34
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
35
+
36
+ im_start, im_end = tokenizer.additional_special_tokens_ids
37
+ nl_tokens = tokenizer("\n").input_ids
38
+ _system = tokenizer("system").input_ids + nl_tokens
39
+ _user = tokenizer("user").input_ids + nl_tokens
40
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
41
+
42
+ # Apply prompt templates
43
+ input_ids, targets = [], []
44
+
45
+ source = sources
46
+ if roles[source[0]["from"]] != roles["human"]:
47
+ source = source[1:]
48
+
49
+ input_id, target = [], []
50
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
51
+ input_id += system
52
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
53
+ assert len(input_id) == len(target)
54
+ for j, sentence in enumerate(source):
55
+ role = roles[sentence["from"]]
56
+ if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
57
+ num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
58
+ texts = sentence["value"].split('<image>')
59
+ _input_id = tokenizer(role).input_ids + nl_tokens
60
+ for i,text in enumerate(texts):
61
+ _input_id += tokenizer(text).input_ids
62
+ if i<len(texts)-1:
63
+ _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
64
+ _input_id += [im_end] + nl_tokens
65
+ assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
66
+ else:
67
+ if sentence["value"] is None:
68
+ _input_id = tokenizer(role).input_ids + nl_tokens
69
+ else:
70
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
71
+ input_id += _input_id
72
+ if role == "<|im_start|>user":
73
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
74
+ elif role == "<|im_start|>assistant":
75
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
76
+ else:
77
+ raise NotImplementedError
78
+ target += _target
79
+
80
+ input_ids.append(input_id)
81
+ targets.append(target)
82
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
83
+ targets = torch.tensor(targets, dtype=torch.long)
84
+ return input_ids
85
+
86
+ def eval_model(args):
87
+
88
+ # Model
89
+ disable_torch_init()
90
+ model_path = os.path.expanduser(args.model_path)
91
+ model_name = get_model_name_from_path(model_path)
92
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
93
+
94
+ # Data
95
+ with open(os.path.expanduser(args.question_file)) as f:
96
+ questions = json.load(f)
97
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
98
+ answers_file = os.path.expanduser(args.answers_file)
99
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
100
+ ans_file = open(answers_file, "w")
101
+
102
+ for line in tqdm(questions):
103
+ idx = line["sample_id"]
104
+ question_type = line["metadata"]["question_type"]
105
+ dataset_name = line["metadata"]["dataset"]
106
+ gt = line["conversations"][1]["value"]
107
+
108
+ image_files = line["image"]
109
+ qs = line["conversations"][0]["value"]
110
+ cur_prompt = args.extra_prompt + qs
111
+
112
+ args.conv_mode = "qwen_1_5"
113
+
114
+ conv = conv_templates[args.conv_mode].copy()
115
+ conv.append_message(conv.roles[0], qs)
116
+ conv.append_message(conv.roles[1], None)
117
+ prompt = conv.get_prompt()
118
+
119
+ input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
120
+ img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX)
121
+
122
+ image_tensors = []
123
+ for image_file in image_files:
124
+ image = Image.open(os.path.join(args.image_folder, image_file))
125
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
126
+ image_tensors.append(image_tensor.half().cuda())
127
+ # image_tensors = torch.cat(image_tensors, dim=0)
128
+
129
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
130
+ keywords = [stop_str]
131
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
132
+
133
+ with torch.inference_mode():
134
+ output_ids = model.generate(
135
+ input_ids,
136
+ images=image_tensors,
137
+ do_sample=True if args.temperature > 0 else False,
138
+ temperature=args.temperature,
139
+ top_p=args.top_p,
140
+ num_beams=args.num_beams,
141
+ # no_repeat_ngram_size=3,
142
+ max_new_tokens=1024,
143
+ use_cache=True)
144
+
145
+
146
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
147
+ outputs = outputs.strip()
148
+ if outputs.endswith(stop_str):
149
+ outputs = outputs[:-len(stop_str)]
150
+ outputs = outputs.strip()
151
+
152
+ ans_id = shortuuid.uuid()
153
+ ans_file.write(json.dumps({
154
+ "dataset": dataset_name,
155
+ "sample_id": idx,
156
+ "prompt": cur_prompt,
157
+ "pred_response": outputs,
158
+ "gt_response": gt,
159
+ "shortuuid": ans_id,
160
+ "model_id": model_name,
161
+ "question_type": question_type,
162
+ }) + "\n")
163
+ ans_file.flush()
164
+
165
+ if len(line["conversations"]) > 2:
166
+
167
+ for i in range(2, len(line["conversations"]), 2):
168
+ input_ids = torch.cat((input_ids, output_ids), dim=1)
169
+
170
+ gt = line["conversations"][i + 1]["value"]
171
+ qs = line["conversations"][i]["value"]
172
+ cur_prompt = args.extra_prompt + qs
173
+
174
+ args.conv_mode = "qwen_1_5"
175
+
176
+ conv = conv_templates[args.conv_mode].copy()
177
+ conv.append_message(conv.roles[0], qs)
178
+ conv.append_message(conv.roles[1], None)
179
+ prompt = conv.get_prompt()
180
+
181
+ input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
182
+ input_ids = torch.cat((input_ids, input_ids_new), dim=1)
183
+ img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX)
184
+
185
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
186
+ keywords = [stop_str]
187
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
188
+
189
+ with torch.inference_mode():
190
+ output_ids = model.generate(
191
+ input_ids,
192
+ images=image_tensors,
193
+ do_sample=True if args.temperature > 0 else False,
194
+ temperature=args.temperature,
195
+ top_p=args.top_p,
196
+ num_beams=args.num_beams,
197
+ # no_repeat_ngram_size=3,
198
+ max_new_tokens=1024,
199
+ use_cache=True)
200
+
201
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
202
+ outputs = outputs.strip()
203
+ if outputs.endswith(stop_str):
204
+ outputs = outputs[:-len(stop_str)]
205
+ outputs = outputs.strip()
206
+
207
+ ans_id = shortuuid.uuid()
208
+ ans_file.write(json.dumps({
209
+ "dataset": dataset_name,
210
+ "sample_id": idx,
211
+ "prompt": cur_prompt,
212
+ "pred_response": outputs,
213
+ "gt_response": gt,
214
+ "shortuuid": ans_id,
215
+ "model_id": model_name,
216
+ "question_type": question_type,
217
+ }) + "\n")
218
+ ans_file.flush()
219
+
220
+
221
+ ans_file.close()
222
+
223
+ if __name__ == "__main__":
224
+ parser = argparse.ArgumentParser()
225
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
226
+ parser.add_argument("--model-base", type=str, default=None)
227
+ parser.add_argument("--image-folder", type=str, default="")
228
+ parser.add_argument("--extra-prompt", type=str, default="")
229
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
230
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
231
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
232
+ parser.add_argument("--num-chunks", type=int, default=1)
233
+ parser.add_argument("--chunk-idx", type=int, default=0)
234
+ parser.add_argument("--temperature", type=float, default=0.2)
235
+ parser.add_argument("--top_p", type=float, default=None)
236
+ parser.add_argument("--num_beams", type=int, default=1)
237
+ parser.add_argument("--test_size", type=int, default=10000000)
238
+ args = parser.parse_args()
239
+
240
+ eval_model(args)
llava/mm_utils.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def resize_and_center_crop(image, shortest_edge_length):
13
+ # Calculate new dimensions and resize
14
+ aspect_ratio = float(image.width) / float(image.height)
15
+ if aspect_ratio > 1:
16
+ new_width = int(shortest_edge_length * aspect_ratio)
17
+ new_height = shortest_edge_length
18
+ else:
19
+ new_width = shortest_edge_length
20
+ new_height = int(shortest_edge_length / aspect_ratio)
21
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
22
+
23
+ # Calculate the position and perform the center crop
24
+ left = (new_width - shortest_edge_length) / 2
25
+ top = (new_height - shortest_edge_length) / 2
26
+ right = (new_width + shortest_edge_length) / 2
27
+ bottom = (new_height + shortest_edge_length) / 2
28
+ cropped_image = resized_image.crop((left, top, right, bottom))
29
+
30
+ return cropped_image
31
+
32
+
33
+ def auto_pad_images(image, grid_params):
34
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
35
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
36
+
37
+ # Step 1: Calculate and find the closest aspect ratio
38
+ input_width, input_height = image.size
39
+ input_aspect_ratio = input_width / input_height
40
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
41
+ closest_aspect_ratio = min(candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0]))
42
+
43
+ candidate_resolutions = [(x[1], x[2]) for x in candidate_resolutions if abs(x[0] - closest_aspect_ratio[0]) < 1e-3]
44
+
45
+ target_resolution = min(candidate_resolutions, key=lambda res: abs(max(input_width, input_height) / max(res) - 1))
46
+
47
+ resize_width, resize_height = target_resolution
48
+ if input_width > input_height:
49
+ resize_height = int(resize_width / input_aspect_ratio)
50
+ else:
51
+ resize_width = int(resize_height * input_aspect_ratio)
52
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
53
+
54
+ # Step 5: Pad the resized image if necessary to match the target resolution
55
+ pad_width = target_resolution[0] - resize_width
56
+ pad_height = target_resolution[1] - resize_height
57
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
58
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
59
+
60
+ return padded_image
61
+
62
+
63
+ def extract_patches(image, patch_size, overlap_ratio):
64
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
65
+ assert patch_size > 0, "Patch size should be greater than 0"
66
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
67
+
68
+ W, H = image.size
69
+ patches = []
70
+
71
+ stride = int(patch_size * (1 - overlap_ratio))
72
+
73
+ num_patches_y = (H - patch_size) // stride + 1
74
+ num_patches_x = (W - patch_size) // stride + 1
75
+
76
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
77
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
78
+
79
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
80
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
81
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
82
+ patches.append(patch)
83
+
84
+ return patches
85
+
86
+
87
+ def process_highres_image_crop_split(image, data_args, processor=None):
88
+ crop_resolution = data_args.image_crop_resolution
89
+ split_resolution = data_args.image_split_resolution
90
+ if processor is None:
91
+ processor = data_args.image_processor
92
+ image_crop = resize_and_center_crop(image, crop_resolution)
93
+ image_patches = extract_patches(image_crop, patch_size=split_resolution, overlap_ratio=0)
94
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
95
+ return torch.stack(image_patches, dim=0)
96
+
97
+
98
+ def process_highres_image(image, processor, grid_pinpoints):
99
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
100
+ width_height = max(image.size)
101
+ fit_grid_params = [x for x in grid_params if x >= width_height]
102
+ if len(fit_grid_params) == 0:
103
+ select_size = max(grid_params)
104
+ else:
105
+ select_size = min(fit_grid_params)
106
+ # FIXME: always select the 448
107
+ select_size = max(grid_params)
108
+ image_padded = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
109
+
110
+ # FIXME: this seems to be a bug that it always resizes instead of padding
111
+ image_original_resize = image.resize((processor.size["shortest_edge"], processor.size["shortest_edge"]))
112
+ image_padded = image_padded.resize((select_size, select_size))
113
+ image_patches = extract_patches(image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0)
114
+ image_patches = [image_original_resize] + image_patches
115
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
116
+ return torch.stack(image_patches, dim=0)
117
+
118
+
119
+ def select_best_resolution(original_size, possible_resolutions):
120
+ """
121
+ Selects the best resolution from a list of possible resolutions based on the original size.
122
+
123
+ Args:
124
+ original_size (tuple): The original size of the image in the format (width, height).
125
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
126
+
127
+ Returns:
128
+ tuple: The best fit resolution in the format (width, height).
129
+ """
130
+ original_width, original_height = original_size
131
+ best_fit = None
132
+ max_effective_resolution = 0
133
+ min_wasted_resolution = float("inf")
134
+
135
+ for width, height in possible_resolutions:
136
+ # Calculate the downscaled size to keep the aspect ratio
137
+ scale = min(width / original_width, height / original_height)
138
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
139
+
140
+ # Calculate effective and wasted resolutions
141
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
142
+ wasted_resolution = (width * height) - effective_resolution
143
+
144
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
145
+ max_effective_resolution = effective_resolution
146
+ min_wasted_resolution = wasted_resolution
147
+ best_fit = (width, height)
148
+
149
+ return best_fit
150
+
151
+
152
+ def resize_and_pad_image(image, target_resolution):
153
+ """
154
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
155
+
156
+ Args:
157
+ image (PIL.Image.Image): The input image.
158
+ target_resolution (tuple): The target resolution (width, height) of the image.
159
+
160
+ Returns:
161
+ PIL.Image.Image: The resized and padded image.
162
+ """
163
+ original_width, original_height = image.size
164
+ target_width, target_height = target_resolution
165
+
166
+ # Determine which dimension (width or height) to fill
167
+ scale_w = target_width / original_width
168
+ scale_h = target_height / original_height
169
+
170
+ if scale_w < scale_h:
171
+ # Width will be filled completely
172
+ new_width = target_width
173
+ new_height = min(math.ceil(original_height * scale_w), target_height)
174
+ else:
175
+ # Height will be filled completely
176
+ new_height = target_height
177
+ new_width = min(math.ceil(original_width * scale_h), target_width)
178
+
179
+ # Resize the image
180
+ resized_image = image.resize((new_width, new_height))
181
+
182
+ # Create a new image with the target size and paste the resized image onto it
183
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
184
+ paste_x = (target_width - new_width) // 2
185
+ paste_y = (target_height - new_height) // 2
186
+ new_image.paste(resized_image, (paste_x, paste_y))
187
+
188
+ return new_image
189
+
190
+
191
+ def divide_to_patches(image, patch_size):
192
+ """
193
+ Divides an image into patches of a specified size.
194
+
195
+ Args:
196
+ image (PIL.Image.Image): The input image.
197
+ patch_size (int): The size of each patch.
198
+
199
+ Returns:
200
+ list: A list of PIL.Image.Image objects representing the patches.
201
+ """
202
+ patches = []
203
+ width, height = image.size
204
+ for i in range(0, height, patch_size):
205
+ for j in range(0, width, patch_size):
206
+ box = (j, i, j + patch_size, i + patch_size)
207
+ patch = image.crop(box)
208
+ patches.append(patch)
209
+
210
+ return patches
211
+
212
+
213
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
214
+ """
215
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
216
+
217
+ Args:
218
+ image_size (tuple): The size of the input image in the format (width, height).
219
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
220
+ patch_size (int): The size of each image patch.
221
+
222
+ Returns:
223
+ tuple: The shape of the image patch grid in the format (width, height).
224
+ """
225
+ if isinstance(grid_pinpoints, str):
226
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
227
+ grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
228
+ grid_pinpoints = [[int(x) * patch_size for x in item.split(",")] for item in grid_pinpoints]
229
+
230
+ if type(grid_pinpoints) is list:
231
+ possible_resolutions = grid_pinpoints
232
+ else:
233
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
234
+ width, height = select_best_resolution(image_size, possible_resolutions)
235
+ return width // patch_size, height // patch_size
236
+
237
+
238
+ def process_anyres_image(image, processor, grid_pinpoints):
239
+ """
240
+ Process an image with variable resolutions.
241
+
242
+ Args:
243
+ image (PIL.Image.Image): The input image to be processed.
244
+ processor: The image processor object.
245
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
246
+
247
+ Returns:
248
+ torch.Tensor: A tensor containing the processed image patches.
249
+ """
250
+ # Convert grid_pinpoints from string to list
251
+ if isinstance(grid_pinpoints, str):
252
+ vis_encoder_size = processor.size[0]
253
+ assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]"
254
+ grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(")
255
+ grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints]
256
+
257
+ if type(grid_pinpoints) is list:
258
+ possible_resolutions = grid_pinpoints
259
+ else:
260
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
261
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
262
+ image_padded = resize_and_pad_image(image, best_resolution)
263
+
264
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
265
+
266
+ # FIXME: this seems to be a bug that it resizes instead of pad.
267
+ # but to keep it consistent with previous, i will keep it as it is
268
+ # TODO: uncomment below to ablate with the padding
269
+ if isinstance(processor.size, dict):
270
+ shortest_edge = processor.size["shortest_edge"]
271
+ else:
272
+ shortest_edge = min(processor.size)
273
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
274
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
275
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
276
+
277
+ image_patches = [image_original_resize] + patches
278
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
279
+ return torch.stack(image_patches, dim=0)
280
+
281
+
282
+ def load_image_from_base64(image):
283
+ return Image.open(BytesIO(base64.b64decode(image)))
284
+
285
+
286
+ def expand2square(pil_img, background_color):
287
+ width, height = pil_img.size
288
+ if width == height:
289
+ return pil_img
290
+ elif width > height:
291
+ result = Image.new(pil_img.mode, (width, width), background_color)
292
+ result.paste(pil_img, (0, (width - height) // 2))
293
+ return result
294
+ else:
295
+ result = Image.new(pil_img.mode, (height, height), background_color)
296
+ result.paste(pil_img, ((height - width) // 2, 0))
297
+ return result
298
+
299
+
300
+ def process_images(images, image_processor, model_cfg):
301
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
302
+ new_images = []
303
+ if image_aspect_ratio == "highres":
304
+ for image in images:
305
+ image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints)
306
+ new_images.append(image)
307
+ elif image_aspect_ratio == "anyres":
308
+ for image in images:
309
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
310
+ new_images.append(image)
311
+ elif image_aspect_ratio == "crop_split":
312
+ for image in images:
313
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
314
+ new_images.append(image)
315
+ elif image_aspect_ratio == "pad":
316
+ for image in images:
317
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
318
+ image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
319
+ new_images.append(image)
320
+ else:
321
+ return image_processor(images, return_tensors="pt")["pixel_values"]
322
+ if all(x.shape == new_images[0].shape for x in new_images):
323
+ new_images = torch.stack(new_images, dim=0)
324
+ return new_images
325
+
326
+
327
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
328
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
329
+
330
+ def insert_separator(X, sep):
331
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
332
+
333
+ input_ids = []
334
+ offset = 0
335
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
336
+ offset = 1
337
+ input_ids.append(prompt_chunks[0][0])
338
+
339
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
340
+ input_ids.extend(x[offset:])
341
+
342
+ if return_tensors is not None:
343
+ if return_tensors == "pt":
344
+ return torch.tensor(input_ids, dtype=torch.long)
345
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
346
+ return input_ids
347
+
348
+
349
+ def get_model_name_from_path(model_path):
350
+ model_path = model_path.strip("/")
351
+ model_paths = model_path.split("/")
352
+ if model_paths[-1].startswith("checkpoint-"):
353
+ return model_paths[-2] + "_" + model_paths[-1]
354
+ else:
355
+ return model_paths[-1]
356
+
357
+
358
+ class KeywordsStoppingCriteria(StoppingCriteria):
359
+ def __init__(self, keywords, tokenizer, input_ids):
360
+ self.keywords = keywords
361
+ self.keyword_ids = []
362
+ for keyword in keywords:
363
+ cur_keyword_ids = tokenizer(keyword).input_ids
364
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
365
+ cur_keyword_ids = cur_keyword_ids[1:]
366
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
367
+ self.tokenizer = tokenizer
368
+ self.start_len = input_ids.shape[1]
369
+
370
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
371
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
372
+ offset = min(output_ids.shape[1] - self.start_len, 3)
373
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
374
+ for keyword_id in self.keyword_ids:
375
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
376
+ return True
377
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
378
+ for keyword in self.keywords:
379
+ if keyword in outputs:
380
+ return True
381
+ return False
llava/model/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ AVAILABLE_MODELS = {
4
+ "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
5
+ "llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig",
6
+ "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
7
+ # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
8
+ "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig",
9
+ "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig",
10
+ # Add other models as needed
11
+ }
12
+
13
+ for model_name, model_classes in AVAILABLE_MODELS.items():
14
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
15
+ # print(f"import {model_classes} successfully")
16
+ try:
17
+ exec(f"from .language_model.{model_name} import {model_classes}")
18
+ print(f"import {model_classes} successfully")
19
+ except ImportError:
20
+ # import traceback
21
+ # traceback.print_exc()
22
+ print(f"Failed to import {model_classes} from llava.language_model.{model_name}")
23
+ pass
llava/model/apply_delta.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava import LlavaLlamaForCausalLM
12
+
13
+
14
+ def apply_delta(base_model_path, target_model_path, delta_path):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
33
+
34
+ print("Saving target model")
35
+ delta.save_pretrained(target_model_path)
36
+ delta_tokenizer.save_pretrained(target_model_path)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--base-model-path", type=str, required=True)
42
+ parser.add_argument("--target-model-path", type=str, required=True)
43
+ parser.add_argument("--delta-path", type=str, required=True)
44
+
45
+ args = parser.parse_args()
46
+
47
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from llava.utils import rank0_print
25
+
26
+
27
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, **kwargs):
28
+ kwargs = {"device_map": device_map}
29
+
30
+ if load_8bit:
31
+ kwargs["load_in_8bit"] = True
32
+ elif load_4bit:
33
+ kwargs["load_in_4bit"] = True
34
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
35
+ else:
36
+ kwargs["torch_dtype"] = torch.float16
37
+
38
+ if customized_config is not None:
39
+ kwargs["config"] = customized_config
40
+
41
+ if "llava" in model_name.lower():
42
+ # Load LLaVA model
43
+ if "lora" in model_name.lower() and model_base is None:
44
+ warnings.warn(
45
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
46
+ )
47
+ if "lora" in model_name.lower() and model_base is not None:
48
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
49
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
50
+ rank0_print("Loading LLaVA from base model...")
51
+ if "mixtral" in model_name.lower():
52
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
53
+
54
+ lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
57
+ elif "mistral" in model_name.lower():
58
+ from llava.model.language_model.llava_mistral import LlavaMistralConfig
59
+
60
+ lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
61
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
62
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
63
+ elif "gemma" in model_name.lower():
64
+ from llava.model.language_model.llava_gemma import LlavaGemmaConfig
65
+
66
+ lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
67
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
68
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
69
+ else:
70
+ from llava.model.language_model.llava_llama import LlavaConfig
71
+
72
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
73
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
74
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
75
+
76
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
77
+ if model.lm_head.weight.shape[0] != token_num:
78
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
79
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
80
+
81
+ rank0_print("Loading additional LLaVA weights...")
82
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
83
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
84
+ else:
85
+ # this is probably from HF Hub
86
+ from huggingface_hub import hf_hub_download
87
+
88
+ def load_from_hf(repo_id, filename, subfolder=None):
89
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
90
+ return torch.load(cache_file, map_location="cpu")
91
+
92
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
93
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
94
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
95
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
96
+ model.load_state_dict(non_lora_trainables, strict=False)
97
+
98
+ from peft import PeftModel
99
+
100
+ rank0_print("Loading LoRA weights...")
101
+ model = PeftModel.from_pretrained(model, model_path)
102
+ rank0_print("Merging LoRA weights...")
103
+ model = model.merge_and_unload()
104
+ rank0_print("Model is loaded...")
105
+ elif model_base is not None:
106
+ # this may be mm projector only
107
+ rank0_print(f"Loading LLaVA from base model {model_base}...")
108
+ if "mixtral" in model_name.lower():
109
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
110
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
111
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
112
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
113
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
114
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
115
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
116
+ elif "gemma" in model_name.lower():
117
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
118
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
119
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
120
+ elif (
121
+ "wizardlm-2" in model_name.lower()
122
+ and "vicuna" in model_name.lower()
123
+ or "llama" in model_name.lower()
124
+ or "yi" in model_name.lower()
125
+ or "nous-hermes" in model_name.lower()
126
+ or "llava-v1.6-34b" in model_name.lower()
127
+ or "llava-v1.5" in model_name.lower()
128
+ ):
129
+ from llava.model.language_model.llava_llama import LlavaConfig
130
+
131
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
132
+ if customized_config is None:
133
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
134
+ if "v1.5" in model_name.lower():
135
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
136
+ else:
137
+ llava_cfg = customized_config
138
+
139
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
140
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
141
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
142
+ else:
143
+ raise ValueError(f"Model {model_name} not supported")
144
+
145
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
146
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
147
+ model.load_state_dict(mm_projector_weights, strict=False)
148
+ else:
149
+ rank0_print(f"Loaded LLaVA model: {model_path}")
150
+ if "mixtral" in model_name.lower():
151
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
152
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
153
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
154
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
155
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
156
+ elif (
157
+ "wizardlm-2" in model_name.lower()
158
+ and "vicuna" in model_name.lower()
159
+ or "llama" in model_name.lower()
160
+ or "yi" in model_name.lower()
161
+ or "nous-hermes" in model_name.lower()
162
+ or "llava-v1.6-34b" in model_name.lower()
163
+ or "llava-v1.5" in model_name.lower()
164
+ ):
165
+ from llava.model.language_model.llava_llama import LlavaConfig
166
+
167
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
168
+ if customized_config is None:
169
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
170
+ if "v1.5" in model_name.lower():
171
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
172
+ else:
173
+ llava_cfg = customized_config
174
+
175
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
176
+ elif "qwen" in model_name.lower():
177
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
178
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
179
+ elif "gemma" in model_name.lower():
180
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
181
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
182
+ model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
183
+ else:
184
+ rank0_print("\n\n\nWarning : No matching llava architecture, auto load llava_llama. If it is not intended, specify it in model_name\n\n\n")
185
+ try:
186
+ from llava.model.language_model.llava_llama import LlavaConfig
187
+
188
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
189
+ if customized_config is None:
190
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
191
+ if "v1.5" in model_path.lower():
192
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
193
+ else:
194
+ llava_cfg = customized_config
195
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
196
+ except:
197
+ raise ValueError(f"Model {model_name} not supported")
198
+
199
+ else:
200
+ # Load language model
201
+ if model_base is not None:
202
+ # PEFT model
203
+ from peft import PeftModel
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
206
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
207
+ print(f"Loading LoRA weights from {model_path}")
208
+ model = PeftModel.from_pretrained(model, model_path)
209
+ print(f"Merging weights")
210
+ model = model.merge_and_unload()
211
+ print("Convert to FP16...")
212
+ model.to(torch.float16)
213
+ else:
214
+ use_fast = False
215
+ if "mpt" in model_name.lower().replace("prompt", ""):
216
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
217
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
218
+ else:
219
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
220
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
221
+
222
+ rank0_print(f"Model Class: {model.__class__.__name__}")
223
+ image_processor = None
224
+
225
+ if "llava" in model_name.lower():
226
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
227
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
228
+ if mm_use_im_patch_token:
229
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
230
+ if mm_use_im_start_end:
231
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
232
+ model.resize_token_embeddings(len(tokenizer))
233
+
234
+ vision_tower = model.get_vision_tower()
235
+ if not vision_tower.is_loaded:
236
+ vision_tower.load_model(device_map=device_map)
237
+ if device_map != "auto":
238
+ vision_tower.to(device="cuda", dtype=torch.float16)
239
+ image_processor = vision_tower.image_processor
240
+
241
+ if hasattr(model.config, "max_sequence_length"):
242
+ context_len = model.config.max_sequence_length
243
+ elif hasattr(model.config, "max_position_embeddings"):
244
+ context_len = model.config.max_position_embeddings
245
+ elif hasattr(model.config, "tokenizer_model_max_length"):
246
+ context_len = model.config.tokenizer_model_max_length
247
+ else:
248
+ context_len = 2048
249
+
250
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model import *
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def consolidate_ckpt(src_path, dst_path):
15
+ print("Loading model")
16
+ auto_upgrade(src_path)
17
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19
+ src_model.save_pretrained(dst_path)
20
+ src_tokenizer.save_pretrained(dst_path)
21
+
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--src", type=str, required=True)
26
+ parser.add_argument("--dst", type=str, required=True)
27
+
28
+ args = parser.parse_args()
29
+
30
+ consolidate_ckpt(args.src, args.dst)
llava/model/language_model/llava_gemma.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaGemmaConfig(GemmaConfig):
31
+ model_type = "llava_gemma"
32
+
33
+
34
+ class LlavaGemmaModel(LlavaMetaModel, GemmaModel):
35
+ config_class = LlavaGemmaConfig
36
+
37
+ def __init__(self, config: GemmaConfig):
38
+ super(LlavaGemmaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaGemmaConfig
43
+
44
+ def __init__(self, config):
45
+ super(GemmaForCausalLM, self).__init__(config)
46
+ self.model = LlavaGemmaModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ image_sizes: Optional[List[List[int]]] = None,
69
+ return_dict: Optional[bool] = None,
70
+ cache_position: Optional[torch.LongTensor] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
75
+
76
+ return super().forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ labels=labels,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
105
+ else:
106
+ inputs_embeds = self.get_model().embed_tokens(inputs)
107
+
108
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+ AutoConfig.register("llava_gemma", LlavaGemmaConfig)
122
+ AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM)
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
22
+
23
+ # , LlamaModel, LlamaForCausalLM, GenerationConfig
24
+ # from .modeling_llama import LlamaModel, LlamaForCausalLM
25
+ from transformers import LlamaModel, LlamaForCausalLM
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+ from transformers.generation.utils import GenerateOutput
28
+
29
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+
31
+
32
+ class LlavaConfig(LlamaConfig):
33
+ model_type = "llava_llama"
34
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
35
+ max_new_tokens: int = 1024
36
+ do_sample: bool = False
37
+ top_p: Optional[float] = None
38
+ rope_scaling: Optional[dict] = {}
39
+
40
+
41
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config: LlamaConfig):
45
+ super(LlavaLlamaModel, self).__init__(config)
46
+
47
+
48
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
49
+ config_class = LlavaConfig
50
+
51
+ def __init__(self, config):
52
+ LlamaForCausalLM.__init__(self, config)
53
+
54
+ # configure default generation settings
55
+ config.model_type = "llava_llama"
56
+ config.rope_scaling = None
57
+
58
+ self.model = LlavaLlamaModel(config)
59
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
60
+ # Initialize weights and apply final processing
61
+ self.post_init()
62
+
63
+ def get_model(self):
64
+ return self.model
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: torch.LongTensor = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
72
+ inputs_embeds: Optional[torch.FloatTensor] = None,
73
+ labels: Optional[torch.LongTensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ output_attentions: Optional[bool] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ images: Optional[torch.FloatTensor] = None,
78
+ image_sizes: Optional[List[List[int]]] = None,
79
+ return_dict: Optional[bool] = None,
80
+ cache_position=None,
81
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
82
+
83
+ if inputs_embeds is None:
84
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
85
+
86
+ return super().forward(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ position_ids=position_ids,
90
+ past_key_values=past_key_values,
91
+ inputs_embeds=inputs_embeds,
92
+ labels=labels,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ output_hidden_states=output_hidden_states,
96
+ return_dict=return_dict,
97
+ )
98
+
99
+ @torch.no_grad()
100
+ def generate(
101
+ self,
102
+ inputs: Optional[torch.Tensor] = None,
103
+ images: Optional[torch.Tensor] = None,
104
+ image_sizes: Optional[torch.Tensor] = None,
105
+ **kwargs,
106
+ ) -> Union[GenerateOutput, torch.LongTensor]:
107
+ position_ids = kwargs.pop("position_ids", None)
108
+ attention_mask = kwargs.pop("attention_mask", None)
109
+ if "inputs_embeds" in kwargs:
110
+ raise NotImplementedError("`inputs_embeds` is not supported")
111
+
112
+ if images is not None:
113
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
114
+ else:
115
+ inputs_embeds = self.get_model().embed_tokens(inputs)
116
+
117
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
118
+
119
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
120
+ images = kwargs.pop("images", None)
121
+ image_sizes = kwargs.pop("image_sizes", None)
122
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
123
+ if images is not None:
124
+ inputs["images"] = images
125
+ if image_sizes is not None:
126
+ inputs["image_sizes"] = image_sizes
127
+ return inputs
128
+
129
+
130
+ AutoConfig.register("llava_llama", LlavaConfig)
131
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMistralConfig(MistralConfig):
31
+ model_type = "llava_mistral"
32
+ temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna
33
+ max_new_tokens: int = 1024
34
+ do_sample: bool = False
35
+ top_p: Optional[float] = None
36
+
37
+
38
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
39
+ config_class = LlavaMistralConfig
40
+
41
+ def __init__(self, config: MistralConfig):
42
+ super(LlavaMistralModel, self).__init__(config)
43
+
44
+
45
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
46
+ config_class = LlavaMistralConfig
47
+
48
+ def __init__(self, config):
49
+ super(MistralForCausalLM, self).__init__(config)
50
+
51
+ config.model_type = "llava_mistral"
52
+ config.rope_scaling = None
53
+
54
+ self.model = LlavaMistralModel(config)
55
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ image_sizes: Optional[List[List[int]]] = None,
75
+ return_dict: Optional[bool] = None,
76
+ cache_position=None,
77
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
78
+
79
+ if inputs_embeds is None:
80
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
81
+
82
+ return super().forward(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ past_key_values=past_key_values,
87
+ inputs_embeds=inputs_embeds,
88
+ labels=labels,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=output_hidden_states,
92
+ return_dict=return_dict,
93
+ )
94
+
95
+ @torch.no_grad()
96
+ def generate(
97
+ self,
98
+ inputs: Optional[torch.Tensor] = None,
99
+ images: Optional[torch.Tensor] = None,
100
+ image_sizes: Optional[torch.Tensor] = None,
101
+ **kwargs,
102
+ ) -> Union[GenerateOutput, torch.LongTensor]:
103
+ position_ids = kwargs.pop("position_ids", None)
104
+ attention_mask = kwargs.pop("attention_mask", None)
105
+ if "inputs_embeds" in kwargs:
106
+ raise NotImplementedError("`inputs_embeds` is not supported")
107
+
108
+ if images is not None:
109
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
110
+ else:
111
+ inputs_embeds = self.get_model().embed_tokens(inputs)
112
+
113
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
114
+
115
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
116
+ images = kwargs.pop("images", None)
117
+ image_sizes = kwargs.pop("image_sizes", None)
118
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
119
+ if images is not None:
120
+ inputs["images"] = images
121
+ if image_sizes is not None:
122
+ inputs["image_sizes"] = image_sizes
123
+ return inputs
124
+
125
+
126
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
127
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_mixtral.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaMixtralConfig(MixtralConfig):
31
+ model_type = "llava_mixtral"
32
+
33
+
34
+ class LlavaMixtralModel(LlavaMetaModel, MixtralModel):
35
+ config_class = LlavaMixtralConfig
36
+
37
+ def __init__(self, config: MixtralConfig):
38
+ super(LlavaMixtralModel, self).__init__(config)
39
+
40
+
41
+ class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaMixtralConfig
43
+
44
+ def __init__(self, config):
45
+ super(MixtralForCausalLM, self).__init__(config)
46
+
47
+ config.model_type = "llava_mixtral"
48
+ config.rope_scaling = None
49
+ self.model = LlavaMixtralModel(config)
50
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ cache_position=None,
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+
74
+ if inputs_embeds is None:
75
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
76
+
77
+ return super().forward(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ labels=labels,
84
+ use_cache=use_cache,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
+
90
+ @torch.no_grad()
91
+ def generate(
92
+ self,
93
+ inputs: Optional[torch.Tensor] = None,
94
+ images: Optional[torch.Tensor] = None,
95
+ image_sizes: Optional[torch.Tensor] = None,
96
+ **kwargs,
97
+ ) -> Union[GenerateOutput, torch.LongTensor]:
98
+ position_ids = kwargs.pop("position_ids", None)
99
+ attention_mask = kwargs.pop("attention_mask", None)
100
+ if "inputs_embeds" in kwargs:
101
+ raise NotImplementedError("`inputs_embeds` is not supported")
102
+
103
+ if images is not None:
104
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
105
+ else:
106
+ inputs_embeds = self.get_model().embed_tokens(inputs)
107
+
108
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
109
+
110
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
111
+ images = kwargs.pop("images", None)
112
+ image_sizes = kwargs.pop("image_sizes", None)
113
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
114
+ if images is not None:
115
+ inputs["images"] = images
116
+ if image_sizes is not None:
117
+ inputs["image_sizes"] = image_sizes
118
+ return inputs
119
+
120
+
121
+ AutoConfig.register("llava_mixtral", LlavaMixtralConfig)
122
+ AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig
21
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
22
+
23
+
24
+ class LlavaMptConfig(MptConfig):
25
+ model_type = "llava_mpt"
26
+
27
+
28
+ class LlavaMptModel(LlavaMetaModel, MptModel):
29
+ config_class = LlavaMptConfig
30
+
31
+ def __init__(self, config: MptConfig):
32
+ config.hidden_size = config.d_model
33
+ super(LlavaMptModel, self).__init__(config)
34
+
35
+ def embed_tokens(self, x):
36
+ return self.wte(x)
37
+
38
+
39
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
40
+ config_class = LlavaMptConfig
41
+ supports_gradient_checkpointing = True
42
+
43
+ def __init__(self, config):
44
+ super(MptForCausalLM, self).__init__(config)
45
+
46
+ config.model_type = "llava_mpt"
47
+ config.rope_scaling = None
48
+ self.generation_config = GenerationConfig(
49
+ temperature=0.0,
50
+ max_new_tokens=1024,
51
+ do_sample=False,
52
+ top_p=None,
53
+ )
54
+
55
+ self.transformer = LlavaMptModel(config)
56
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+
58
+ # Initialize weights and apply final processing
59
+ self.post_init()
60
+
61
+ def get_model(self):
62
+ return self.transformer
63
+
64
+ def _set_gradient_checkpointing(self, module, value=False):
65
+ if isinstance(module, LlavaMptModel):
66
+ module.gradient_checkpointing = value
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ labels: Optional[torch.Tensor] = None,
75
+ use_cache: Optional[bool] = None,
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ return_dict: Optional[bool] = None,
79
+ cache_position=None,
80
+ images=None,
81
+ ):
82
+
83
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
84
+
85
+ return super().forward(
86
+ input_ids,
87
+ past_key_values=past_key_values,
88
+ attention_mask=attention_mask,
89
+ inputs_embeds=inputs_embeds,
90
+ labels=labels,
91
+ use_cache=use_cache,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
98
+ images = kwargs.pop("images", None)
99
+ _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
100
+ _inputs["images"] = images
101
+ return _inputs
102
+
103
+
104
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
105
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llava/model/language_model/llava_qwen.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenConfig(Qwen2Config):
36
+ model_type = "llava_qwen"
37
+
38
+
39
+ class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
40
+ config_class = LlavaQwenConfig
41
+
42
+ def __init__(self, config: Qwen2Config):
43
+ super(LlavaQwenModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2ForCausalLM, self).__init__(config)
51
+ Qwen2ForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ cache_position=None,
78
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
79
+
80
+ if inputs_embeds is None:
81
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
82
+
83
+ return super().forward(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ labels=labels,
90
+ use_cache=use_cache,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ @torch.no_grad()
97
+ def generate(
98
+ self,
99
+ inputs: Optional[torch.Tensor] = None,
100
+ images: Optional[torch.Tensor] = None,
101
+ image_sizes: Optional[torch.Tensor] = None,
102
+ **kwargs,
103
+ ) -> Union[GenerateOutput, torch.LongTensor]:
104
+ position_ids = kwargs.pop("position_ids", None)
105
+ attention_mask = kwargs.pop("attention_mask", None)
106
+ if "inputs_embeds" in kwargs:
107
+ raise NotImplementedError("`inputs_embeds` is not supported")
108
+
109
+ if images is not None:
110
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
111
+ else:
112
+ inputs_embeds = self.get_model().embed_tokens(inputs)
113
+
114
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
115
+
116
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117
+ images = kwargs.pop("images", None)
118
+ image_sizes = kwargs.pop("image_sizes", None)
119
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
120
+ if images is not None:
121
+ inputs["images"] = images
122
+ if image_sizes is not None:
123
+ inputs["image_sizes"] = image_sizes
124
+ return inputs
125
+
126
+
127
+ AutoConfig.register("llava_qwen", LlavaQwenConfig)
128
+ AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
llava/model/language_model/llava_qwen_moe.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+ from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM
30
+
31
+ # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32
+ # from .qwen.configuration_qwen import QWenConfig
33
+
34
+
35
+ class LlavaQwenMoeConfig(Qwen2MoeConfig):
36
+ model_type = "llava_qwen_moe"
37
+
38
+
39
+ class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel):
40
+ config_class = LlavaQwenMoeConfig
41
+
42
+ def __init__(self, config: Qwen2MoeConfig):
43
+ super(LlavaQwenMoeModel, self).__init__(config)
44
+
45
+
46
+ class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaQwenMoeConfig
48
+
49
+ def __init__(self, config):
50
+ # super(Qwen2MoeForCausalLM, self).__init__(config)
51
+ Qwen2MoeForCausalLM.__init__(self, config)
52
+ config.model_type = "llava_qwen_moe"
53
+ config.rope_scaling = None
54
+
55
+ self.model = LlavaQwenMoeModel(config)
56
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def get_model(self):
61
+ return self.model
62
+
63
+ def forward(
64
+ self,
65
+ input_ids: torch.LongTensor = None,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ position_ids: Optional[torch.LongTensor] = None,
68
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ labels: Optional[torch.LongTensor] = None,
71
+ use_cache: Optional[bool] = None,
72
+ output_attentions: Optional[bool] = None,
73
+ output_hidden_states: Optional[bool] = None,
74
+ images: Optional[torch.FloatTensor] = None,
75
+ image_sizes: Optional[List[List[int]]] = None,
76
+ return_dict: Optional[bool] = None,
77
+ cache_position=None,
78
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
79
+
80
+ if inputs_embeds is None:
81
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
82
+
83
+ return super().forward(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ position_ids=position_ids,
87
+ past_key_values=past_key_values,
88
+ inputs_embeds=inputs_embeds,
89
+ labels=labels,
90
+ use_cache=use_cache,
91
+ output_attentions=output_attentions,
92
+ output_hidden_states=output_hidden_states,
93
+ return_dict=return_dict,
94
+ )
95
+
96
+ @torch.no_grad()
97
+ def generate(
98
+ self,
99
+ inputs: Optional[torch.Tensor] = None,
100
+ images: Optional[torch.Tensor] = None,
101
+ image_sizes: Optional[torch.Tensor] = None,
102
+ **kwargs,
103
+ ) -> Union[GenerateOutput, torch.LongTensor]:
104
+ position_ids = kwargs.pop("position_ids", None)
105
+ attention_mask = kwargs.pop("attention_mask", None)
106
+ if "inputs_embeds" in kwargs:
107
+ raise NotImplementedError("`inputs_embeds` is not supported")
108
+
109
+ if images is not None:
110
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
111
+ else:
112
+ inputs_embeds = self.get_model().embed_tokens(inputs)
113
+
114
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
115
+
116
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117
+ images = kwargs.pop("images", None)
118
+ image_sizes = kwargs.pop("image_sizes", None)
119
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
120
+ if images is not None:
121
+ inputs["images"] = images
122
+ if image_sizes is not None:
123
+ inputs["image_sizes"] = image_sizes
124
+ return inputs
125
+
126
+
127
+ AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig)
128
+ AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_resampler.builder import build_vision_resampler
23
+ from .multimodal_projector.builder import build_vision_projector
24
+
25
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
26
+
27
+ from llava.mm_utils import get_anyres_image_grid_shape
28
+ from llava.utils import rank0_print
29
+
30
+
31
+ class LlavaMetaModel:
32
+
33
+ def __init__(self, config):
34
+ super(LlavaMetaModel, self).__init__(config)
35
+
36
+ if hasattr(config, "mm_vision_tower"):
37
+ delay_load = getattr(config, "delay_load", False)
38
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
39
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
40
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
41
+
42
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
43
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
44
+
45
+ def get_vision_tower(self):
46
+ vision_tower = getattr(self, "vision_tower", None)
47
+ if type(vision_tower) is list:
48
+ vision_tower = vision_tower[0]
49
+ return vision_tower
50
+
51
+ def initialize_vision_modules(self, model_args, fsdp=None):
52
+ vision_tower = model_args.vision_tower
53
+ mm_vision_select_layer = model_args.mm_vision_select_layer
54
+ mm_vision_select_feature = model_args.mm_vision_select_feature
55
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
56
+ mm_patch_merge_type = model_args.mm_patch_merge_type
57
+
58
+ self.config.mm_vision_tower = vision_tower
59
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
60
+
61
+ if self.get_vision_tower() is None:
62
+ vision_tower = build_vision_tower(model_args)
63
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
64
+ for k, v in vision_resampler.config.items():
65
+ setattr(self.config, k, v)
66
+
67
+ if fsdp is not None and len(fsdp) > 0:
68
+ self.vision_tower = [vision_tower]
69
+ self.vision_resampler = [vision_resampler]
70
+ else:
71
+ self.vision_tower = vision_tower
72
+ self.vision_resampler = vision_resampler
73
+ else:
74
+ if fsdp is not None and len(fsdp) > 0:
75
+ vision_resampler = self.vision_resampler[0]
76
+ vision_tower = self.vision_tower[0]
77
+ else:
78
+ vision_resampler = self.vision_resampler
79
+ vision_tower = self.vision_tower
80
+ vision_tower.load_model()
81
+
82
+ # In case it is frozen by LoRA
83
+ for p in self.vision_resampler.parameters():
84
+ p.requires_grad = True
85
+
86
+ self.config.use_mm_proj = True
87
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
88
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
89
+ self.config.mm_vision_select_layer = mm_vision_select_layer
90
+ self.config.mm_vision_select_feature = mm_vision_select_feature
91
+ self.config.mm_patch_merge_type = mm_patch_merge_type
92
+
93
+ if getattr(self, "mm_projector", None) is None:
94
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
95
+
96
+ if "unpad" in mm_patch_merge_type:
97
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
98
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
99
+ else:
100
+ # In case it is frozen by LoRA
101
+ for p in self.mm_projector.parameters():
102
+ p.requires_grad = True
103
+
104
+ if pretrain_mm_mlp_adapter is not None:
105
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
106
+
107
+ def get_w(weights, keyword):
108
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
109
+
110
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
111
+ rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
112
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
113
+ rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
114
+
115
+
116
+ def unpad_image(tensor, original_size):
117
+ """
118
+ Unpads a PyTorch tensor of a padded and resized image.
119
+
120
+ Args:
121
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
122
+ original_size (tuple): The original size of the image (height, width).
123
+
124
+ Returns:
125
+ torch.Tensor: The unpadded image tensor.
126
+ """
127
+ original_width, original_height = original_size
128
+ current_height, current_width = tensor.shape[1:]
129
+
130
+ # Compute aspect ratios
131
+ original_aspect_ratio = original_width / original_height
132
+ current_aspect_ratio = current_width / current_height
133
+
134
+ # Determine padding size and direction
135
+ if original_aspect_ratio > current_aspect_ratio:
136
+ # Padding was added to the height
137
+ scale_factor = current_width / original_width
138
+ new_height = int(original_height * scale_factor)
139
+ padding = (current_height - new_height) // 2
140
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
141
+ else:
142
+ # Padding was added to the width
143
+ scale_factor = current_height / original_height
144
+ new_width = int(original_width * scale_factor)
145
+ padding = (current_width - new_width) // 2
146
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
147
+
148
+ return unpadded_tensor
149
+
150
+
151
+ class LlavaMetaForCausalLM(ABC):
152
+
153
+ @abstractmethod
154
+ def get_model(self):
155
+ pass
156
+
157
+ def get_vision_tower(self):
158
+ return self.get_model().get_vision_tower()
159
+
160
+ def encode_images(self, images):
161
+ image_features = self.get_model().get_vision_tower()(images)
162
+ image_features = self.get_model().vision_resampler(image_features, images=images)
163
+ image_features = self.get_model().mm_projector(image_features)
164
+ return image_features
165
+
166
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None):
167
+ vision_tower = self.get_vision_tower()
168
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
169
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
170
+
171
+ if type(images) is list or images.ndim == 5:
172
+ if type(images) is list:
173
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
174
+ concat_images = torch.cat([image for image in images], dim=0)
175
+ image_features = self.encode_images(concat_images)
176
+ split_sizes = [image.shape[0] for image in images]
177
+ image_features = torch.split(image_features, split_sizes, dim=0)
178
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
179
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
180
+ if mm_patch_merge_type == "flat":
181
+ image_features = [x.flatten(0, 1) for x in image_features]
182
+ elif mm_patch_merge_type.startswith("spatial"):
183
+ new_image_features = []
184
+ for image_idx, image_feature in enumerate(image_features):
185
+ # FIXME: now assume the image is square, and split to 2x2 patches
186
+ # num_patches = h * w, where h = w = sqrt(num_patches)
187
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
188
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
189
+
190
+ if image_feature.shape[0] > 1:
191
+ base_image_feature = image_feature[0]
192
+ image_feature = image_feature[1:]
193
+ height = width = self.get_vision_tower().num_patches_per_side
194
+ assert height * width == base_image_feature.shape[0]
195
+ if image_aspect_ratio == "anyres":
196
+ if hasattr(self.get_vision_tower(), "image_size"):
197
+ vision_tower_image_size = self.get_vision_tower().image_size
198
+ else:
199
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
200
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
201
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
202
+ else:
203
+ image_feature = image_feature.view(2, 2, height, width, -1)
204
+ if "maxpool2x2" in mm_patch_merge_type:
205
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
206
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
207
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
208
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
209
+ elif "unpad" in mm_patch_merge_type:
210
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
211
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
212
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
213
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
214
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
215
+ else:
216
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
217
+ image_feature = image_feature.flatten(0, 3)
218
+ if "nobase" in mm_patch_merge_type:
219
+ pass
220
+ else:
221
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
222
+ else:
223
+ image_feature = image_feature[0]
224
+ if "unpad" in mm_patch_merge_type:
225
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
226
+ new_image_features.append(image_feature)
227
+ image_features = new_image_features
228
+ else:
229
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
230
+ else:
231
+ image_features = self.encode_images(images)
232
+
233
+ # TODO: image start / end is not implemented here to support pretraining.
234
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
235
+ raise NotImplementedError
236
+
237
+ # Let's just add dummy tensors if they do not exist,
238
+ # it is a headache to deal with None all the time.
239
+ # But it is not ideal, and if you have a better idea,
240
+ # please open an issue / submit a PR, thanks.
241
+ _labels = labels
242
+ _position_ids = position_ids
243
+ _attention_mask = attention_mask
244
+ if attention_mask is None:
245
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
246
+ else:
247
+ attention_mask = attention_mask.bool()
248
+ if position_ids is None:
249
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
250
+ if labels is None:
251
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
252
+
253
+ # remove the padding using attention_mask -- FIXME
254
+ _input_ids = input_ids
255
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
256
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
257
+
258
+ new_input_embeds = []
259
+ new_labels = []
260
+ cur_image_idx = 0
261
+ for batch_idx, cur_input_ids in enumerate(input_ids):
262
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
263
+ if num_images == 0:
264
+ cur_image_features = image_features[cur_image_idx]
265
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
266
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
267
+ new_input_embeds.append(cur_input_embeds)
268
+ new_labels.append(labels[batch_idx])
269
+ cur_image_idx += 1
270
+ continue
271
+
272
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
273
+ cur_input_ids_noim = []
274
+ cur_labels = labels[batch_idx]
275
+ cur_labels_noim = []
276
+ for i in range(len(image_token_indices) - 1):
277
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
278
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
279
+
280
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
281
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
282
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
283
+ cur_new_input_embeds = []
284
+ cur_new_labels = []
285
+
286
+ for i in range(num_images + 1):
287
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
288
+ cur_new_labels.append(cur_labels_noim[i])
289
+ if i < num_images:
290
+ cur_image_features = image_features[cur_image_idx]
291
+ cur_image_idx += 1
292
+ cur_new_input_embeds.append(cur_image_features)
293
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
294
+
295
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
296
+
297
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
298
+ cur_new_labels = torch.cat(cur_new_labels)
299
+
300
+ new_input_embeds.append(cur_new_input_embeds)
301
+ new_labels.append(cur_new_labels)
302
+
303
+ # Truncate sequences to max length as image embeddings can make the sequence longer
304
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
305
+ if tokenizer_model_max_length is not None:
306
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
307
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
308
+
309
+ # Combine them
310
+ max_len = max(x.shape[0] for x in new_input_embeds)
311
+ batch_size = len(new_input_embeds)
312
+
313
+ new_input_embeds_padded = []
314
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
315
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
316
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
317
+
318
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
319
+ cur_len = cur_new_embed.shape[0]
320
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
321
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
322
+ if cur_len > 0:
323
+ new_labels_padded[i, -cur_len:] = cur_new_labels
324
+ attention_mask[i, -cur_len:] = True
325
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
326
+ else:
327
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
328
+ if cur_len > 0:
329
+ new_labels_padded[i, :cur_len] = cur_new_labels
330
+ attention_mask[i, :cur_len] = True
331
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
332
+
333
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
334
+
335
+ if _labels is None:
336
+ new_labels = None
337
+ else:
338
+ new_labels = new_labels_padded
339
+
340
+ if _attention_mask is None:
341
+ attention_mask = None
342
+ else:
343
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
344
+
345
+ if _position_ids is None:
346
+ position_ids = None
347
+
348
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
349
+
350
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
351
+ if model_args.mm_use_im_patch_token:
352
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
353
+ self.resize_token_embeddings(len(tokenizer))
354
+
355
+ if model_args.mm_use_im_start_end:
356
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
357
+ self.resize_token_embeddings(len(tokenizer))
358
+
359
+ if num_new_tokens > 0:
360
+ input_embeddings = self.get_input_embeddings().weight.data
361
+ output_embeddings = self.get_output_embeddings().weight.data
362
+
363
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
364
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
365
+
366
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
367
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
368
+
369
+ if model_args.tune_mm_mlp_adapter:
370
+ for p in self.get_input_embeddings().parameters():
371
+ p.requires_grad = True
372
+ for p in self.get_output_embeddings().parameters():
373
+ p.requires_grad = False
374
+
375
+ if model_args.pretrain_mm_mlp_adapter:
376
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
377
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
378
+ assert num_new_tokens == 2
379
+ if input_embeddings.shape == embed_tokens_weight.shape:
380
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
381
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
382
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
383
+ else:
384
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
385
+ elif model_args.mm_use_im_patch_token:
386
+ if model_args.tune_mm_mlp_adapter:
387
+ for p in self.get_input_embeddings().parameters():
388
+ p.requires_grad = False
389
+ for p in self.get_output_embeddings().parameters():
390
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+
6
+ import argparse
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
+ from llava.model.utils import auto_upgrade
12
+
13
+
14
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
15
+ print("Loading base model")
16
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31
+ bparam = base.state_dict()[name]
32
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .siglip_encoder import SigLipVisionTower
4
+
5
+
6
+ def build_vision_tower(vision_tower_cfg, **kwargs):
7
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8
+ is_absolute_path_exists = os.path.exists(vision_tower)
9
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
11
+ elif "siglip" in vision_tower:
12
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
13
+
14
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
20
+ # TODO: better detector is needed.
21
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
22
+ self.load_model()
23
+ else:
24
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
25
+
26
+ def load_model(self, device_map=None):
27
+ if self.is_loaded:
28
+ print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
29
+ return
30
+
31
+ # import pdb; pdb.set_trace()
32
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
33
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
34
+ self.vision_tower.requires_grad_(False)
35
+
36
+ self.is_loaded = True
37
+
38
+ def feature_select(self, image_forward_outs):
39
+ select_feature_type = self.select_feature
40
+
41
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
42
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
43
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
44
+ select_feature_type = select_feature_type.replace("slicefour_", "")
45
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
46
+ select_layers = [-2, -5, -8, -11, 6]
47
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
48
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
49
+ else:
50
+ image_features = image_forward_outs.hidden_states[self.select_layer]
51
+
52
+ if select_feature_type == "patch":
53
+ image_features = image_features[:, 1:]
54
+ elif select_feature_type == "cls_patch":
55
+ image_features = image_features
56
+ else:
57
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
58
+ return image_features
59
+
60
+ def forward(self, images):
61
+ if type(images) is list:
62
+ image_features = []
63
+ for image in images:
64
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
65
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
66
+ image_features.append(image_feature)
67
+ else:
68
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
69
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
70
+
71
+ return image_features
72
+
73
+ @property
74
+ def dummy_feature(self):
75
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
76
+
77
+ @property
78
+ def dtype(self):
79
+ return self.vision_tower.dtype
80
+
81
+ @property
82
+ def device(self):
83
+ return self.vision_tower.device
84
+
85
+ @property
86
+ def config(self):
87
+ if self.is_loaded:
88
+ return self.vision_tower.config
89
+ else:
90
+ return self.cfg_only
91
+
92
+ @property
93
+ def hidden_size(self):
94
+ _hidden_size = self.config.hidden_size
95
+ if "slicefour" in self.select_feature:
96
+ _hidden_size *= 4
97
+ if "slice_m25811_f6" in self.select_feature:
98
+ _hidden_size *= 5
99
+ return _hidden_size
100
+
101
+ @property
102
+ def num_patches_per_side(self):
103
+ return self.config.image_size // self.config.patch_size
104
+
105
+ @property
106
+ def num_patches(self):
107
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
108
+ if "cls_patch" in self.select_feature:
109
+ _num_patches += 1
110
+ return _num_patches
111
+
112
+ @property
113
+ def image_size(self):
114
+ return self.config.image_size
llava/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3
+ """
4
+
5
+ from typing import Optional, Tuple, Union, Dict
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce
8
+ from PIL import Image
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ import os
13
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
14
+ from transformers.image_transforms import (
15
+ convert_to_rgb,
16
+ normalize,
17
+ rescale,
18
+ resize,
19
+ to_channel_dimension_format,
20
+ )
21
+ from transformers.image_utils import (
22
+ ChannelDimension,
23
+ PILImageResampling,
24
+ to_numpy_array,
25
+ )
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers import PretrainedConfig
30
+ from transformers.utils import ModelOutput
31
+ from llava.utils import rank0_print
32
+
33
+
34
+ class SigLipImageProcessor:
35
+ def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
36
+ crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
37
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
38
+
39
+ self.image_mean = image_mean
40
+ self.image_std = image_std
41
+ self.size = size
42
+ self.resample = resample
43
+ self.rescale_factor = rescale_factor
44
+ self.data_format = data_format
45
+ self.crop_size = crop_size
46
+
47
+ def preprocess(self, images, return_tensors):
48
+ if isinstance(images, Image.Image):
49
+ images = [images]
50
+ else:
51
+ assert isinstance(images, list)
52
+
53
+ transforms = [
54
+ convert_to_rgb,
55
+ to_numpy_array,
56
+ partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
57
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
58
+ partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
59
+ partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
60
+ ]
61
+
62
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
63
+ data = {"pixel_values": images}
64
+
65
+ return BatchFeature(data=data, tensor_type=return_tensors)
66
+
67
+
68
+ class SigLipVisionConfig(PretrainedConfig):
69
+ model_type = "siglip_vision_model"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=1152,
74
+ image_mean=(0.5, 0.5, 0.5),
75
+ intermediate_size=4304,
76
+ num_hidden_layers=27,
77
+ num_attention_heads=16,
78
+ num_channels=3,
79
+ image_size=384,
80
+ patch_size=14,
81
+ hidden_act="gelu_pytorch_tanh",
82
+ layer_norm_eps=1e-6,
83
+ attention_dropout=0.0,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+
88
+ self.hidden_size = hidden_size
89
+ self.intermediate_size = intermediate_size
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.attention_dropout = attention_dropout
96
+ self.layer_norm_eps = layer_norm_eps
97
+ self.hidden_act = hidden_act
98
+ self.image_mean = image_mean
99
+
100
+ @classmethod
101
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
102
+ cls._set_token_in_kwargs(kwargs)
103
+
104
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
105
+
106
+ # get the vision config dict if we are loading from SigLipConfig
107
+ if config_dict.get("model_type") == "siglip":
108
+ config_dict = config_dict["vision_config"]
109
+
110
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
111
+ print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
112
+
113
+ return cls.from_dict(config_dict, **kwargs)
114
+
115
+
116
+ @dataclass
117
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
118
+ class SigLipVisionModelOutput(ModelOutput):
119
+ """
120
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
121
+
122
+ Args:
123
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
124
+ The image embeddings obtained by applying the projection layer to the pooler_output.
125
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
126
+ Sequence of hidden-states at the output of the last layer of the model.
127
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
128
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
129
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
130
+
131
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
132
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
133
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
134
+ sequence_length)`.
135
+
136
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
137
+ heads.
138
+ """
139
+
140
+ image_embeds: Optional[torch.FloatTensor] = None
141
+ last_hidden_state: torch.FloatTensor = None
142
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
143
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
144
+
145
+
146
+ class SigLipVisionEmbeddings(nn.Module):
147
+ def __init__(self, config: SigLipVisionConfig):
148
+ super().__init__()
149
+ self.config = config
150
+ self.embed_dim = config.hidden_size
151
+ self.image_size = config.image_size
152
+ self.patch_size = config.patch_size
153
+
154
+ self.patch_embedding = nn.Conv2d(
155
+ in_channels=config.num_channels,
156
+ out_channels=self.embed_dim,
157
+ kernel_size=self.patch_size,
158
+ stride=self.patch_size,
159
+ padding="valid",
160
+ )
161
+
162
+ self.num_patches = (self.image_size // self.patch_size) ** 2
163
+ self.num_positions = self.num_patches
164
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
165
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
166
+
167
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
168
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
169
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
170
+
171
+ embeddings = embeddings + self.position_embedding(self.position_ids)
172
+ return embeddings
173
+
174
+
175
+ class SigLipAttention(nn.Module):
176
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
177
+
178
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
179
+ def __init__(self, config):
180
+ super().__init__()
181
+ self.config = config
182
+ self.embed_dim = config.hidden_size
183
+ self.num_heads = config.num_attention_heads
184
+ self.head_dim = self.embed_dim // self.num_heads
185
+ if self.head_dim * self.num_heads != self.embed_dim:
186
+ raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
187
+ self.scale = self.head_dim**-0.5
188
+ self.dropout = config.attention_dropout
189
+
190
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
191
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
193
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
194
+
195
+ def forward(
196
+ self,
197
+ hidden_states: torch.Tensor,
198
+ attention_mask: Optional[torch.Tensor] = None,
199
+ output_attentions: Optional[bool] = False,
200
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
201
+ """Input shape: Batch x Time x Channel"""
202
+
203
+ batch_size, q_len, _ = hidden_states.size()
204
+
205
+ query_states = self.q_proj(hidden_states)
206
+ key_states = self.k_proj(hidden_states)
207
+ value_states = self.v_proj(hidden_states)
208
+
209
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
210
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
211
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
212
+
213
+ k_v_seq_len = key_states.shape[-2]
214
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
215
+
216
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
217
+ raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
218
+
219
+ if attention_mask is not None:
220
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
221
+ raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
222
+ attn_weights = attn_weights + attention_mask
223
+
224
+ # upcast attention to fp32
225
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
226
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
227
+ attn_output = torch.matmul(attn_weights, value_states)
228
+
229
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
230
+ raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
231
+
232
+ attn_output = attn_output.transpose(1, 2).contiguous()
233
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
234
+
235
+ attn_output = self.out_proj(attn_output)
236
+
237
+ return attn_output, attn_weights
238
+
239
+
240
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
241
+ class SigLipMLP(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.config = config
245
+ self.activation_fn = ACT2FN[config.hidden_act]
246
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
247
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
248
+
249
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
250
+ hidden_states = self.fc1(hidden_states)
251
+ hidden_states = self.activation_fn(hidden_states)
252
+ hidden_states = self.fc2(hidden_states)
253
+ return hidden_states
254
+
255
+
256
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
257
+ class SigLipEncoderLayer(nn.Module):
258
+ def __init__(self, config: SigLipVisionConfig):
259
+ super().__init__()
260
+ self.embed_dim = config.hidden_size
261
+ self.self_attn = SigLipAttention(config)
262
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
263
+ self.mlp = SigLipMLP(config)
264
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
265
+
266
+ # Ignore copy
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ attention_mask: torch.Tensor,
271
+ output_attentions: Optional[bool] = False,
272
+ ) -> Tuple[torch.FloatTensor]:
273
+ """
274
+ Args:
275
+ hidden_states (`torch.FloatTensor`):
276
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
277
+ attention_mask (`torch.FloatTensor`):
278
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
279
+ output_attentions (`bool`, *optional*, defaults to `False`):
280
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
281
+ returned tensors for more detail.
282
+ """
283
+ residual = hidden_states
284
+
285
+ hidden_states = self.layer_norm1(hidden_states)
286
+ hidden_states, attn_weights = self.self_attn(
287
+ hidden_states=hidden_states,
288
+ attention_mask=attention_mask,
289
+ output_attentions=output_attentions,
290
+ )
291
+ hidden_states = residual + hidden_states
292
+
293
+ residual = hidden_states
294
+ hidden_states = self.layer_norm2(hidden_states)
295
+ hidden_states = self.mlp(hidden_states)
296
+ hidden_states = residual + hidden_states
297
+
298
+ outputs = (hidden_states,)
299
+
300
+ if output_attentions:
301
+ outputs += (attn_weights,)
302
+
303
+ return outputs
304
+
305
+
306
+ class SigLipPreTrainedModel(PreTrainedModel):
307
+ """
308
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
309
+ models.
310
+ """
311
+
312
+ config_class = SigLipVisionConfig
313
+ base_model_prefix = "siglip"
314
+ supports_gradient_checkpointing = True
315
+
316
+ def _init_weights(self, module):
317
+ """Initialize the weights"""
318
+ pass
319
+
320
+
321
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
322
+ class SigLipEncoder(nn.Module):
323
+ """
324
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
325
+ [`SigLipEncoderLayer`].
326
+
327
+ Args:
328
+ config: SigLipVisionConfig
329
+ """
330
+
331
+ def __init__(self, config: SigLipVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
335
+ self.gradient_checkpointing = False
336
+
337
+ # Ignore copy
338
+ def forward(
339
+ self,
340
+ inputs_embeds,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ output_attentions: Optional[bool] = None,
343
+ output_hidden_states: Optional[bool] = None,
344
+ return_dict: Optional[bool] = None,
345
+ ) -> Union[Tuple, BaseModelOutput]:
346
+ r"""
347
+ Args:
348
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
349
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
350
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
351
+ than the model's internal embedding lookup matrix.
352
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
353
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
354
+
355
+ - 1 for tokens that are **not masked**,
356
+ - 0 for tokens that are **masked**.
357
+
358
+ [What are attention masks?](../glossary#attention-mask)
359
+ output_attentions (`bool`, *optional*):
360
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
361
+ returned tensors for more detail.
362
+ output_hidden_states (`bool`, *optional*):
363
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
364
+ for more detail.
365
+ return_dict (`bool`, *optional*):
366
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
367
+ """
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ encoder_states = () if output_hidden_states else None
373
+ all_attentions = () if output_attentions else None
374
+
375
+ hidden_states = inputs_embeds
376
+ for encoder_layer in self.layers:
377
+ if output_hidden_states:
378
+ encoder_states = encoder_states + (hidden_states,)
379
+ if self.gradient_checkpointing and self.training:
380
+ layer_outputs = self._gradient_checkpointing_func(
381
+ encoder_layer.__call__,
382
+ hidden_states,
383
+ attention_mask,
384
+ output_attentions,
385
+ )
386
+ else:
387
+ layer_outputs = encoder_layer(
388
+ hidden_states,
389
+ attention_mask,
390
+ output_attentions=output_attentions,
391
+ )
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_attentions = all_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ encoder_states = encoder_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
403
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
404
+
405
+
406
+ class SigLipVisionTransformer(nn.Module):
407
+ def __init__(self, config: SigLipVisionConfig):
408
+ super().__init__()
409
+ self.config = config
410
+ embed_dim = config.hidden_size
411
+
412
+ self.embeddings = SigLipVisionEmbeddings(config)
413
+ self.encoder = SigLipEncoder(config)
414
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
415
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
416
+
417
+ def forward(
418
+ self,
419
+ pixel_values,
420
+ output_attentions: Optional[bool] = None,
421
+ output_hidden_states: Optional[bool] = None,
422
+ return_dict: Optional[bool] = None,
423
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
424
+ r"""
425
+ Returns:
426
+
427
+ """
428
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
429
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
431
+
432
+ hidden_states = self.embeddings(pixel_values)
433
+
434
+ encoder_outputs = self.encoder(
435
+ inputs_embeds=hidden_states,
436
+ output_attentions=output_attentions,
437
+ output_hidden_states=output_hidden_states,
438
+ return_dict=return_dict,
439
+ )
440
+
441
+ last_hidden_state = encoder_outputs[0]
442
+ last_hidden_state = self.post_layernorm(last_hidden_state)
443
+
444
+ pooled_output = self.head(last_hidden_state)
445
+
446
+ if not return_dict:
447
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
448
+
449
+ return BaseModelOutputWithPooling(
450
+ last_hidden_state=last_hidden_state,
451
+ pooler_output=pooled_output,
452
+ hidden_states=encoder_outputs.hidden_states,
453
+ attentions=encoder_outputs.attentions,
454
+ )
455
+
456
+
457
+ class SigLipMultiheadAttentionPoolingHead(nn.Module):
458
+ """Multihead Attention Pooling."""
459
+
460
+ def __init__(self, config: SigLipVisionConfig):
461
+ super().__init__()
462
+
463
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
464
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
465
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
466
+ self.mlp = SigLipMLP(config)
467
+
468
+ def forward(self, hidden_state):
469
+ batch_size = hidden_state.shape[0]
470
+ probe = self.probe.repeat(batch_size, 1, 1)
471
+
472
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
473
+
474
+ residual = hidden_state
475
+ hidden_state = self.layernorm(hidden_state)
476
+ hidden_state = residual + self.mlp(hidden_state)
477
+
478
+ return hidden_state[:, 0]
479
+
480
+
481
+ class SigLipVisionModel(SigLipPreTrainedModel):
482
+ config_class = SigLipVisionConfig
483
+ main_input_name = "pixel_values"
484
+ _no_split_modules = ["SigLipEncoderLayer"]
485
+
486
+ def __init__(self, config: SigLipVisionConfig):
487
+ super().__init__(config)
488
+
489
+ self.vision_model = SigLipVisionTransformer(config)
490
+
491
+ # Initialize weights and apply final processing
492
+ self.post_init()
493
+
494
+ def get_input_embeddings(self) -> nn.Module:
495
+ return self.vision_model.embeddings.patch_embedding
496
+
497
+ def forward(
498
+ self,
499
+ pixel_values,
500
+ output_attentions: Optional[bool] = None,
501
+ output_hidden_states: Optional[bool] = None,
502
+ return_dict: Optional[bool] = None,
503
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
504
+ r"""
505
+ Returns:
506
+
507
+ Examples:
508
+
509
+ ```python
510
+ >>> from PIL import Image
511
+ >>> import requests
512
+ >>> from transformers import AutoProcessor, SigLipVisionModel
513
+
514
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
515
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
516
+
517
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
518
+ >>> image = Image.open(requests.get(url, stream=True).raw)
519
+
520
+ >>> inputs = processor(images=image, return_tensors="pt")
521
+
522
+ >>> outputs = model(**inputs)
523
+ >>> last_hidden_state = outputs.last_hidden_state
524
+ >>> pooled_output = outputs.pooler_output # pooled features
525
+ ```"""
526
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
527
+
528
+ return self.vision_model(
529
+ pixel_values=pixel_values.to(self.device),
530
+ output_attentions=output_attentions,
531
+ output_hidden_states=output_hidden_states,
532
+ return_dict=return_dict,
533
+ )
534
+
535
+
536
+ class SigLipVisionTower(nn.Module):
537
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
538
+ super().__init__()
539
+
540
+ self.is_loaded = False
541
+
542
+ self.config = SigLipVisionConfig()
543
+
544
+ self.vision_tower_name = vision_tower
545
+
546
+ self.image_processor = SigLipImageProcessor()
547
+
548
+ if not delay_load:
549
+ rank0_print(f"Loading vision tower: {vision_tower}")
550
+ self.load_model()
551
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
552
+ # TODO: better detector is needed.
553
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
554
+ self.load_model()
555
+ elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
556
+ rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
557
+ self.load_model()
558
+ else:
559
+ self.cfg_only = self.config
560
+
561
+ def load_model(self, device_map=None):
562
+ if self.is_loaded:
563
+ return
564
+
565
+ self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
566
+
567
+ del self.vision_tower.vision_model.encoder.layers[-1:]
568
+ self.vision_tower.vision_model.head = nn.Identity()
569
+ self.vision_tower.requires_grad_(False)
570
+ self.vision_tower.eval()
571
+
572
+ self.is_loaded = True
573
+
574
+ @torch.no_grad()
575
+ def forward(self, images):
576
+ if type(images) is list:
577
+ image_features = []
578
+ for image in images:
579
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
580
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
581
+ assert image_features.shape[-2] == 729
582
+ image_features.append(image_feature)
583
+ else:
584
+ images=images.to(device=self.device, dtype=self.dtype)
585
+ image_forward_outs = self.vision_tower(images, output_hidden_states=True)
586
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
587
+ assert image_features.shape[-2] == 729
588
+
589
+ return image_features
590
+
591
+ @property
592
+ def dummy_feature(self):
593
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
594
+
595
+ @property
596
+ def dtype(self):
597
+ for p in self.vision_tower.parameters():
598
+ return p.dtype
599
+
600
+ @property
601
+ def device(self):
602
+ for p in self.vision_tower.parameters():
603
+ return p.device
604
+
605
+ @property
606
+ def hidden_size(self):
607
+ return self.config.hidden_size
608
+
609
+ @property
610
+ def num_patches(self):
611
+ return (self.config.image_size // self.config.patch_size) ** 2
612
+
613
+ @property
614
+ def num_patches_per_side(self):
615
+ return self.config.image_size // self.config.patch_size
616
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
617
+
618
+ @property
619
+ def image_size(self):
620
+ return self.config.image_size
llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ from .pooler_projector import PoolerProjector
6
+
7
+
8
+ class IdentityMap(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ return x
14
+
15
+ @property
16
+ def config(self):
17
+ return {"mm_projector_type": "identity"}
18
+
19
+
20
+ class SimpleResBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super().__init__()
23
+ self.pre_norm = nn.LayerNorm(channels)
24
+
25
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, "mm_projector_type", "linear")
34
+
35
+ if projector_type == "linear":
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ if projector_type == "pooler":
39
+ return PoolerProjector(config, kwargs["vision_cfg"])
40
+
41
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
42
+ if mlp_gelu_match:
43
+ mlp_depth = int(mlp_gelu_match.group(1))
44
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
45
+ for _ in range(1, mlp_depth):
46
+ modules.append(nn.GELU())
47
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48
+ return nn.Sequential(*modules)
49
+
50
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
51
+ if mlp_gelu_resnet_match:
52
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
53
+ res_depth = int(mlp_gelu_resnet_match.group(2))
54
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
55
+ for _ in range(1, mlp_depth):
56
+ modules.append(nn.GELU())
57
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
58
+ for _ in range(res_depth):
59
+ modules.append(SimpleResBlock(config.hidden_size))
60
+ return nn.Sequential(*modules)
61
+
62
+ if projector_type == "identity":
63
+ return IdentityMap()
64
+
65
+ raise ValueError(f"Unknown projector type: {projector_type}")
llava/model/multimodal_projector/pooler_projector.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
7
+
8
+
9
+ class PoolerProjector(nn.Module):
10
+ def __init__(self, config, vision_cfg):
11
+ super().__init__()
12
+ self._config = config
13
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
14
+
15
+ self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16
+
17
+ self.proj = nn.Sequential(
18
+ nn.GELU(),
19
+ nn.Linear(config.hidden_size, config.hidden_size),
20
+ )
21
+
22
+ def forward(self, x, *args, **kwargs):
23
+ height = width = self.hw
24
+ assert height * width == x.shape[1]
25
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26
+ x = self.conv_pool(x)
27
+ x = x.flatten(2).transpose(1, 2)
28
+ x = self.proj(x)
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "pooler"}
llava/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .masked_drop import MaskedDrop
4
+ from .spatial_pool import SpatialPool
5
+ from .perceiver import PerceiverResampler
6
+ from .qformer import Qformer
7
+
8
+
9
+ class IdentityMap(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_resampler_type": None}
19
+
20
+
21
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
22
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
23
+ if resampler_type == "masked_drop":
24
+ return MaskedDrop(model_args)
25
+ elif resampler_type == "spatial_pool":
26
+ return SpatialPool(model_args, **kwargs)
27
+ elif resampler_type == "perceiver":
28
+ return PerceiverResampler(model_args, **kwargs)
29
+ elif resampler_type == "qformer":
30
+ return Qformer(model_args, **kwargs)
31
+ elif resampler_type is None:
32
+ return IdentityMap()
33
+
34
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
llava/model/multimodal_resampler/masked_drop.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import random
5
+
6
+
7
+ class MaskedDrop(nn.Module):
8
+ def __init__(self, model_args):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_mask_drop_mode
12
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
+ self.ratio = model_args.mm_mask_drop_ratio
14
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
+
17
+ def forward(self, image_features, *args, **kwargs):
18
+
19
+ if not self.training:
20
+ return image_features
21
+
22
+ if self.skip_percentage > random.random():
23
+ return image_features
24
+
25
+ masked_features = []
26
+
27
+ for image_feature in image_features:
28
+ num_tokens = image_feature.shape[0]
29
+ if self.mode == "fixed":
30
+ num_keep = int(num_tokens * self.ratio)
31
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32
+ elif self.mode == "range":
33
+ num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34
+ masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35
+ elif self.mode == "cls_only":
36
+ masked_features.append(image_feature[0:1])
37
+ else:
38
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39
+
40
+ if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41
+ masked_features = torch.stack(masked_features, dim=0)
42
+
43
+ return masked_features
44
+
45
+ @property
46
+ def config(self):
47
+ return {
48
+ "mm_resampler_type": "masked_drop",
49
+ "mm_mask_drop_mode": self.mode,
50
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
51
+ "mm_mask_drop_ratio": self.ratio,
52
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
53
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
54
+ }
55
+
56
+ def random_masking(self, x, len_keep):
57
+ """
58
+ Perform per-sample random masking by per-sample shuffling.
59
+ Per-sample shuffling is done by argsort random noise.
60
+ x: [N, L, D], sequence
61
+ """
62
+ N, L, D = x.shape # batch, length, dim
63
+
64
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65
+
66
+ # sort noise for each sample
67
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
69
+
70
+ # keep the first subset
71
+ ids_keep = ids_shuffle[:, :len_keep]
72
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73
+
74
+ # generate the binary mask: 0 is keep, 1 is remove
75
+ mask = torch.ones([N, L], device=x.device)
76
+ mask[:, :len_keep] = 0
77
+ # unshuffle to get the binary mask
78
+ mask = torch.gather(mask, dim=1, index=ids_restore)
79
+
80
+ return x_masked, mask, ids_restore
llava/model/multimodal_resampler/perceiver.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+
8
+ try:
9
+ from einops_exts import rearrange_many
10
+ except:
11
+ pass
12
+
13
+ from torch import einsum, nn
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def FeedForward(dim, mult=4):
21
+ inner_dim = int(dim * mult)
22
+ return nn.Sequential(
23
+ nn.LayerNorm(dim),
24
+ nn.Linear(dim, inner_dim, bias=False),
25
+ nn.GELU(),
26
+ nn.Linear(inner_dim, dim, bias=False),
27
+ )
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm_media = nn.LayerNorm(dim)
38
+ self.norm_latents = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, T, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, T, n2, D)
51
+ """
52
+ x = self.norm_media(x)
53
+ latents = self.norm_latents(latents)
54
+
55
+ h = self.heads
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
+ q = q * self.scale
62
+
63
+ # attention
64
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
65
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66
+ attn = sim.softmax(dim=-1)
67
+
68
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
69
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70
+ return self.to_out(out)
71
+
72
+
73
+ class PerceiverResamplerModule(nn.Module):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ dim,
78
+ depth=6,
79
+ dim_head=64,
80
+ heads=8,
81
+ num_latents=64,
82
+ max_num_media=None,
83
+ max_num_frames=None,
84
+ ff_mult=4,
85
+ ):
86
+ super().__init__()
87
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
88
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
89
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
90
+
91
+ self.layers = nn.ModuleList([])
92
+ for _ in range(depth):
93
+ self.layers.append(
94
+ nn.ModuleList(
95
+ [
96
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
97
+ FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
98
+ ]
99
+ )
100
+ )
101
+
102
+ self.norm = nn.LayerNorm(dim)
103
+
104
+ def forward(self, x):
105
+ """
106
+ Args:
107
+ x (torch.Tensor): image features
108
+ shape (b, T, F, v, D)
109
+ Returns:
110
+ shape (b, T, n, D) where n is self.num_latents
111
+ """
112
+ b, T, F, v = x.shape[:4]
113
+
114
+ # frame and media time embeddings
115
+ if exists(self.frame_embs):
116
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
117
+ x = x + frame_embs
118
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
119
+ if exists(self.media_time_embs):
120
+ x = x + self.media_time_embs[:T]
121
+
122
+ # blocks
123
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
124
+ for attn, ff in self.layers:
125
+ latents = attn(x, latents) + latents
126
+ latents = ff(latents) + latents
127
+ return self.norm(latents)
128
+
129
+
130
+ class PerceiverResampler(nn.Module):
131
+ def __init__(self, model_args, vision_tower):
132
+ super().__init__()
133
+
134
+ self.depth = model_args.mm_perceiver_depth
135
+ self.num_latents = model_args.mm_perceiver_latents
136
+ self.ff_mult = model_args.mm_perceiver_ff_mult
137
+ self.pretrained = model_args.mm_perceiver_pretrained
138
+
139
+ self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
140
+
141
+ if self.pretrained is not None:
142
+ self.load_state_dict(torch.load(self.pretrained))
143
+
144
+ def forward(self, image_features, *args, **kwargs):
145
+ return self.perceiver(image_features[:, None, None]).squeeze(1)
146
+
147
+ @property
148
+ def config(self):
149
+ return {
150
+ "mm_resampler_type": "perceiver",
151
+ "mm_perceiver_depth": self.depth,
152
+ "mm_perceiver_latents": self.num_latents,
153
+ "mm_perceiver_ff_mult": self.ff_mult,
154
+ "mm_perceiver_pretrained": self.pretrained,
155
+ }
llava/model/multimodal_resampler/qformer.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def disabled_train(self, mode=True):
52
+ """Overwrite model.train with this function to make sure train/eval mode
53
+ does not change anymore."""
54
+ return self
55
+
56
+
57
+ class BertEmbeddings(nn.Module):
58
+ """Construct the embeddings from word and position embeddings."""
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
63
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
64
+
65
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
66
+ # any TensorFlow checkpoint file
67
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
69
+
70
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
71
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
72
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
73
+
74
+ self.config = config
75
+
76
+ def forward(
77
+ self,
78
+ input_ids=None,
79
+ position_ids=None,
80
+ query_embeds=None,
81
+ past_key_values_length=0,
82
+ ):
83
+ if input_ids is not None:
84
+ seq_length = input_ids.size()[1]
85
+ else:
86
+ seq_length = 0
87
+
88
+ if position_ids is None:
89
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
90
+
91
+ if input_ids is not None:
92
+ embeddings = self.word_embeddings(input_ids)
93
+ if self.position_embedding_type == "absolute":
94
+ position_embeddings = self.position_embeddings(position_ids)
95
+ embeddings = embeddings + position_embeddings
96
+
97
+ if query_embeds is not None:
98
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
99
+ else:
100
+ embeddings = query_embeds
101
+
102
+ embeddings = self.LayerNorm(embeddings)
103
+ embeddings = self.dropout(embeddings)
104
+ return embeddings
105
+
106
+
107
+ class BertSelfAttention(nn.Module):
108
+ def __init__(self, config, is_cross_attention):
109
+ super().__init__()
110
+ self.config = config
111
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
112
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
113
+
114
+ self.num_attention_heads = config.num_attention_heads
115
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
116
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
117
+
118
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
119
+ if is_cross_attention:
120
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
121
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
122
+ else:
123
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
124
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
125
+
126
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
127
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
128
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
129
+ self.max_position_embeddings = config.max_position_embeddings
130
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
131
+ self.save_attention = False
132
+
133
+ def save_attn_gradients(self, attn_gradients):
134
+ self.attn_gradients = attn_gradients
135
+
136
+ def get_attn_gradients(self):
137
+ return self.attn_gradients
138
+
139
+ def save_attention_map(self, attention_map):
140
+ self.attention_map = attention_map
141
+
142
+ def get_attention_map(self):
143
+ return self.attention_map
144
+
145
+ def transpose_for_scores(self, x):
146
+ new_x_shape = x.size()[:-1] + (
147
+ self.num_attention_heads,
148
+ self.attention_head_size,
149
+ )
150
+ x = x.view(*new_x_shape)
151
+ return x.permute(0, 2, 1, 3)
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states,
156
+ attention_mask=None,
157
+ head_mask=None,
158
+ encoder_hidden_states=None,
159
+ encoder_attention_mask=None,
160
+ past_key_value=None,
161
+ output_attentions=False,
162
+ ):
163
+
164
+ # If this is instantiated as a cross-attention module, the keys
165
+ # and values come from an encoder; the attention mask needs to be
166
+ # such that the encoder's padding tokens are not attended to.
167
+ is_cross_attention = encoder_hidden_states is not None
168
+
169
+ if is_cross_attention:
170
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
172
+ attention_mask = encoder_attention_mask
173
+ elif past_key_value is not None:
174
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
175
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
176
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
177
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
178
+ else:
179
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
180
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
181
+
182
+ mixed_query_layer = self.query(hidden_states)
183
+
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ past_key_value = (key_layer, value_layer)
187
+
188
+ # Take the dot product between "query" and "key" to get the raw attention scores.
189
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
190
+
191
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
192
+ seq_length = hidden_states.size()[1]
193
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
194
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
195
+ distance = position_ids_l - position_ids_r
196
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
197
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
198
+
199
+ if self.position_embedding_type == "relative_key":
200
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
201
+ attention_scores = attention_scores + relative_position_scores
202
+ elif self.position_embedding_type == "relative_key_query":
203
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
204
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
205
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
206
+
207
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
208
+ if attention_mask is not None:
209
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
210
+ attention_scores = attention_scores + attention_mask
211
+
212
+ # Normalize the attention scores to probabilities.
213
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
214
+
215
+ if is_cross_attention and self.save_attention:
216
+ self.save_attention_map(attention_probs)
217
+ attention_probs.register_hook(self.save_attn_gradients)
218
+
219
+ # This is actually dropping out entire tokens to attend to, which might
220
+ # seem a bit unusual, but is taken from the original Transformer paper.
221
+ attention_probs_dropped = self.dropout(attention_probs)
222
+
223
+ # Mask heads if we want to
224
+ if head_mask is not None:
225
+ attention_probs_dropped = attention_probs_dropped * head_mask
226
+
227
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
228
+
229
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
230
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
231
+ context_layer = context_layer.view(*new_context_layer_shape)
232
+
233
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
234
+
235
+ outputs = outputs + (past_key_value,)
236
+ return outputs
237
+
238
+
239
+ class BertSelfOutput(nn.Module):
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
243
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
244
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
245
+
246
+ def forward(self, hidden_states, input_tensor):
247
+ hidden_states = self.dense(hidden_states)
248
+ hidden_states = self.dropout(hidden_states)
249
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
250
+ return hidden_states
251
+
252
+
253
+ class BertAttention(nn.Module):
254
+ def __init__(self, config, is_cross_attention=False):
255
+ super().__init__()
256
+ self.self = BertSelfAttention(config, is_cross_attention)
257
+ self.output = BertSelfOutput(config)
258
+ self.pruned_heads = set()
259
+
260
+ def prune_heads(self, heads):
261
+ if len(heads) == 0:
262
+ return
263
+ heads, index = find_pruneable_heads_and_indices(
264
+ heads,
265
+ self.self.num_attention_heads,
266
+ self.self.attention_head_size,
267
+ self.pruned_heads,
268
+ )
269
+
270
+ # Prune linear layers
271
+ self.self.query = prune_linear_layer(self.self.query, index)
272
+ self.self.key = prune_linear_layer(self.self.key, index)
273
+ self.self.value = prune_linear_layer(self.self.value, index)
274
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
275
+
276
+ # Update hyper params and store pruned heads
277
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
278
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
279
+ self.pruned_heads = self.pruned_heads.union(heads)
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states,
284
+ attention_mask=None,
285
+ head_mask=None,
286
+ encoder_hidden_states=None,
287
+ encoder_attention_mask=None,
288
+ past_key_value=None,
289
+ output_attentions=False,
290
+ ):
291
+ self_outputs = self.self(
292
+ hidden_states,
293
+ attention_mask,
294
+ head_mask,
295
+ encoder_hidden_states,
296
+ encoder_attention_mask,
297
+ past_key_value,
298
+ output_attentions,
299
+ )
300
+ attention_output = self.output(self_outputs[0], hidden_states)
301
+
302
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
303
+ return outputs
304
+
305
+
306
+ class BertIntermediate(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
310
+ if isinstance(config.hidden_act, str):
311
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
312
+ else:
313
+ self.intermediate_act_fn = config.hidden_act
314
+
315
+ def forward(self, hidden_states):
316
+ hidden_states = self.dense(hidden_states)
317
+ hidden_states = self.intermediate_act_fn(hidden_states)
318
+ return hidden_states
319
+
320
+
321
+ class BertOutput(nn.Module):
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
325
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
326
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
327
+
328
+ def forward(self, hidden_states, input_tensor):
329
+ hidden_states = self.dense(hidden_states)
330
+ hidden_states = self.dropout(hidden_states)
331
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
332
+ return hidden_states
333
+
334
+
335
+ class BertLayer(nn.Module):
336
+ def __init__(self, config, layer_num):
337
+ super().__init__()
338
+ self.config = config
339
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
340
+ self.seq_len_dim = 1
341
+ self.attention = BertAttention(config)
342
+ self.layer_num = layer_num
343
+ if self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0:
344
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
345
+ self.has_cross_attention = True
346
+ else:
347
+ self.has_cross_attention = False
348
+ self.intermediate = BertIntermediate(config)
349
+ self.output = BertOutput(config)
350
+
351
+ self.intermediate_query = BertIntermediate(config)
352
+ self.output_query = BertOutput(config)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states,
357
+ attention_mask=None,
358
+ head_mask=None,
359
+ encoder_hidden_states=None,
360
+ encoder_attention_mask=None,
361
+ past_key_value=None,
362
+ output_attentions=False,
363
+ query_length=0,
364
+ ):
365
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
366
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
367
+ self_attention_outputs = self.attention(
368
+ hidden_states,
369
+ attention_mask,
370
+ head_mask,
371
+ output_attentions=output_attentions,
372
+ past_key_value=self_attn_past_key_value,
373
+ )
374
+ attention_output = self_attention_outputs[0]
375
+ outputs = self_attention_outputs[1:-1]
376
+
377
+ present_key_value = self_attention_outputs[-1]
378
+
379
+ if query_length > 0:
380
+ query_attention_output = attention_output[:, :query_length, :]
381
+
382
+ if self.has_cross_attention:
383
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
384
+ cross_attention_outputs = self.crossattention(
385
+ query_attention_output,
386
+ attention_mask,
387
+ head_mask,
388
+ encoder_hidden_states,
389
+ encoder_attention_mask,
390
+ output_attentions=output_attentions,
391
+ )
392
+ query_attention_output = cross_attention_outputs[0]
393
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
394
+
395
+ layer_output = apply_chunking_to_forward(
396
+ self.feed_forward_chunk_query,
397
+ self.chunk_size_feed_forward,
398
+ self.seq_len_dim,
399
+ query_attention_output,
400
+ )
401
+ if attention_output.shape[1] > query_length:
402
+ layer_output_text = apply_chunking_to_forward(
403
+ self.feed_forward_chunk,
404
+ self.chunk_size_feed_forward,
405
+ self.seq_len_dim,
406
+ attention_output[:, query_length:, :],
407
+ )
408
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
409
+ else:
410
+ layer_output = apply_chunking_to_forward(
411
+ self.feed_forward_chunk,
412
+ self.chunk_size_feed_forward,
413
+ self.seq_len_dim,
414
+ attention_output,
415
+ )
416
+ outputs = (layer_output,) + outputs
417
+
418
+ outputs = outputs + (present_key_value,)
419
+
420
+ return outputs
421
+
422
+ def feed_forward_chunk(self, attention_output):
423
+ intermediate_output = self.intermediate(attention_output)
424
+ layer_output = self.output(intermediate_output, attention_output)
425
+ return layer_output
426
+
427
+ def feed_forward_chunk_query(self, attention_output):
428
+ intermediate_output = self.intermediate_query(attention_output)
429
+ layer_output = self.output_query(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ class BertEncoder(nn.Module):
434
+ def __init__(self, config):
435
+ super().__init__()
436
+ self.config = config
437
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states,
442
+ attention_mask=None,
443
+ head_mask=None,
444
+ encoder_hidden_states=None,
445
+ encoder_attention_mask=None,
446
+ past_key_values=None,
447
+ use_cache=None,
448
+ output_attentions=False,
449
+ output_hidden_states=False,
450
+ return_dict=True,
451
+ query_length=0,
452
+ ):
453
+ all_hidden_states = () if output_hidden_states else None
454
+ all_self_attentions = () if output_attentions else None
455
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
456
+
457
+ next_decoder_cache = () if use_cache else None
458
+
459
+ for i in range(self.config.num_hidden_layers):
460
+ layer_module = self.layer[i]
461
+ if output_hidden_states:
462
+ all_hidden_states = all_hidden_states + (hidden_states,)
463
+
464
+ layer_head_mask = head_mask[i] if head_mask is not None else None
465
+ past_key_value = past_key_values[i] if past_key_values is not None else None
466
+
467
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
468
+
469
+ if use_cache:
470
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
471
+ use_cache = False
472
+
473
+ def create_custom_forward(module):
474
+ def custom_forward(*inputs):
475
+ return module(*inputs, past_key_value, output_attentions, query_length)
476
+
477
+ return custom_forward
478
+
479
+ layer_outputs = torch.utils.checkpoint.checkpoint(
480
+ create_custom_forward(layer_module),
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ )
487
+ else:
488
+ layer_outputs = layer_module(
489
+ hidden_states,
490
+ attention_mask,
491
+ layer_head_mask,
492
+ encoder_hidden_states,
493
+ encoder_attention_mask,
494
+ past_key_value,
495
+ output_attentions,
496
+ query_length,
497
+ )
498
+
499
+ hidden_states = layer_outputs[0]
500
+ if use_cache:
501
+ next_decoder_cache += (layer_outputs[-1],)
502
+ if output_attentions:
503
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
504
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
505
+
506
+ if output_hidden_states:
507
+ all_hidden_states = all_hidden_states + (hidden_states,)
508
+
509
+ if not return_dict:
510
+ return tuple(
511
+ v
512
+ for v in [
513
+ hidden_states,
514
+ next_decoder_cache,
515
+ all_hidden_states,
516
+ all_self_attentions,
517
+ all_cross_attentions,
518
+ ]
519
+ if v is not None
520
+ )
521
+ return BaseModelOutputWithPastAndCrossAttentions(
522
+ last_hidden_state=hidden_states,
523
+ past_key_values=next_decoder_cache,
524
+ hidden_states=all_hidden_states,
525
+ attentions=all_self_attentions,
526
+ cross_attentions=all_cross_attentions,
527
+ )
528
+
529
+
530
+ class BertPooler(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
534
+ self.activation = nn.Tanh()
535
+
536
+ def forward(self, hidden_states):
537
+ # We "pool" the model by simply taking the hidden state corresponding
538
+ # to the first token.
539
+ first_token_tensor = hidden_states[:, 0]
540
+ pooled_output = self.dense(first_token_tensor)
541
+ pooled_output = self.activation(pooled_output)
542
+ return pooled_output
543
+
544
+
545
+ class BertPredictionHeadTransform(nn.Module):
546
+ def __init__(self, config):
547
+ super().__init__()
548
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
549
+ if isinstance(config.hidden_act, str):
550
+ self.transform_act_fn = ACT2FN[config.hidden_act]
551
+ else:
552
+ self.transform_act_fn = config.hidden_act
553
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
554
+
555
+ def forward(self, hidden_states):
556
+ hidden_states = self.dense(hidden_states)
557
+ hidden_states = self.transform_act_fn(hidden_states)
558
+ hidden_states = self.LayerNorm(hidden_states)
559
+ return hidden_states
560
+
561
+
562
+ class BertLMPredictionHead(nn.Module):
563
+ def __init__(self, config):
564
+ super().__init__()
565
+ self.transform = BertPredictionHeadTransform(config)
566
+
567
+ # The output weights are the same as the input embeddings, but there is
568
+ # an output-only bias for each token.
569
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
570
+
571
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
572
+
573
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
574
+ self.decoder.bias = self.bias
575
+
576
+ def forward(self, hidden_states):
577
+ hidden_states = self.transform(hidden_states)
578
+ hidden_states = self.decoder(hidden_states)
579
+ return hidden_states
580
+
581
+
582
+ class BertOnlyMLMHead(nn.Module):
583
+ def __init__(self, config):
584
+ super().__init__()
585
+ self.predictions = BertLMPredictionHead(config)
586
+
587
+ def forward(self, sequence_output):
588
+ prediction_scores = self.predictions(sequence_output)
589
+ return prediction_scores
590
+
591
+
592
+ class BertPreTrainedModel(PreTrainedModel):
593
+ """
594
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
595
+ models.
596
+ """
597
+
598
+ config_class = BertConfig
599
+ base_model_prefix = "bert"
600
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
601
+
602
+ def _init_weights(self, module):
603
+ """Initialize the weights"""
604
+ if isinstance(module, (nn.Linear, nn.Embedding)):
605
+ # Slightly different from the TF version which uses truncated_normal for initialization
606
+ # cf https://github.com/pytorch/pytorch/pull/5617
607
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
608
+ elif isinstance(module, nn.LayerNorm):
609
+ module.bias.data.zero_()
610
+ module.weight.data.fill_(1.0)
611
+ if isinstance(module, nn.Linear) and module.bias is not None:
612
+ module.bias.data.zero_()
613
+
614
+
615
+ class BertModel(BertPreTrainedModel):
616
+ """
617
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
618
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
619
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
620
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
621
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
622
+ input to the forward pass.
623
+ """
624
+
625
+ def __init__(self, config, add_pooling_layer=False):
626
+ super().__init__(config)
627
+ self.config = config
628
+
629
+ self.embeddings = BertEmbeddings(config)
630
+
631
+ self.encoder = BertEncoder(config)
632
+
633
+ self.pooler = BertPooler(config) if add_pooling_layer else None
634
+
635
+ self.init_weights()
636
+
637
+ def get_input_embeddings(self):
638
+ return self.embeddings.word_embeddings
639
+
640
+ def set_input_embeddings(self, value):
641
+ self.embeddings.word_embeddings = value
642
+
643
+ def _prune_heads(self, heads_to_prune):
644
+ """
645
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
646
+ class PreTrainedModel
647
+ """
648
+ for layer, heads in heads_to_prune.items():
649
+ self.encoder.layer[layer].attention.prune_heads(heads)
650
+
651
+ def get_extended_attention_mask(
652
+ self,
653
+ attention_mask: Tensor,
654
+ input_shape: Tuple[int],
655
+ device: device,
656
+ is_decoder: bool,
657
+ has_query: bool = False,
658
+ ) -> Tensor:
659
+ """
660
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
661
+
662
+ Arguments:
663
+ attention_mask (:obj:`torch.Tensor`):
664
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
665
+ input_shape (:obj:`Tuple[int]`):
666
+ The shape of the input to the model.
667
+ device: (:obj:`torch.device`):
668
+ The device of the input to the model.
669
+
670
+ Returns:
671
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
672
+ """
673
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
674
+ # ourselves in which case we just need to make it broadcastable to all heads.
675
+ if attention_mask.dim() == 3:
676
+ extended_attention_mask = attention_mask[:, None, :, :]
677
+ elif attention_mask.dim() == 2:
678
+ # Provided a padding mask of dimensions [batch_size, seq_length]
679
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
680
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
681
+ if is_decoder:
682
+ batch_size, seq_length = input_shape
683
+
684
+ seq_ids = torch.arange(seq_length, device=device)
685
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
686
+
687
+ # add a prefix ones mask to the causal mask
688
+ # causal and attention masks must have same type with pytorch version < 1.3
689
+ causal_mask = causal_mask.to(attention_mask.dtype)
690
+
691
+ if causal_mask.shape[1] < attention_mask.shape[1]:
692
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
693
+ if has_query: # UniLM style attention mask
694
+ causal_mask = torch.cat(
695
+ [
696
+ torch.zeros(
697
+ (batch_size, prefix_seq_len, seq_length),
698
+ device=device,
699
+ dtype=causal_mask.dtype,
700
+ ),
701
+ causal_mask,
702
+ ],
703
+ axis=1,
704
+ )
705
+ causal_mask = torch.cat(
706
+ [
707
+ torch.ones(
708
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
709
+ device=device,
710
+ dtype=causal_mask.dtype,
711
+ ),
712
+ causal_mask,
713
+ ],
714
+ axis=-1,
715
+ )
716
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
717
+ else:
718
+ extended_attention_mask = attention_mask[:, None, None, :]
719
+ else:
720
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
721
+
722
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
723
+ # masked positions, this operation will create a tensor which is 0.0 for
724
+ # positions we want to attend and -10000.0 for masked positions.
725
+ # Since we are adding it to the raw scores before the softmax, this is
726
+ # effectively the same as removing these entirely.
727
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
728
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
729
+ return extended_attention_mask
730
+
731
+ def forward(
732
+ self,
733
+ input_ids=None,
734
+ attention_mask=None,
735
+ position_ids=None,
736
+ head_mask=None,
737
+ query_embeds=None,
738
+ encoder_hidden_states=None,
739
+ encoder_attention_mask=None,
740
+ past_key_values=None,
741
+ use_cache=None,
742
+ output_attentions=None,
743
+ output_hidden_states=None,
744
+ return_dict=None,
745
+ is_decoder=False,
746
+ ):
747
+ r"""
748
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
749
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
750
+ the model is configured as a decoder.
751
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
752
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
753
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
754
+ - 1 for tokens that are **not masked**,
755
+ - 0 for tokens that are **masked**.
756
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
757
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
758
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
759
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
760
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
761
+ use_cache (:obj:`bool`, `optional`):
762
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
763
+ decoding (see :obj:`past_key_values`).
764
+ """
765
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
767
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
768
+
769
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
770
+
771
+ if input_ids is None:
772
+ assert query_embeds is not None, "You have to specify query_embeds when input_ids is None"
773
+
774
+ # past_key_values_length
775
+ past_key_values_length = past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
776
+
777
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
778
+
779
+ embedding_output = self.embeddings(
780
+ input_ids=input_ids,
781
+ position_ids=position_ids,
782
+ query_embeds=query_embeds,
783
+ past_key_values_length=past_key_values_length,
784
+ )
785
+
786
+ input_shape = embedding_output.size()[:-1]
787
+ batch_size, seq_length = input_shape
788
+ device = embedding_output.device
789
+
790
+ if attention_mask is None:
791
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
792
+
793
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
794
+ # ourselves in which case we just need to make it broadcastable to all heads.
795
+ if is_decoder:
796
+ extended_attention_mask = self.get_extended_attention_mask(
797
+ attention_mask,
798
+ input_ids.shape,
799
+ device,
800
+ is_decoder,
801
+ has_query=(query_embeds is not None),
802
+ )
803
+ else:
804
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device, is_decoder)
805
+
806
+ # If a 2D or 3D attention mask is provided for the cross-attention
807
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
808
+ if encoder_hidden_states is not None:
809
+ if type(encoder_hidden_states) == list:
810
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
811
+ else:
812
+ (
813
+ encoder_batch_size,
814
+ encoder_sequence_length,
815
+ _,
816
+ ) = encoder_hidden_states.size()
817
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
818
+
819
+ if type(encoder_attention_mask) == list:
820
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
821
+ elif encoder_attention_mask is None:
822
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
823
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
824
+ else:
825
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
826
+ else:
827
+ encoder_extended_attention_mask = None
828
+
829
+ # Prepare head mask if needed
830
+ # 1.0 in head_mask indicate we keep the head
831
+ # attention_probs has shape bsz x n_heads x N x N
832
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
833
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
834
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
835
+
836
+ encoder_outputs = self.encoder(
837
+ embedding_output,
838
+ attention_mask=extended_attention_mask,
839
+ head_mask=head_mask,
840
+ encoder_hidden_states=encoder_hidden_states,
841
+ encoder_attention_mask=encoder_extended_attention_mask,
842
+ past_key_values=past_key_values,
843
+ use_cache=use_cache,
844
+ output_attentions=output_attentions,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ query_length=query_length,
848
+ )
849
+ sequence_output = encoder_outputs[0]
850
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
851
+
852
+ if not return_dict:
853
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
854
+
855
+ return BaseModelOutputWithPoolingAndCrossAttentions(
856
+ last_hidden_state=sequence_output,
857
+ pooler_output=pooled_output,
858
+ past_key_values=encoder_outputs.past_key_values,
859
+ hidden_states=encoder_outputs.hidden_states,
860
+ attentions=encoder_outputs.attentions,
861
+ cross_attentions=encoder_outputs.cross_attentions,
862
+ )
863
+
864
+
865
+ class BertLMHeadModel(BertPreTrainedModel):
866
+
867
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
868
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
869
+
870
+ def __init__(self, config):
871
+ super().__init__(config)
872
+
873
+ self.bert = BertModel(config, add_pooling_layer=False)
874
+ self.cls = BertOnlyMLMHead(config)
875
+
876
+ self.init_weights()
877
+
878
+ def get_output_embeddings(self):
879
+ return self.cls.predictions.decoder
880
+
881
+ def set_output_embeddings(self, new_embeddings):
882
+ self.cls.predictions.decoder = new_embeddings
883
+
884
+ def forward(
885
+ self,
886
+ input_ids=None,
887
+ attention_mask=None,
888
+ position_ids=None,
889
+ head_mask=None,
890
+ query_embeds=None,
891
+ encoder_hidden_states=None,
892
+ encoder_attention_mask=None,
893
+ labels=None,
894
+ past_key_values=None,
895
+ use_cache=True,
896
+ output_attentions=None,
897
+ output_hidden_states=None,
898
+ return_dict=None,
899
+ return_logits=False,
900
+ is_decoder=True,
901
+ reduction="mean",
902
+ ):
903
+ r"""
904
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
905
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
906
+ the model is configured as a decoder.
907
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
908
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
909
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
910
+ - 1 for tokens that are **not masked**,
911
+ - 0 for tokens that are **masked**.
912
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
913
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
914
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
915
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
916
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
917
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
918
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
919
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
920
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
921
+ use_cache (:obj:`bool`, `optional`):
922
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
923
+ decoding (see :obj:`past_key_values`).
924
+ Returns:
925
+ Example::
926
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
927
+ >>> import torch
928
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
929
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
930
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
931
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
932
+ >>> outputs = model(**inputs)
933
+ >>> prediction_logits = outputs.logits
934
+ """
935
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
+ if labels is not None:
937
+ use_cache = False
938
+ if past_key_values is not None:
939
+ query_embeds = None
940
+
941
+ outputs = self.bert(
942
+ input_ids,
943
+ attention_mask=attention_mask,
944
+ position_ids=position_ids,
945
+ head_mask=head_mask,
946
+ query_embeds=query_embeds,
947
+ encoder_hidden_states=encoder_hidden_states,
948
+ encoder_attention_mask=encoder_attention_mask,
949
+ past_key_values=past_key_values,
950
+ use_cache=use_cache,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ return_dict=return_dict,
954
+ is_decoder=is_decoder,
955
+ )
956
+
957
+ sequence_output = outputs[0]
958
+ if query_embeds is not None:
959
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
960
+
961
+ prediction_scores = self.cls(sequence_output)
962
+
963
+ if return_logits:
964
+ return prediction_scores[:, :-1, :].contiguous()
965
+
966
+ lm_loss = None
967
+ if labels is not None:
968
+ # we are doing next-token prediction; shift prediction scores and input ids by one
969
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
970
+ labels = labels[:, 1:].contiguous()
971
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
972
+ lm_loss = loss_fct(
973
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
974
+ labels.view(-1),
975
+ )
976
+ if reduction == "none":
977
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
978
+
979
+ if not return_dict:
980
+ output = (prediction_scores,) + outputs[2:]
981
+ return ((lm_loss,) + output) if lm_loss is not None else output
982
+
983
+ return CausalLMOutputWithCrossAttentions(
984
+ loss=lm_loss,
985
+ logits=prediction_scores,
986
+ past_key_values=outputs.past_key_values,
987
+ hidden_states=outputs.hidden_states,
988
+ attentions=outputs.attentions,
989
+ cross_attentions=outputs.cross_attentions,
990
+ )
991
+
992
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
993
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
994
+ if attention_mask is None:
995
+ attention_mask = input_ids.new_ones(input_ids.shape)
996
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
997
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
998
+
999
+ # cut decoder_input_ids if past is used
1000
+ if past is not None:
1001
+ input_ids = input_ids[:, -1:]
1002
+
1003
+ return {
1004
+ "input_ids": input_ids,
1005
+ "query_embeds": query_embeds,
1006
+ "attention_mask": attention_mask,
1007
+ "past_key_values": past,
1008
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1009
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1010
+ "is_decoder": True,
1011
+ }
1012
+
1013
+ def _reorder_cache(self, past, beam_idx):
1014
+ reordered_past = ()
1015
+ for layer_past in past:
1016
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1017
+ return reordered_past
1018
+
1019
+
1020
+ class BertForMaskedLM(BertPreTrainedModel):
1021
+
1022
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1023
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1024
+
1025
+ def __init__(self, config):
1026
+ super().__init__(config)
1027
+
1028
+ self.bert = BertModel(config, add_pooling_layer=False)
1029
+ self.cls = BertOnlyMLMHead(config)
1030
+
1031
+ self.init_weights()
1032
+
1033
+ def get_output_embeddings(self):
1034
+ return self.cls.predictions.decoder
1035
+
1036
+ def set_output_embeddings(self, new_embeddings):
1037
+ self.cls.predictions.decoder = new_embeddings
1038
+
1039
+ def forward(
1040
+ self,
1041
+ input_ids=None,
1042
+ attention_mask=None,
1043
+ position_ids=None,
1044
+ head_mask=None,
1045
+ query_embeds=None,
1046
+ encoder_hidden_states=None,
1047
+ encoder_attention_mask=None,
1048
+ labels=None,
1049
+ output_attentions=None,
1050
+ output_hidden_states=None,
1051
+ return_dict=None,
1052
+ return_logits=False,
1053
+ is_decoder=False,
1054
+ ):
1055
+ r"""
1056
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1057
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1058
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1059
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1060
+ """
1061
+
1062
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1063
+
1064
+ outputs = self.bert(
1065
+ input_ids,
1066
+ attention_mask=attention_mask,
1067
+ position_ids=position_ids,
1068
+ head_mask=head_mask,
1069
+ query_embeds=query_embeds,
1070
+ encoder_hidden_states=encoder_hidden_states,
1071
+ encoder_attention_mask=encoder_attention_mask,
1072
+ output_attentions=output_attentions,
1073
+ output_hidden_states=output_hidden_states,
1074
+ return_dict=return_dict,
1075
+ is_decoder=is_decoder,
1076
+ )
1077
+
1078
+ if query_embeds is not None:
1079
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1080
+ prediction_scores = self.cls(sequence_output)
1081
+
1082
+ if return_logits:
1083
+ return prediction_scores
1084
+
1085
+ masked_lm_loss = None
1086
+ if labels is not None:
1087
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1088
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1089
+
1090
+ if not return_dict:
1091
+ output = (prediction_scores,) + outputs[2:]
1092
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1093
+
1094
+ return MaskedLMOutput(
1095
+ loss=masked_lm_loss,
1096
+ logits=prediction_scores,
1097
+ hidden_states=outputs.hidden_states,
1098
+ attentions=outputs.attentions,
1099
+ )
1100
+
1101
+
1102
+ class Qformer(nn.Module):
1103
+ def __init__(self, model_args, vision_tower):
1104
+ super().__init__()
1105
+
1106
+ self.depth = model_args.mm_qformer_depth
1107
+ self.num_latents = model_args.mm_qformer_latents
1108
+ self.pretrained = model_args.mm_qformer_pretrained
1109
+
1110
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
1111
+
1112
+ if self.pretrained is not None:
1113
+ pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
1114
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")}
1115
+ self.load_state_dict(pretrained_dict)
1116
+
1117
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
1118
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1119
+ encoder_config.encoder_width = vision_width
1120
+ # insert cross-attention layer every other block
1121
+ encoder_config.add_cross_attention = True
1122
+ encoder_config.cross_attention_freq = cross_attention_freq
1123
+ encoder_config.query_length = num_query_token
1124
+ Qformer = BertLMHeadModel(config=encoder_config)
1125
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
1126
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1127
+ Qformer.cls = None
1128
+ Qformer.bert.embeddings.word_embeddings = None
1129
+ Qformer.bert.embeddings.position_embeddings = None
1130
+ for layer in Qformer.bert.encoder.layer:
1131
+ layer.output = None
1132
+ layer.intermediate = None
1133
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
1134
+
1135
+ def forward(self, image_features, *args, **kwargs):
1136
+ x = self.ln_vision(image_features)
1137
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
1138
+
1139
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
1140
+ query_output = self.Qformer.bert(
1141
+ query_embeds=query_tokens,
1142
+ encoder_hidden_states=x,
1143
+ encoder_attention_mask=image_atts,
1144
+ return_dict=True,
1145
+ )
1146
+
1147
+ return query_output.last_hidden_state
1148
+
1149
+ @property
1150
+ def hidden_size(self):
1151
+ return 768
1152
+
1153
+ @property
1154
+ def config(self):
1155
+ return {
1156
+ "mm_resampler_type": "qformer",
1157
+ "mm_qformer_depth": self.depth,
1158
+ "mm_qformer_latents": self.num_latents,
1159
+ "mm_qformer_pretrained": self.pretrained,
1160
+ }
llava/model/multimodal_resampler/spatial_pool.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class SpatialPool(nn.Module):
7
+ def __init__(self, model_args, vision_tower):
8
+ super().__init__()
9
+
10
+ self.mode = model_args.mm_spatial_pool_mode
11
+ self.stride = model_args.mm_spatial_pool_stride
12
+ self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
13
+
14
+ if self.mode == "average":
15
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
16
+ elif self.mode == "max":
17
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
18
+ elif self.mode == "conv":
19
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
20
+ else:
21
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
22
+
23
+ def forward(self, image_features, images, *args, **kwargs):
24
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
25
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
26
+
27
+ B, _, F = image_features.shape
28
+
29
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
30
+ image_features_spatial_pool = self.pool(image_features_spatial)
31
+
32
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
33
+
34
+ @property
35
+ def config(self):
36
+ return {
37
+ "mm_resampler_type": "spatial_pool",
38
+ "mm_spatial_pool_stride": self.stride,
39
+ "mm_spatial_pool_mode": self.mode,
40
+ "mm_spatial_pool_out_channels": self.out_channels,
41
+ }
42
+
43
+ @property
44
+ def hidden_size(self):
45
+ return self.out_channels
llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if "llava" in config and "llava" not in cfg.model_type:
7
+ assert cfg.model_type == "llama"
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
llava/utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from llava.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+ import torch.distributed as dist
17
+
18
+
19
+ def rank0_print(*args):
20
+ if dist.is_initialized():
21
+ if dist.get_rank() == 0:
22
+ print(f"Rank {dist.get_rank()}: ", *args)
23
+
24
+
25
+ def build_logger(logger_name, logger_filename):
26
+ global handler
27
+
28
+ formatter = logging.Formatter(
29
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ )
32
+
33
+ # Set the format of root handlers
34
+ if not logging.getLogger().handlers:
35
+ logging.basicConfig(level=logging.INFO)
36
+ logging.getLogger().handlers[0].setFormatter(formatter)
37
+
38
+ # Redirect stdout and stderr to loggers
39
+ stdout_logger = logging.getLogger("stdout")
40
+ stdout_logger.setLevel(logging.INFO)
41
+ sl = StreamToLogger(stdout_logger, logging.INFO)
42
+ sys.stdout = sl
43
+
44
+ stderr_logger = logging.getLogger("stderr")
45
+ stderr_logger.setLevel(logging.ERROR)
46
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
47
+ sys.stderr = sl
48
+
49
+ # Get logger
50
+ logger = logging.getLogger(logger_name)
51
+ logger.setLevel(logging.INFO)
52
+
53
+ # Add a file handler for all loggers
54
+ if handler is None:
55
+ os.makedirs(LOGDIR, exist_ok=True)
56
+ filename = os.path.join(LOGDIR, logger_filename)
57
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
58
+ handler.setFormatter(formatter)
59
+
60
+ for name, item in logging.root.manager.loggerDict.items():
61
+ if isinstance(item, logging.Logger):
62
+ item.addHandler(handler)
63
+
64
+ return logger
65
+
66
+
67
+ class StreamToLogger(object):
68
+ """
69
+ Fake file-like stream object that redirects writes to a logger instance.
70
+ """
71
+
72
+ def __init__(self, logger, log_level=logging.INFO):
73
+ self.terminal = sys.stdout
74
+ self.logger = logger
75
+ self.log_level = log_level
76
+ self.linebuf = ""
77
+
78
+ def __getattr__(self, attr):
79
+ return getattr(self.terminal, attr)
80
+
81
+ def write(self, buf):
82
+ temp_linebuf = self.linebuf + buf
83
+ self.linebuf = ""
84
+ for line in temp_linebuf.splitlines(True):
85
+ # From the io.TextIOWrapper docs:
86
+ # On output, if newline is None, any '\n' characters written
87
+ # are translated to the system default line separator.
88
+ # By default sys.stdout.write() expects '\n' newlines and then
89
+ # translates them so this is still cross platform.
90
+ if line[-1] == "\n":
91
+ self.logger.log(self.log_level, line.rstrip())
92
+ else:
93
+ self.linebuf += line
94
+
95
+ def flush(self):
96
+ if self.linebuf != "":
97
+ self.logger.log(self.log_level, self.linebuf.rstrip())
98
+ self.linebuf = ""
99
+
100
+
101
+ def disable_torch_init():
102
+ """
103
+ Disable the redundant torch default initialization to accelerate model creation.
104
+ """
105
+ import torch
106
+
107
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
108
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
109
+
110
+
111
+ def violates_moderation(text):
112
+ """
113
+ Check whether the text violates OpenAI moderation API.
114
+ """
115
+ url = "https://api.openai.com/v1/moderations"
116
+ headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
117
+ text = text.replace("\n", "")
118
+ data = "{" + '"input": ' + f'"{text}"' + "}"
119
+ data = data.encode("utf-8")
120
+ try:
121
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
122
+ flagged = ret.json()["results"][0]["flagged"]
123
+ except requests.exceptions.RequestException as e:
124
+ flagged = False
125
+ except KeyError as e:
126
+ flagged = False
127
+
128
+ return flagged
129
+
130
+
131
+ def pretty_print_semaphore(semaphore):
132
+ if semaphore is None:
133
+ return "None"
134
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
llavavid/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
llavavid/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
llavavid/conversation.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+ PLAIN = auto()
12
+ LLAMA_2 = auto()
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class Conversation:
17
+ """A class that keeps all conversation history."""
18
+ system: str
19
+ roles: List[str]
20
+ messages: List[List[str]]
21
+ offset: int
22
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23
+ sep: str = "###"
24
+ sep2: str = None
25
+ version: str = "Unknown"
26
+
27
+ skip_next: bool = False
28
+
29
+ def get_prompt(self):
30
+ messages = self.messages
31
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
32
+ messages = self.messages.copy()
33
+ init_role, init_msg = messages[0].copy()
34
+ init_msg = init_msg[0].replace("<image>", "").strip()
35
+ if 'mmtag' in self.version:
36
+ messages[0] = (init_role, init_msg)
37
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
38
+ messages.insert(1, (self.roles[1], "Received."))
39
+ else:
40
+ messages[0] = (init_role, "<image>\n" + init_msg)
41
+
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in messages:
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ elif self.sep_style == SeparatorStyle.TWO:
52
+ seps = [self.sep, self.sep2]
53
+ ret = self.system + seps[0]
54
+ for i, (role, message) in enumerate(messages):
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + ": " + message + seps[i % 2]
59
+ else:
60
+ ret += role + ":"
61
+ elif self.sep_style == SeparatorStyle.MPT:
62
+ ret = self.system + self.sep
63
+ for role, message in messages:
64
+ if message:
65
+ if type(message) is tuple:
66
+ message, _, _ = message
67
+ ret += role + message + self.sep
68
+ else:
69
+ ret += role
70
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
71
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
72
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
73
+ ret = ""
74
+
75
+ for i, (role, message) in enumerate(messages):
76
+ if i == 0:
77
+ assert message, "first message should not be none"
78
+ assert role == self.roles[0], "first message should come from user"
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ if i == 0: message = wrap_sys(self.system) + message
83
+ if i % 2 == 0:
84
+ message = wrap_inst(message)
85
+ ret += self.sep + message
86
+ else:
87
+ ret += " " + message + " " + self.sep2
88
+ else:
89
+ ret += ""
90
+ ret = ret.lstrip(self.sep)
91
+ elif self.sep_style == SeparatorStyle.PLAIN:
92
+ seps = [self.sep, self.sep2]
93
+ ret = self.system
94
+ for i, (role, message) in enumerate(messages):
95
+ if message:
96
+ if type(message) is tuple:
97
+ message, _, _ = message
98
+ ret += message + seps[i % 2]
99
+ else:
100
+ ret += ""
101
+ else:
102
+ raise ValueError(f"Invalid style: {self.sep_style}")
103
+
104
+ return ret
105
+
106
+ def append_message(self, role, message):
107
+ self.messages.append([role, message])
108
+
109
+ def get_images(self, return_pil=False):
110
+ images = []
111
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
112
+ if i % 2 == 0:
113
+ if type(msg) is tuple:
114
+ import base64
115
+ from io import BytesIO
116
+ from PIL import Image
117
+ msg, image, image_process_mode = msg
118
+ if image_process_mode == "Pad":
119
+ def expand2square(pil_img, background_color=(122, 116, 104)):
120
+ width, height = pil_img.size
121
+ if width == height:
122
+ return pil_img
123
+ elif width > height:
124
+ result = Image.new(pil_img.mode, (width, width), background_color)
125
+ result.paste(pil_img, (0, (width - height) // 2))
126
+ return result
127
+ else:
128
+ result = Image.new(pil_img.mode, (height, height), background_color)
129
+ result.paste(pil_img, ((height - width) // 2, 0))
130
+ return result
131
+ image = expand2square(image)
132
+ elif image_process_mode in ["Default", "Crop"]:
133
+ pass
134
+ elif image_process_mode == "Resize":
135
+ image = image.resize((336, 336))
136
+ else:
137
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
138
+ max_hw, min_hw = max(image.size), min(image.size)
139
+ aspect_ratio = max_hw / min_hw
140
+ max_len, min_len = 800, 400
141
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
142
+ longest_edge = int(shortest_edge * aspect_ratio)
143
+ W, H = image.size
144
+ if longest_edge != max(image.size):
145
+ if H > W:
146
+ H, W = longest_edge, shortest_edge
147
+ else:
148
+ H, W = shortest_edge, longest_edge
149
+ image = image.resize((W, H))
150
+ if return_pil:
151
+ images.append(image)
152
+ else:
153
+ buffered = BytesIO()
154
+ image.save(buffered, format="PNG")
155
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
156
+ images.append(img_b64_str)
157
+ return images
158
+
159
+ def to_gradio_chatbot(self):
160
+ ret = []
161
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
162
+ if i % 2 == 0:
163
+ if type(msg) is tuple:
164
+ import base64
165
+ from io import BytesIO
166
+ msg, image, image_process_mode = msg
167
+ max_hw, min_hw = max(image.size), min(image.size)
168
+ aspect_ratio = max_hw / min_hw
169
+ max_len, min_len = 800, 400
170
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
171
+ longest_edge = int(shortest_edge * aspect_ratio)
172
+ W, H = image.size
173
+ if H > W:
174
+ H, W = longest_edge, shortest_edge
175
+ else:
176
+ H, W = shortest_edge, longest_edge
177
+ image = image.resize((W, H))
178
+ buffered = BytesIO()
179
+ image.save(buffered, format="JPEG")
180
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
181
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
182
+ msg = img_str + msg.replace('<image>', '').strip()
183
+ ret.append([msg, None])
184
+ else:
185
+ ret.append([msg, None])
186
+ else:
187
+ ret[-1][-1] = msg
188
+ return ret
189
+
190
+ def copy(self):
191
+ return Conversation(
192
+ system=self.system,
193
+ roles=self.roles,
194
+ messages=[[x, y] for x, y in self.messages],
195
+ offset=self.offset,
196
+ sep_style=self.sep_style,
197
+ sep=self.sep,
198
+ sep2=self.sep2,
199
+ version=self.version)
200
+
201
+ def dict(self):
202
+ if len(self.get_images()) > 0:
203
+ return {
204
+ "system": self.system,
205
+ "roles": self.roles,
206
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
207
+ "offset": self.offset,
208
+ "sep": self.sep,
209
+ "sep2": self.sep2,
210
+ }
211
+ return {
212
+ "system": self.system,
213
+ "roles": self.roles,
214
+ "messages": self.messages,
215
+ "offset": self.offset,
216
+ "sep": self.sep,
217
+ "sep2": self.sep2,
218
+ }
219
+
220
+
221
+ conv_vicuna_v0 = Conversation(
222
+ system="A chat between a curious human and an artificial intelligence assistant. "
223
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
224
+ roles=("Human", "Assistant"),
225
+ messages=(
226
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
227
+ ("Assistant",
228
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
229
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
230
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
231
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
232
+ "renewable and non-renewable energy sources:\n"
233
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
234
+ "energy sources are finite and will eventually run out.\n"
235
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
236
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
237
+ "and other negative effects.\n"
238
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
239
+ "have lower operational costs than non-renewable sources.\n"
240
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
241
+ "locations than non-renewable sources.\n"
242
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
243
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
244
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
245
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
246
+ ),
247
+ offset=2,
248
+ sep_style=SeparatorStyle.SINGLE,
249
+ sep="###",
250
+ )
251
+
252
+ conv_vicuna_v1 = Conversation(
253
+ system="A chat between a curious user and an artificial intelligence assistant. "
254
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
255
+ roles=("USER", "ASSISTANT"),
256
+ version="v1",
257
+ messages=(),
258
+ offset=0,
259
+ sep_style=SeparatorStyle.TWO,
260
+ sep=" ",
261
+ sep2="</s>",
262
+ )
263
+
264
+ conv_llama_2 = Conversation(
265
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
266
+
267
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="llama_v2",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.LLAMA_2,
273
+ sep="<s>",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_llava_llama_2 = Conversation(
278
+ system="You are a helpful language and vision assistant. "
279
+ "You are able to understand the visual content that the user provides, "
280
+ "and assist the user with a variety of tasks using natural language.",
281
+ roles=("USER", "ASSISTANT"),
282
+ version="llama_v2",
283
+ messages=(),
284
+ offset=0,
285
+ sep_style=SeparatorStyle.LLAMA_2,
286
+ sep="<s>",
287
+ sep2="</s>",
288
+ )
289
+
290
+ conv_mpt = Conversation(
291
+ system="""<|im_start|>system
292
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
293
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
294
+ version="mpt",
295
+ messages=(),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.MPT,
298
+ sep="<|im_end|>",
299
+ )
300
+
301
+ conv_llava_plain = Conversation(
302
+ system="",
303
+ roles=("", ""),
304
+ messages=(
305
+ ),
306
+ offset=0,
307
+ sep_style=SeparatorStyle.PLAIN,
308
+ sep="\n",
309
+ )
310
+
311
+ conv_llava_v0 = Conversation(
312
+ system="A chat between a curious human and an artificial intelligence assistant. "
313
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
314
+ roles=("Human", "Assistant"),
315
+ messages=(
316
+ ),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.SINGLE,
319
+ sep="###",
320
+ )
321
+
322
+ conv_llava_v0_mmtag = Conversation(
323
+ system="A chat between a curious user and an artificial intelligence assistant. "
324
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
325
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
326
+ roles=("Human", "Assistant"),
327
+ messages=(
328
+ ),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.SINGLE,
331
+ sep="###",
332
+ version="v0_mmtag",
333
+ )
334
+
335
+ conv_llava_v1 = Conversation(
336
+ system="A chat between a curious human and an artificial intelligence assistant. "
337
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
338
+ roles=("USER", "ASSISTANT"),
339
+ version="v1",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.TWO,
343
+ sep=" ",
344
+ sep2="</s>",
345
+ )
346
+
347
+ conv_llava_v1_mmtag = Conversation(
348
+ system="A chat between a curious user and an artificial intelligence assistant. "
349
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
350
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
351
+ roles=("USER", "ASSISTANT"),
352
+ messages=(),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.TWO,
355
+ sep=" ",
356
+ sep2="</s>",
357
+ version="v1_mmtag",
358
+ )
359
+
360
+ conv_mistral_instruct = Conversation(
361
+ system="",
362
+ roles=("USER", "ASSISTANT"),
363
+ version="llama_v2",
364
+ messages=(),
365
+ offset=0,
366
+ sep_style=SeparatorStyle.LLAMA_2,
367
+ sep="",
368
+ sep2="</s>",
369
+ )
370
+
371
+ conv_chatml_direct = Conversation(
372
+ system="""<|im_start|>system
373
+ Answer the questions.""",
374
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
375
+ version="mpt",
376
+ messages=(),
377
+ offset=0,
378
+ sep_style=SeparatorStyle.MPT,
379
+ sep="<|im_end|>",
380
+ )
381
+
382
+ default_conversation = conv_vicuna_v1
383
+ conv_templates = {
384
+ "default": conv_vicuna_v0,
385
+ "v0": conv_vicuna_v0,
386
+ "v1": conv_vicuna_v1,
387
+ "vicuna_v1": conv_vicuna_v1,
388
+ "llama_2": conv_llama_2,
389
+ "mistral_instruct": conv_mistral_instruct,
390
+ "chatml_direct": conv_chatml_direct,
391
+ "mistral_direct": conv_chatml_direct,
392
+
393
+ "plain": conv_llava_plain,
394
+ "v0_plain": conv_llava_plain,
395
+ "llava_v0": conv_llava_v0,
396
+ "v0_mmtag": conv_llava_v0_mmtag,
397
+ "llava_v1": conv_llava_v1,
398
+ "v1_mmtag": conv_llava_v1_mmtag,
399
+ "llava_llama_2": conv_llava_llama_2,
400
+
401
+ "mpt": conv_mpt,
402
+ }
403
+
404
+
405
+ if __name__ == "__main__":
406
+ print(default_conversation.get_prompt())
llavavid/mm_utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from llavavid.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
+
142
+ image_patches = [image_original_resize] + patches
143
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
+ for image_patch in image_patches]
145
+ return torch.stack(image_patches, dim=0)
146
+
147
+
148
+ def load_image_from_base64(image):
149
+ return Image.open(BytesIO(base64.b64decode(image)))
150
+
151
+
152
+ def expand2square(pil_img, background_color):
153
+ width, height = pil_img.size
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+
166
+ def process_images(images, image_processor, model_cfg):
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
+ new_images = []
169
+ if image_aspect_ratio == 'pad':
170
+ for image in images:
171
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
+ new_images.append(image)
174
+ elif image_aspect_ratio == "anyres":
175
+ for image in images:
176
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
+ new_images.append(image)
178
+ else:
179
+ return image_processor(images, return_tensors='pt')['pixel_values']
180
+ if all(x.shape == new_images[0].shape for x in new_images):
181
+ new_images = torch.stack(new_images, dim=0)
182
+ return new_images
183
+
184
+
185
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
+
188
+ def insert_separator(X, sep):
189
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
+
191
+ input_ids = []
192
+ offset = 0
193
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
+ offset = 1
195
+ input_ids.append(prompt_chunks[0][0])
196
+
197
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
+ input_ids.extend(x[offset:])
199
+
200
+ if return_tensors is not None:
201
+ if return_tensors == 'pt':
202
+ return torch.tensor(input_ids, dtype=torch.long)
203
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
+ return input_ids
205
+
206
+
207
+ def get_model_name_from_path(model_path):
208
+ model_path = model_path.strip("/")
209
+ model_paths = model_path.split("/")
210
+ if model_paths[-1].startswith('checkpoint-'):
211
+ return model_paths[-2] + "_" + model_paths[-1]
212
+ else:
213
+ return model_paths[-1]
214
+
215
+ class KeywordsStoppingCriteria(StoppingCriteria):
216
+ def __init__(self, keywords, tokenizer, input_ids):
217
+ self.keywords = keywords
218
+ self.keyword_ids = []
219
+ self.max_keyword_len = 0
220
+ for keyword in keywords:
221
+ cur_keyword_ids = tokenizer(keyword).input_ids
222
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
+ cur_keyword_ids = cur_keyword_ids[1:]
224
+ if len(cur_keyword_ids) > self.max_keyword_len:
225
+ self.max_keyword_len = len(cur_keyword_ids)
226
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
+ self.tokenizer = tokenizer
228
+ self.start_len = input_ids.shape[1]
229
+
230
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
+ for keyword_id in self.keyword_ids:
234
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
235
+ return True
236
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
237
+ for keyword in self.keywords:
238
+ if keyword in outputs:
239
+ return True
240
+ return False
241
+
242
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
243
+ outputs = []
244
+ for i in range(output_ids.shape[0]):
245
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
246
+ return all(outputs)
llavavid/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # try:
2
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3
+ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5
+ # except:
6
+ # pass
llavavid/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llavavid import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llavavid/model/builder.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llavavid.model import *
23
+ from llavavid.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", overwrite_config=None):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ # import pdb;pdb.set_trace()
30
+ if load_8bit:
31
+ kwargs["load_in_8bit"] = True
32
+ elif load_4bit:
33
+ kwargs["load_in_4bit"] = True
34
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
35
+ else:
36
+ kwargs["torch_dtype"] = torch.float16
37
+
38
+ if "llava" in model_name.lower():
39
+ # Load LLaVA model
40
+ if "lora" in model_name.lower() and model_base is None:
41
+ warnings.warn(
42
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
43
+ )
44
+ if "lora" in model_name.lower() and model_base is not None:
45
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
46
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
47
+ print("Loading LLaVA from base model...")
48
+ if "mixtral" in model_name.lower():
49
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
50
+ else:
51
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
52
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
53
+ if model.lm_head.weight.shape[0] != token_num:
54
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
56
+
57
+ print("Loading additional LLaVA weights...")
58
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
59
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
60
+ else:
61
+ # this is probably from HF Hub
62
+ from huggingface_hub import hf_hub_download
63
+
64
+ def load_from_hf(repo_id, filename, subfolder=None):
65
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
66
+ return torch.load(cache_file, map_location="cpu")
67
+
68
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
69
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
70
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
71
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
72
+ model.load_state_dict(non_lora_trainables, strict=False)
73
+
74
+ from peft import PeftModel
75
+
76
+ print("Loading LoRA weights...")
77
+ model = PeftModel.from_pretrained(model, model_path)
78
+ print("Merging LoRA weights...")
79
+ model = model.merge_and_unload()
80
+ print("Model is loaded...")
81
+ elif model_base is not None:
82
+ # this may be mm projector only
83
+ print("Loading LLaVA from base model...")
84
+ if "mpt" in model_name.lower().replace("prompt", ""):
85
+ if not os.path.isfile(os.path.join(model_path, "configuration_mpt.py")):
86
+ shutil.copyfile(os.path.join(model_base, "configuration_mpt.py"), os.path.join(model_path, "configuration_mpt.py"))
87
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
88
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
89
+ model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
90
+ else:
91
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
92
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
93
+ if overwrite_config is not None:
94
+ print(f"Overwriting config with {overwrite_config}")
95
+ for k, v in overwrite_config.items():
96
+ setattr(cfg_pretrained, k, v)
97
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
98
+
99
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
100
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
101
+ model.load_state_dict(mm_projector_weights, strict=False)
102
+ else:
103
+ if "mpt" in model_name.lower().replace("prompt", ""):
104
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
105
+ model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
106
+ elif "mixtral" in model_name.lower() and "vicuna" not in model_name.lower() and "mistral" not in model_name.lower():
107
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
108
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
109
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
110
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
111
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
112
+ # import pdb;pdb.set_trace()
113
+ if overwrite_config is not None:
114
+ print(f"Overwriting config with {overwrite_config}")
115
+ for k, v in overwrite_config.items():
116
+ setattr(cfg_pretrained, k, v)
117
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
118
+ else:
119
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
120
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
121
+ if overwrite_config is not None:
122
+ print(f"Overwriting config with {overwrite_config}")
123
+ for k, v in overwrite_config.items():
124
+ setattr(cfg_pretrained, k, v)
125
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
126
+ else:
127
+ # Load language model
128
+ if model_base is not None:
129
+ # PEFT model
130
+ from peft import PeftModel
131
+
132
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
133
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
134
+ print(f"Loading LoRA weights from {model_path}")
135
+ model = PeftModel.from_pretrained(model, model_path)
136
+ print(f"Merging weights")
137
+ model = model.merge_and_unload()
138
+ print("Convert to FP16...")
139
+ model.to(torch.float16)
140
+ else:
141
+ use_fast = False
142
+ if "mpt" in model_name.lower().replace("prompt", ""):
143
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
144
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
145
+ else:
146
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
147
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
148
+
149
+ image_processor = None
150
+
151
+ assert "llava" in model_name.lower(), "Only LLaVA models are supported for video chatbot."
152
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
153
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
154
+ if mm_use_im_patch_token:
155
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
156
+ if mm_use_im_start_end:
157
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
158
+ model.resize_token_embeddings(len(tokenizer))
159
+
160
+ vision_tower = model.get_vision_tower()
161
+ if not vision_tower.is_loaded:
162
+ vision_tower.load_model(device_map=device_map)
163
+ if device_map != "auto":
164
+ vision_tower.to(device="cuda", dtype=torch.float16)
165
+ image_processor = vision_tower.image_processor
166
+
167
+ if hasattr(model.config, "max_sequence_length"):
168
+ context_len = model.config.max_sequence_length
169
+ else:
170
+ context_len = 2048
171
+
172
+ return tokenizer, model, image_processor, context_len
llavavid/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llavavid.model import *
10
+ from llavavid.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
llavavid/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
22
+
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from transformers.generation.utils import GenerateOutput
25
+
26
+ from llavavid.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
27
+
28
+ import torch.distributed as dist
29
+
30
+
31
+
32
+
33
+
34
+ class LlavaConfig(LlamaConfig):
35
+ model_type = "llava_llama"
36
+
37
+
38
+
39
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
40
+ config_class = LlavaConfig
41
+
42
+ def __init__(self, config: LlamaConfig):
43
+ super(LlavaLlamaModel, self).__init__(config)
44
+
45
+
46
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
47
+ config_class = LlavaConfig
48
+
49
+ def __init__(self, config):
50
+ # import pdb; pdb.set_trace()
51
+ LlamaForCausalLM.__init__(self, config)
52
+ self.model = LlavaLlamaModel(config)
53
+
54
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
55
+
56
+ # Initialize weights and apply final processing
57
+ self.post_init()
58
+
59
+ def get_model(self):
60
+ return self.model
61
+
62
+ def forward(
63
+ self,
64
+ input_ids: torch.LongTensor = None,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ position_ids: Optional[torch.LongTensor] = None,
67
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
68
+ inputs_embeds: Optional[torch.FloatTensor] = None,
69
+ labels: Optional[torch.LongTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ images: Optional[torch.FloatTensor] = None,
74
+ prompts: Optional[List[str]] = None,
75
+ modalities: Optional[List[str]] = None,
76
+ image_sizes: Optional[List[List[int]]] = None,
77
+ return_dict: Optional[bool] = None,
78
+ cache_position: Optional[bool] = None,
79
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
80
+
81
+ if inputs_embeds is None:
82
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(
83
+ input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, prompts
84
+ )
85
+
86
+ # import pdb; pdb.set_trace()
87
+ return super().forward(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ position_ids=position_ids,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ labels=labels,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=output_hidden_states,
97
+ return_dict=return_dict,
98
+ )
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ inputs: Optional[torch.Tensor] = None,
104
+ images: Optional[torch.Tensor] = None,
105
+ image_sizes: Optional[torch.Tensor] = None,
106
+ **kwargs,
107
+ ) -> Union[GenerateOutput, torch.LongTensor]:
108
+ modalities = kwargs.pop("modalities", None)
109
+ position_ids = kwargs.pop("position_ids", None)
110
+ attention_mask = kwargs.pop("attention_mask", None)
111
+ if "inputs_embeds" in kwargs:
112
+ raise NotImplementedError("`inputs_embeds` is not supported")
113
+
114
+ if images is not None:
115
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
116
+ else:
117
+ inputs_embeds = self.get_model().embed_tokens(inputs)
118
+
119
+ # import pdb; pdb.set_trace()
120
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
121
+
122
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
123
+ images = kwargs.pop("images", None)
124
+ image_sizes = kwargs.pop("image_sizes", None)
125
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
126
+ if images is not None:
127
+ inputs["images"] = images
128
+ if image_sizes is not None:
129
+ inputs["image_sizes"] = image_sizes
130
+ return inputs
131
+
132
+
133
+ if LlavaConfig.model_type == "llava":
134
+ LlavaConfig.model_type = "llava_llama" # directly set to llava_dev to avoid conflict with HF's llava
135
+
136
+ AutoConfig.register("llava_llama", LlavaConfig)
137
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llavavid/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ MistralConfig, MistralModel, MistralForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+ from transformers.generation.utils import GenerateOutput
27
+
28
+ from llavavid.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+
30
+
31
+ class LlavaMistralConfig(MistralConfig):
32
+ model_type = "llava_mistral"
33
+
34
+
35
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
36
+ config_class = LlavaMistralConfig
37
+
38
+ def __init__(self, config: MistralConfig):
39
+ super(LlavaMistralModel, self).__init__(config)
40
+
41
+
42
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaMistralConfig
44
+
45
+ def __init__(self, config):
46
+ super(MistralForCausalLM, self).__init__(config)
47
+ self.model = LlavaMistralModel(config)
48
+
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
158
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llavavid/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, \
21
+ MptConfig, MptForCausalLM, MptModel
22
+ from llavavid.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23
+
24
+
25
+ class LlavaMptConfig(MptConfig):
26
+ model_type = "llava_mpt"
27
+
28
+
29
+ class LlavaMptModel(LlavaMetaModel, MptModel):
30
+ config_class = LlavaMptConfig
31
+
32
+ def __init__(self, config: MptConfig):
33
+ config.hidden_size = config.d_model
34
+ super(LlavaMptModel, self).__init__(config)
35
+
36
+ def embed_tokens(self, x):
37
+ return self.wte(x)
38
+
39
+
40
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41
+ config_class = LlavaMptConfig
42
+ supports_gradient_checkpointing = True
43
+
44
+ def __init__(self, config):
45
+ super(MptForCausalLM, self).__init__(config)
46
+
47
+ self.transformer = LlavaMptModel(config)
48
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.transformer
55
+
56
+ def _set_gradient_checkpointing(self, module, value=False):
57
+ if isinstance(module, LlavaMptModel):
58
+ module.gradient_checkpointing = value
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ inputs_embeds: Optional[torch.Tensor] = None,
66
+ labels: Optional[torch.Tensor] = None,
67
+ use_cache: Optional[bool] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ images=None):
72
+
73
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74
+
75
+ return super().forward(
76
+ input_ids,
77
+ past_key_values=past_key_values,
78
+ attention_mask=attention_mask,
79
+ inputs_embeds=inputs_embeds,
80
+ labels=labels,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ )
86
+
87
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88
+ images = kwargs.pop("images", None)
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91
+ )
92
+ _inputs['images'] = images
93
+ return _inputs
94
+
95
+
96
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
97
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llavavid/model/llava_arch.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch.distributed as dist
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from .multimodal_encoder.builder import build_vision_tower
24
+ from .multimodal_resampler.builder import build_vision_resampler
25
+ from .multimodal_projector.builder import build_vision_projector
26
+
27
+ from llavavid.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
28
+
29
+ from llavavid.mm_utils import get_anyres_image_grid_shape
30
+
31
+ import math
32
+
33
+
34
+ class LlavaMetaModel:
35
+
36
+ def __init__(self, config):
37
+ super(LlavaMetaModel, self).__init__(config)
38
+ # import pdb; pdb.set_trace()
39
+ if hasattr(config, "mm_vision_tower"):
40
+ # import pdb; pdb.set_trace()
41
+ self.vision_tower = build_vision_tower(config, delay_load=True)
42
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
43
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
44
+ self.vision_resampler.mm_projector = self.mm_projector
45
+
46
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
47
+ self.image_newline = nn.Parameter(
48
+ torch.empty(config.hidden_size, dtype=self.dtype)
49
+ )
50
+
51
+ def get_vision_tower(self):
52
+ vision_tower = getattr(self, "vision_tower", None)
53
+ if type(vision_tower) is list:
54
+ vision_tower = vision_tower[0]
55
+ return vision_tower
56
+
57
+ def initialize_vision_modules(self, model_args, fsdp=None):
58
+ vision_tower = model_args.vision_tower
59
+ mm_vision_select_layer = model_args.mm_vision_select_layer
60
+ mm_vision_select_feature = model_args.mm_vision_select_feature
61
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
62
+ mm_patch_merge_type = model_args.mm_patch_merge_type
63
+
64
+ self.config.mm_vision_tower = vision_tower
65
+ if self.get_vision_tower() is None:
66
+ vision_tower = build_vision_tower(model_args)
67
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
68
+ for k, v in vision_resampler.config.items():
69
+ setattr(self.config, k, v)
70
+
71
+ if fsdp is not None and len(fsdp) > 0:
72
+ self.vision_tower = [vision_tower]
73
+ self.vision_resampler = [vision_resampler]
74
+ else:
75
+ self.vision_tower = vision_tower
76
+ self.vision_resampler = vision_resampler
77
+ else:
78
+ if fsdp is not None and len(fsdp) > 0:
79
+ vision_resampler = self.vision_resampler[0]
80
+ vision_tower = self.vision_tower[0]
81
+ else:
82
+ vision_resampler = self.vision_resampler
83
+ vision_tower = self.vision_tower
84
+ vision_tower.load_model()
85
+
86
+ # In case it is frozen by LoRA
87
+ for p in self.vision_resampler.parameters():
88
+ p.requires_grad = True
89
+
90
+ self.config.use_mm_proj = True
91
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
92
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
93
+ self.config.mm_vision_select_layer = mm_vision_select_layer
94
+ self.config.mm_vision_select_feature = mm_vision_select_feature
95
+ self.config.mm_patch_merge_type = mm_patch_merge_type
96
+
97
+ self.config.patchify_video_feature = getattr(model_args, "patchify_video_feature", False)
98
+
99
+ if getattr(self, "mm_projector", None) is None:
100
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
101
+
102
+ if "unpad" in mm_patch_merge_type:
103
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
104
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
105
+ else:
106
+ # In case it is frozen by LoRA
107
+ for p in self.mm_projector.parameters():
108
+ p.requires_grad = True
109
+
110
+ if pretrain_mm_mlp_adapter is not None:
111
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
112
+
113
+ def get_w(weights, keyword):
114
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
115
+
116
+ # import pdb; pdb.set_trace()
117
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
118
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
119
+ print(incompatible_keys)
120
+
121
+ num_trainable_parameters = sum(p.numel() for p in self.vision_resampler.parameters() if p.requires_grad) / 1e6
122
+ print(f"Number of trainable parameters in vision resampler: {num_trainable_parameters}M")
123
+
124
+
125
+
126
+ def unpad_image(tensor, original_size):
127
+ """
128
+ Unpads a PyTorch tensor of a padded and resized image.
129
+
130
+ Args:
131
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
132
+ original_size (tuple): The original size of the image (height, width).
133
+
134
+ Returns:
135
+ torch.Tensor: The unpadded image tensor.
136
+ """
137
+ original_width, original_height = original_size
138
+ current_height, current_width = tensor.shape[1:]
139
+
140
+ # Compute aspect ratios
141
+ original_aspect_ratio = original_width / original_height
142
+ current_aspect_ratio = current_width / current_height
143
+
144
+ # Determine padding size and direction
145
+ if original_aspect_ratio > current_aspect_ratio:
146
+ # Padding was added to the height
147
+ scale_factor = current_width / original_width
148
+ new_height = int(original_height * scale_factor)
149
+ padding = (current_height - new_height) // 2
150
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
151
+ else:
152
+ # Padding was added to the width
153
+ scale_factor = current_height / original_height
154
+ new_width = int(original_width * scale_factor)
155
+ padding = (current_width - new_width) // 2
156
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
157
+
158
+ return unpadded_tensor
159
+
160
+
161
+ class LlavaMetaForCausalLM(ABC):
162
+
163
+ @abstractmethod
164
+ def get_model(self):
165
+ pass
166
+
167
+ def get_vision_tower(self):
168
+ return self.get_model().get_vision_tower()
169
+
170
+ def encode_images(self, images, input_modality="image", prompts=None, image_counts=None, long_video=False):
171
+ image_features = self.get_model().get_vision_tower()(images)
172
+
173
+
174
+ if input_modality == "video":
175
+ image_features = self.get_model().vision_resampler(image_features, images=images)
176
+
177
+ image_features = self.get_model().mm_projector(image_features)
178
+
179
+ return image_features
180
+
181
+ def update_prompt(self, prompts=None):
182
+ self.prompts = prompts
183
+
184
+ def prepare_inputs_labels_for_multimodal(
185
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
186
+ images, modalities, image_sizes=None,prompts=None
187
+ ):
188
+ vision_tower = self.get_vision_tower()
189
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
190
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
191
+
192
+ # pre-process images for long video
193
+ if images[0].shape[-1] > 1000:
194
+ long_video = True
195
+ else:
196
+ long_video = False
197
+
198
+ if isinstance(modalities, str):
199
+ modalities = [modalities]
200
+
201
+ image_idx_in_batch = []
202
+ for _ in range(len(modalities)):
203
+ # if modalities[_] != "video":
204
+ if modalities[_] == "image":
205
+ image_idx_in_batch.append(_)
206
+ # import pdb; pdb.set_trace()
207
+ if type(images) is list or images.ndim == 5:
208
+ # not reseshape for long video
209
+
210
+ if not long_video:
211
+ images_list = []
212
+ for image in images:
213
+ if image.ndim == 4:
214
+ images_list.append(image)
215
+ else:
216
+ images_list.append(image.unsqueeze(0))
217
+
218
+ image_counts = [image.shape[0] for image in images_list]
219
+ try:
220
+ concat_images = torch.cat(images_list, dim=0)
221
+ except Exception as e:
222
+ print(e)
223
+ for _ in images_list:
224
+ print(_.shape)
225
+ import pdb
226
+
227
+ pdb.set_trace()
228
+
229
+ image_features_samplers = self.encode_images(concat_images, "video", prompts, image_counts, long_video=long_video)
230
+ image_features = self.encode_images(concat_images) # , prompts)#, image_counts, long_video=long_video)
231
+ split_sizes = [image.shape[0] for image in images_list]
232
+ image_features = torch.split(image_features, split_sizes, dim=0)
233
+
234
+ image_features_samplers = torch.split(image_features_samplers, split_sizes, dim=0)
235
+ image_features = [image_features[i] if i in image_idx_in_batch else image_features_samplers[i] for i in range(len(images_list))]
236
+
237
+ # import pdb; pdb.set_trace()
238
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
239
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
240
+ if mm_patch_merge_type == "flat":
241
+ new_image_features = []
242
+ for image_idx, image_feature in enumerate(image_features):
243
+ new_image_features.append(image_feature.flatten(0, 1))
244
+ elif mm_patch_merge_type.startswith("spatial"):
245
+ new_image_features = []
246
+ for image_idx, image_feature in enumerate(image_features):
247
+ # FIXME: now assume the image is square, and split to 2x2 patches
248
+ # num_patches = h * w, where h = w = sqrt(num_patches)
249
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
250
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
251
+
252
+ if image_idx not in image_idx_in_batch:
253
+ new_image_features.append(image_feature.flatten(0, 1))
254
+ continue
255
+
256
+ if image_feature.shape[0] > 1:
257
+ base_image_feature = image_feature[0]
258
+ image_feature = image_feature[1:]
259
+ height = width = self.get_vision_tower().num_patches_per_side
260
+ # import pdb; pdb.set_trace()
261
+ assert height * width == base_image_feature.shape[0]
262
+ if image_aspect_ratio == "anyres":
263
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
264
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
265
+ else:
266
+ image_feature = image_feature.view(2, 2, height, width, -1)
267
+ if "maxpool2x2" in mm_patch_merge_type:
268
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
269
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
270
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
271
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
272
+ elif "unpad" in mm_patch_merge_type:
273
+ # import pdb; pdb.set_trace()
274
+ #
275
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
276
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
277
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
278
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
279
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
280
+ else:
281
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
282
+ image_feature = image_feature.flatten(0, 3)
283
+ if "nobase" in mm_patch_merge_type:
284
+ pass
285
+ else:
286
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
287
+ else:
288
+ image_feature = image_feature[0]
289
+ if "unpad" in mm_patch_merge_type:
290
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
291
+ new_image_features.append(image_feature)
292
+ image_features = new_image_features
293
+ else:
294
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
295
+ else:
296
+ image_features = self.encode_images(images)
297
+
298
+ # TODO: image start / end is not implemented here to support pretraining.
299
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
300
+ raise NotImplementedError
301
+
302
+ # Let's just add dummy tensors if they do not exist,
303
+ # it is a headache to deal with None all the time.
304
+ # But it is not ideal, and if you have a better idea,
305
+ # please open an issue / submit a PR, thanks.
306
+ _labels = labels
307
+ _position_ids = position_ids
308
+ _attention_mask = attention_mask
309
+ if attention_mask is None:
310
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
311
+ else:
312
+ attention_mask = attention_mask.bool()
313
+ if position_ids is None:
314
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
315
+ if labels is None:
316
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
317
+
318
+ # remove the padding using attention_mask -- FIXME
319
+ _input_ids = input_ids
320
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
321
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
322
+
323
+ new_input_embeds = []
324
+ new_labels = []
325
+ cur_image_idx = 0
326
+ for batch_idx, cur_input_ids in enumerate(input_ids):
327
+ # import pdb; pdb.set_trace()
328
+ cur_labels = labels[batch_idx]
329
+
330
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
331
+ if num_images == 0:
332
+ cur_image_features = image_features[cur_image_idx]
333
+ if cur_image_features.ndim == 3:
334
+ cur_image_features = cur_image_features.squeeze(0)
335
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
336
+ try:
337
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
338
+ except:
339
+ import pdb
340
+ pdb.set_trace()
341
+ new_input_embeds.append(cur_input_embeds)
342
+ new_labels.append(labels[batch_idx])
343
+ cur_image_idx += 1
344
+ continue
345
+
346
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
347
+ cur_input_ids_noim = []
348
+ cur_labels_noim = []
349
+ for i in range(len(image_token_indices) - 1):
350
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
351
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
352
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
353
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
354
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
355
+ cur_new_input_embeds = []
356
+ cur_new_labels = []
357
+
358
+ # import pdb; pdb.set_trace()
359
+ for i in range(num_images + 1):
360
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
361
+ cur_new_labels.append(cur_labels_noim[i])
362
+ if i < num_images:
363
+ # import pdb; pdb.set_trace()
364
+ cur_image_features = image_features[cur_image_idx]
365
+ if cur_image_features.ndim == 3:
366
+ hidden_size = cur_image_features.shape[-1]
367
+ cur_image_features = cur_image_features.reshape(-1, hidden_size)
368
+ cur_image_idx += 1
369
+ cur_new_input_embeds.append(cur_image_features)
370
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
371
+
372
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
373
+
374
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
375
+ # import pdb; pdb.set_trace()
376
+ cur_new_labels = torch.cat(cur_new_labels)
377
+
378
+ new_input_embeds.append(cur_new_input_embeds)
379
+ # import pdb; pdb.set_trace()
380
+ new_labels.append(cur_new_labels)
381
+
382
+ # Truncate sequences to max length as image embeddings can make the sequence longer
383
+ # import pdb; pdb.set_trace()
384
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
385
+ if tokenizer_model_max_length is not None:
386
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
387
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
388
+ # import pdb; pdb.set_trace()
389
+
390
+ # Combine them
391
+ max_len = max(x.shape[0] for x in new_input_embeds)
392
+ batch_size = len(new_input_embeds)
393
+
394
+ new_input_embeds_padded = []
395
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
396
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
397
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
398
+
399
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
400
+ cur_len = cur_new_embed.shape[0]
401
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
402
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
403
+ if cur_len > 0:
404
+ new_labels_padded[i, -cur_len:] = cur_new_labels
405
+ attention_mask[i, -cur_len:] = True
406
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
407
+ else:
408
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
409
+ if cur_len > 0:
410
+ new_labels_padded[i, :cur_len] = cur_new_labels
411
+ attention_mask[i, :cur_len] = True
412
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
413
+
414
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
415
+
416
+ if _labels is None:
417
+ new_labels = None
418
+ else:
419
+ new_labels = new_labels_padded
420
+
421
+ if _attention_mask is None:
422
+ attention_mask = None
423
+ else:
424
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
425
+
426
+ if _position_ids is None:
427
+ position_ids = None
428
+
429
+ gpu_rank = new_input_embeds.device.index if new_input_embeds.is_cuda else None
430
+
431
+ # if dist.is_available() and dist.is_initialized():
432
+ # dist.barrier()
433
+ # print(f"{gpu_rank}\n")
434
+ # dist.barrier()
435
+ # print(f"{gpu_rank}_{modalities}_{new_input_embeds.shape}\n")
436
+ # print(f"all finished\n")
437
+ # import pdb; pdb.set_trace()
438
+
439
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
440
+
441
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
442
+ if model_args.mm_use_im_patch_token:
443
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
444
+ self.resize_token_embeddings(len(tokenizer))
445
+
446
+ if model_args.mm_use_im_start_end:
447
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
448
+ self.resize_token_embeddings(len(tokenizer))
449
+
450
+ if num_new_tokens > 0:
451
+ input_embeddings = self.get_input_embeddings().weight.data
452
+ output_embeddings = self.get_output_embeddings().weight.data
453
+
454
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
455
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
456
+
457
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
458
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
459
+
460
+ if model_args.tune_mm_mlp_adapter:
461
+ for p in self.get_input_embeddings().parameters():
462
+ p.requires_grad = True
463
+ for p in self.get_output_embeddings().parameters():
464
+ p.requires_grad = False
465
+
466
+ if model_args.pretrain_mm_mlp_adapter:
467
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
468
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
469
+ assert num_new_tokens == 2
470
+ if input_embeddings.shape == embed_tokens_weight.shape:
471
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
472
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
473
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
474
+ else:
475
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
476
+ elif model_args.mm_use_im_patch_token:
477
+ if model_args.tune_mm_mlp_adapter:
478
+ for p in self.get_input_embeddings().parameters():
479
+ p.requires_grad = False
480
+ for p in self.get_output_embeddings().parameters():
481
+ p.requires_grad = False
llavavid/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llavavid.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llavavid/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from llavavid.model.multimodal_encoder.clip_encoder import CLIPVisionTower
3
+
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+ is_absolute_path_exists = os.path.exists(vision_tower)
8
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
9
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10
+
11
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llavavid/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
20
+ # TODO: better detector is needed.
21
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
22
+ self.load_model()
23
+ else:
24
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
25
+
26
+ def load_model(self, device_map=None):
27
+ if self.is_loaded:
28
+ print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
29
+ return
30
+
31
+ # import pdb; pdb.set_trace()
32
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
33
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
34
+ self.vision_tower.requires_grad_(False)
35
+
36
+ self.is_loaded = True
37
+
38
+ def feature_select(self, image_forward_outs):
39
+ select_feature_type = self.select_feature
40
+
41
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
42
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
43
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
44
+ select_feature_type = select_feature_type.replace("slicefour_", "")
45
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
46
+ select_layers = [-2, -5, -8, -11, 6]
47
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
48
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
49
+ else:
50
+ image_features = image_forward_outs.hidden_states[self.select_layer]
51
+
52
+ if select_feature_type == "patch":
53
+ image_features = image_features[:, 1:]
54
+ elif select_feature_type == "cls_patch":
55
+ image_features = image_features
56
+ else:
57
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
58
+ return image_features
59
+
60
+ def forward(self, images):
61
+ if type(images) is list:
62
+ image_features = []
63
+ for image in images:
64
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
65
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
66
+ image_features.append(image_feature)
67
+ else:
68
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
69
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
70
+
71
+ return image_features
72
+
73
+ @property
74
+ def dummy_feature(self):
75
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
76
+
77
+ @property
78
+ def dtype(self):
79
+ return self.vision_tower.dtype
80
+
81
+ @property
82
+ def device(self):
83
+ return self.vision_tower.device
84
+
85
+ @property
86
+ def config(self):
87
+ if self.is_loaded:
88
+ return self.vision_tower.config
89
+ else:
90
+ return self.cfg_only
91
+
92
+ @property
93
+ def hidden_size(self):
94
+ _hidden_size = self.config.hidden_size
95
+ if "slicefour" in self.select_feature:
96
+ _hidden_size *= 4
97
+ if "slice_m25811_f6" in self.select_feature:
98
+ _hidden_size *= 5
99
+ return _hidden_size
100
+
101
+ @property
102
+ def num_patches_per_side(self):
103
+ return self.config.image_size // self.config.patch_size
104
+
105
+ @property
106
+ def num_patches(self):
107
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
108
+ if "cls_patch" in self.select_feature:
109
+ _num_patches += 1
110
+ return _num_patches
llavavid/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ if projector_type == 'identity':
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f'Unknown projector type: {projector_type}')
llavavid/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .spatial_pool import SpatialPool
4
+
5
+
6
+ class IdentityMap(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_resampler_type": None}
16
+
17
+
18
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
19
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
20
+ if resampler_type == "spatial_pool":
21
+ return SpatialPool(model_args, **kwargs)
22
+ elif resampler_type is None:
23
+ return IdentityMap()
24
+
25
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
llavavid/model/multimodal_resampler/spatial_pool.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+
6
+ class SpatialPool(nn.Module):
7
+ def __init__(self, model_args, vision_tower):
8
+ super().__init__()
9
+
10
+ self.mode = model_args.mm_spatial_pool_mode
11
+ self.stride = model_args.mm_spatial_pool_stride
12
+ # import pdb; pdb.set_trace()
13
+ self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
14
+
15
+ if self.mode == "average":
16
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
17
+ elif self.mode == "max":
18
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
19
+ elif self.mode == "conv":
20
+ self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
21
+ else:
22
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
23
+
24
+ def forward(self, image_features, images, *args, **kwargs):
25
+ # import pdb; pdb.set_trace()
26
+ ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
27
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
28
+
29
+ B, _, F = image_features.shape
30
+
31
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
32
+ image_features_spatial_pool = self.pool(image_features_spatial)
33
+
34
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
35
+
36
+ @property
37
+ def config(self):
38
+ return {
39
+ "mm_resampler_type": "spatial_pool",
40
+ "mm_spatial_pool_stride": self.stride,
41
+ "mm_spatial_pool_mode": self.mode,
42
+ "mm_spatial_pool_out_channels": self.out_channels,
43
+ }
44
+
45
+ @property
46
+ def hidden_size(self):
47
+ return self.out_channels
llavavid/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)