zhaode commited on
Commit
fd52dc4
·
verified ·
1 Parent(s): d26d2fa

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -11
  2. README.md +53 -0
  3. added_tokens.json +24 -0
  4. config.json +44 -0
  5. generation_config.json +6 -0
  6. llava/__init__.py +1 -0
  7. llava/__pycache__/__init__.cpython-312.pyc +0 -0
  8. llava/__pycache__/constants.cpython-312.pyc +0 -0
  9. llava/__pycache__/conversation.cpython-312.pyc +0 -0
  10. llava/__pycache__/mm_utils.cpython-312.pyc +0 -0
  11. llava/__pycache__/utils.cpython-312.pyc +0 -0
  12. llava/constants.py +13 -0
  13. llava/conversation.py +479 -0
  14. llava/mm_utils.py +250 -0
  15. llava/model/__init__.py +8 -0
  16. llava/model/__pycache__/__init__.cpython-312.pyc +0 -0
  17. llava/model/__pycache__/builder.cpython-312.pyc +0 -0
  18. llava/model/__pycache__/llava_arch.cpython-312.pyc +0 -0
  19. llava/model/apply_delta.py +48 -0
  20. llava/model/builder.py +181 -0
  21. llava/model/consolidate.py +29 -0
  22. llava/model/language_model/__pycache__/llava_llama.cpython-312.pyc +0 -0
  23. llava/model/language_model/__pycache__/llava_mistral.cpython-312.pyc +0 -0
  24. llava/model/language_model/__pycache__/llava_mpt.cpython-312.pyc +0 -0
  25. llava/model/language_model/__pycache__/llava_qwen.cpython-312.pyc +0 -0
  26. llava/model/language_model/llava_llama.py +159 -0
  27. llava/model/language_model/llava_mistral.py +158 -0
  28. llava/model/language_model/llava_mpt.py +97 -0
  29. llava/model/language_model/llava_qwen.py +160 -0
  30. llava/model/llava_arch.py +376 -0
  31. llava/model/make_delta.py +52 -0
  32. llava/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc +0 -0
  33. llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-312.pyc +0 -0
  34. llava/model/multimodal_encoder/__pycache__/mobileclip_encoder.cpython-312.pyc +0 -0
  35. llava/model/multimodal_encoder/builder.py +19 -0
  36. llava/model/multimodal_encoder/clip_encoder.py +166 -0
  37. llava/model/multimodal_encoder/mobileclip/__init__.py +87 -0
  38. llava/model/multimodal_encoder/mobileclip/__pycache__/__init__.cpython-312.pyc +0 -0
  39. llava/model/multimodal_encoder/mobileclip/__pycache__/mci.cpython-312.pyc +0 -0
  40. llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json +20 -0
  41. llava/model/multimodal_encoder/mobileclip/mci.py +1479 -0
  42. llava/model/multimodal_encoder/mobileclip_encoder.py +116 -0
  43. llava/model/multimodal_projector/__pycache__/builder.cpython-312.pyc +0 -0
  44. llava/model/multimodal_projector/builder.py +35 -0
  45. llava/model/utils.py +20 -0
  46. llava/serve/__init__.py +0 -0
  47. llava/serve/cli.py +126 -0
  48. llava/serve/controller.py +298 -0
  49. llava/serve/examples/extreme_ironing.jpg +0 -0
  50. llava/serve/examples/waterview.jpg +0 -0
.gitattributes CHANGED
@@ -1,35 +1,38 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.mnn filter=lfs diff=lfs merge=lfs -text
36
+ *.mnn.* filter=lfs diff=lfs merge=lfs -text
37
+ *.weight filter=lfs diff=lfs merge=lfs -text
38
+
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: apple
4
+ license_link: https://github.com/apple/ml-fastvlm/blob/main/LICENSE
5
+ language:
6
+ - en
7
+ pipeline_tag: image-text-to-text
8
+ tags:
9
+ - multimodal
10
+ library_name: transformers
11
+ ---
12
+
13
+ # FastVLM-7B-Stage3
14
+
15
+ ## Introduction
16
+
17
+ This is FastVLM-7B-Stage3, a multimodal language model that can understand things visually, being agentic, understand long videos and capture events, and generate structured outputs.
18
+
19
+ This model is exported from Github [apple/ml-fastvlm](https://github.com/apple/ml-fastvlm).
20
+
21
+ Model's weight: [llava-fastvithd_7b_stage3.zip](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip).
22
+
23
+
24
+ ### Usage
25
+ ```python
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+
28
+ model_id = 'FastVLM-7B-Stage3'
29
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
30
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', trust_remote_code=True)
31
+ ```
32
+
33
+ ### Export to MNN
34
+ ```python
35
+ git clone https://github.com/alibaba/MNN
36
+ cd MNN/transformers/llm/export
37
+ python llmexport.py --path /path/to/FastVLM-7B-Stage3 --export mnn
38
+ ```
39
+
40
+
41
+ ## Citation
42
+
43
+ If you find our work helpful, feel free to give us a cite.
44
+
45
+ ```
46
+ @InProceedings{fastvlm2025,
47
+ author = {Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari},
48
+ title = {FastVLM: Efficient Vision Encoding for Vision Language Models},
49
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
50
+ month = {June},
51
+ year = {2025},
52
+ }{2023}
53
+ ```
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlavaQwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151645,
8
+ "freeze_mm_mlp_adapter": false,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 3584,
11
+ "image_aspect_ratio": "pad",
12
+ "image_grid_pinpoints": null,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 18944,
15
+ "max_position_embeddings": 32768,
16
+ "max_window_layers": 28,
17
+ "mm_hidden_size": 3072,
18
+ "mm_patch_merge_type": "flat",
19
+ "mm_projector_lr": null,
20
+ "mm_projector_type": "mlp2x_gelu",
21
+ "mm_use_im_patch_token": false,
22
+ "mm_use_im_start_end": false,
23
+ "mm_vision_select_feature": "patch",
24
+ "mm_vision_select_layer": -2,
25
+ "mm_vision_tower": "mobileclip_l_1024",
26
+ "model_type": "llava_qwen2",
27
+ "num_attention_heads": 28,
28
+ "num_hidden_layers": 28,
29
+ "num_key_value_heads": 4,
30
+ "rms_norm_eps": 1e-06,
31
+ "rope_theta": 1000000.0,
32
+ "sliding_window": 131072,
33
+ "tie_word_embeddings": false,
34
+ "tokenizer_model_max_length": 8192,
35
+ "tokenizer_padding_side": "right",
36
+ "torch_dtype": "bfloat16",
37
+ "transformers_version": "4.39.3",
38
+ "tune_mm_mlp_adapter": false,
39
+ "unfreeze_mm_vision_tower": true,
40
+ "use_cache": true,
41
+ "use_mm_proj": true,
42
+ "use_sliding_window": false,
43
+ "vocab_size": 152064
44
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "temperature": null,
4
+ "top_p": null,
5
+ "transformers_version": "4.39.3"
6
+ }
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM
llava/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (256 Bytes). View file
 
llava/__pycache__/constants.cpython-312.pyc ADDED
Binary file (568 Bytes). View file
 
llava/__pycache__/conversation.cpython-312.pyc ADDED
Binary file (17.1 kB). View file
 
llava/__pycache__/mm_utils.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
llava/__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.64 kB). View file
 
llava/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
llava/conversation.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ QWEN_2 = auto() # fix: add qwen2
17
+ CHATML = auto()
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class Conversation:
22
+ """A class that keeps all conversation history."""
23
+ system: str
24
+ roles: List[str]
25
+ messages: List[List[str]]
26
+ offset: int
27
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
28
+ sep: str = "###"
29
+ sep2: str = None
30
+ version: str = "Unknown"
31
+
32
+ skip_next: bool = False
33
+
34
+ def get_prompt(self):
35
+ messages = self.messages
36
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
37
+ messages = self.messages.copy()
38
+ init_role, init_msg = messages[0].copy()
39
+ init_msg = init_msg[0].replace("<image>", "").strip()
40
+ if 'mmtag' in self.version:
41
+ messages[0] = (init_role, init_msg)
42
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
43
+ messages.insert(1, (self.roles[1], "Received."))
44
+ else:
45
+ messages[0] = (init_role, "<image>\n" + init_msg)
46
+
47
+ if self.sep_style == SeparatorStyle.SINGLE:
48
+ ret = self.system + self.sep
49
+ for role, message in messages:
50
+ if message:
51
+ if type(message) is tuple:
52
+ message, _, _ = message
53
+ ret += role + ": " + message + self.sep
54
+ else:
55
+ ret += role + ":"
56
+ # elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2
57
+ # seps = [self.sep, self.sep2]
58
+ # ret = self.system + seps[0]
59
+ # ret = ""
60
+ # for i, (role, message) in enumerate(messages):
61
+ # if message:
62
+ # if type(message) is tuple:
63
+ # message, _, _ = message
64
+ # ret += role + ": " + message + seps[i % 2]
65
+ # else:
66
+ # ret += role + ":"
67
+ elif self.sep_style == SeparatorStyle.QWEN_2: # fix: add qwen2
68
+ ret = self.system + self.sep
69
+ for i, (role, message) in enumerate(messages):
70
+ if message:
71
+ if type(message) is tuple:
72
+ message, _, _ = message
73
+ ret += role + message + self.sep
74
+ else:
75
+ ret += role
76
+ elif self.sep_style == SeparatorStyle.CHATML:
77
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
78
+ for role, message in messages:
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, images = message
82
+ message = "<image>" * len(images) + message
83
+ ret += role + "\n" + message + self.sep + "\n"
84
+ else:
85
+ ret += role + "\n"
86
+ return ret
87
+ elif self.sep_style == SeparatorStyle.TWO:
88
+ seps = [self.sep, self.sep2]
89
+ ret = self.system + seps[0]
90
+ for i, (role, message) in enumerate(messages):
91
+ if message:
92
+ if type(message) is tuple:
93
+ message, _, _ = message
94
+ ret += role + ": " + message + seps[i % 2]
95
+ else:
96
+ ret += role + ":"
97
+ elif self.sep_style == SeparatorStyle.MPT:
98
+ ret = self.system + self.sep
99
+ for role, message in messages:
100
+ if message:
101
+ if type(message) is tuple:
102
+ message, _, _ = message
103
+ ret += role + message + self.sep
104
+ else:
105
+ ret += role
106
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
107
+ def wrap_sys(msg): return f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
108
+ def wrap_inst(msg): return f"[INST] {msg} [/INST]"
109
+ ret = ""
110
+
111
+ for i, (role, message) in enumerate(messages):
112
+ if i == 0:
113
+ assert message, "first message should not be none"
114
+ assert role == self.roles[0], "first message should come from user"
115
+ if message:
116
+ if type(message) is tuple:
117
+ message, _, _ = message
118
+ if i == 0:
119
+ message = wrap_sys(self.system) + message
120
+ if i % 2 == 0:
121
+ message = wrap_inst(message)
122
+ ret += self.sep + message
123
+ else:
124
+ ret += " " + message + " " + self.sep2
125
+ else:
126
+ ret += ""
127
+ ret = ret.lstrip(self.sep)
128
+ elif self.sep_style == SeparatorStyle.PLAIN:
129
+ seps = [self.sep, self.sep2]
130
+ ret = self.system
131
+ for i, (role, message) in enumerate(messages):
132
+ if message:
133
+ if type(message) is tuple:
134
+ message, _, _ = message
135
+ ret += message + seps[i % 2]
136
+ else:
137
+ ret += ""
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ return ret
142
+
143
+ def append_message(self, role, message):
144
+ self.messages.append([role, message])
145
+
146
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
147
+ if image_process_mode == "Pad":
148
+ def expand2square(pil_img, background_color=(122, 116, 104)):
149
+ width, height = pil_img.size
150
+ if width == height:
151
+ return pil_img
152
+ elif width > height:
153
+ result = Image.new(pil_img.mode, (width, width), background_color)
154
+ result.paste(pil_img, (0, (width - height) // 2))
155
+ return result
156
+ else:
157
+ result = Image.new(pil_img.mode, (height, height), background_color)
158
+ result.paste(pil_img, ((height - width) // 2, 0))
159
+ return result
160
+ image = expand2square(image)
161
+ elif image_process_mode in ["Default", "Crop"]:
162
+ pass
163
+ elif image_process_mode == "Resize":
164
+ image = image.resize((336, 336))
165
+ else:
166
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
167
+ if max(image.size) > max_len:
168
+ max_hw, min_hw = max(image.size), min(image.size)
169
+ aspect_ratio = max_hw / min_hw
170
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
171
+ longest_edge = int(shortest_edge * aspect_ratio)
172
+ W, H = image.size
173
+ if H > W:
174
+ H, W = longest_edge, shortest_edge
175
+ else:
176
+ H, W = shortest_edge, longest_edge
177
+ image = image.resize((W, H))
178
+ if return_pil:
179
+ return image
180
+ else:
181
+ buffered = BytesIO()
182
+ image.save(buffered, format=image_format)
183
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
184
+ return img_b64_str
185
+
186
+ def get_images(self, return_pil=False):
187
+ images = []
188
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
189
+ if i % 2 == 0:
190
+ if type(msg) is tuple:
191
+ msg, image, image_process_mode = msg
192
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
193
+ images.append(image)
194
+ return images
195
+
196
+ def to_gradio_chatbot(self):
197
+ ret = []
198
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
199
+ if i % 2 == 0:
200
+ if type(msg) is tuple:
201
+ msg, image, image_process_mode = msg
202
+ img_b64_str = self.process_image(
203
+ image, "Default", return_pil=False,
204
+ image_format='JPEG')
205
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
206
+ msg = img_str + msg.replace('<image>', '').strip()
207
+ ret.append([msg, None])
208
+ else:
209
+ ret.append([msg, None])
210
+ else:
211
+ ret[-1][-1] = msg
212
+ return ret
213
+
214
+ def copy(self):
215
+ return Conversation(
216
+ system=self.system,
217
+ roles=self.roles,
218
+ messages=[[x, y] for x, y in self.messages],
219
+ offset=self.offset,
220
+ sep_style=self.sep_style,
221
+ sep=self.sep,
222
+ sep2=self.sep2,
223
+ version=self.version)
224
+
225
+ def dict(self):
226
+ if len(self.get_images()) > 0:
227
+ return {
228
+ "system": self.system,
229
+ "roles": self.roles,
230
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
231
+ "offset": self.offset,
232
+ "sep": self.sep,
233
+ "sep2": self.sep2,
234
+ }
235
+ return {
236
+ "system": self.system,
237
+ "roles": self.roles,
238
+ "messages": self.messages,
239
+ "offset": self.offset,
240
+ "sep": self.sep,
241
+ "sep2": self.sep2,
242
+ }
243
+
244
+
245
+ conv_vicuna_v0 = Conversation(
246
+ system="A chat between a curious human and an artificial intelligence assistant. "
247
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
248
+ roles=("Human", "Assistant"),
249
+ messages=(
250
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
251
+ ("Assistant",
252
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
253
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
254
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
255
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
256
+ "renewable and non-renewable energy sources:\n"
257
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
258
+ "energy sources are finite and will eventually run out.\n"
259
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
260
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
261
+ "and other negative effects.\n"
262
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
263
+ "have lower operational costs than non-renewable sources.\n"
264
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
265
+ "locations than non-renewable sources.\n"
266
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
267
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
268
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
269
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
270
+ ),
271
+ offset=2,
272
+ sep_style=SeparatorStyle.SINGLE,
273
+ sep="###",
274
+ )
275
+
276
+ conv_vicuna_v1 = Conversation(
277
+ system="A chat between a curious user and an artificial intelligence assistant. "
278
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
279
+ roles=("USER", "ASSISTANT"),
280
+ version="v1",
281
+ messages=(),
282
+ offset=0,
283
+ sep_style=SeparatorStyle.TWO,
284
+ sep=" ",
285
+ sep2="</s>",
286
+ )
287
+
288
+ conv_llama_2 = Conversation(
289
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
290
+
291
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
292
+ roles=("USER", "ASSISTANT"),
293
+ version="llama_v2",
294
+ messages=(),
295
+ offset=0,
296
+ sep_style=SeparatorStyle.LLAMA_2,
297
+ sep="<s>",
298
+ sep2="</s>",
299
+ )
300
+
301
+ conv_llava_llama_2 = Conversation(
302
+ system="You are a helpful language and vision assistant. "
303
+ "You are able to understand the visual content that the user provides, "
304
+ "and assist the user with a variety of tasks using natural language.",
305
+ roles=("USER", "ASSISTANT"),
306
+ version="llama_v2",
307
+ messages=(),
308
+ offset=0,
309
+ sep_style=SeparatorStyle.LLAMA_2,
310
+ sep="<s>",
311
+ sep2="</s>",
312
+ )
313
+
314
+ conv_mpt = Conversation(
315
+ system="""<|im_start|>system
316
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
317
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
318
+ version="mpt",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.MPT,
322
+ sep="<|im_end|>",
323
+ )
324
+
325
+ conv_llava_plain = Conversation(
326
+ system="",
327
+ roles=("", ""),
328
+ messages=(
329
+ ),
330
+ offset=0,
331
+ sep_style=SeparatorStyle.PLAIN,
332
+ sep="\n",
333
+ )
334
+
335
+ conv_llava_v0 = Conversation(
336
+ system="A chat between a curious human and an artificial intelligence assistant. "
337
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
338
+ roles=("Human", "Assistant"),
339
+ messages=(
340
+ ),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.SINGLE,
343
+ sep="###",
344
+ )
345
+
346
+ conv_llava_v0_mmtag = Conversation(
347
+ system="A chat between a curious user and an artificial intelligence assistant. "
348
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
349
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
350
+ roles=("Human", "Assistant"),
351
+ messages=(
352
+ ),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.SINGLE,
355
+ sep="###",
356
+ version="v0_mmtag",
357
+ )
358
+
359
+ conv_llava_v1 = Conversation(
360
+ system="A chat between a curious human and an artificial intelligence assistant. "
361
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
362
+ roles=("USER", "ASSISTANT"),
363
+ version="v1",
364
+ messages=(),
365
+ offset=0,
366
+ sep_style=SeparatorStyle.TWO,
367
+ sep=" ",
368
+ sep2="</s>",
369
+ )
370
+
371
+ conv_llava_v1_mmtag = Conversation(
372
+ system="A chat between a curious user and an artificial intelligence assistant. "
373
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
374
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
375
+ roles=("USER", "ASSISTANT"),
376
+ messages=(),
377
+ offset=0,
378
+ sep_style=SeparatorStyle.TWO,
379
+ sep=" ",
380
+ sep2="</s>",
381
+ version="v1_mmtag",
382
+ )
383
+
384
+ conv_mistral_instruct = Conversation(
385
+ system="",
386
+ roles=("USER", "ASSISTANT"),
387
+ version="llama_v2",
388
+ messages=(),
389
+ offset=0,
390
+ sep_style=SeparatorStyle.LLAMA_2,
391
+ sep="",
392
+ sep2="</s>",
393
+ )
394
+
395
+ conv_chatml_direct = Conversation(
396
+ system="""<|im_start|>system
397
+ Answer the questions.""",
398
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
399
+ version="mpt",
400
+ messages=(),
401
+ offset=0,
402
+ sep_style=SeparatorStyle.MPT,
403
+ sep="<|im_end|>",
404
+ )
405
+
406
+
407
+ conv_qwen_2 = Conversation(
408
+ system="<|im_start|>system\nYou are a helpful assistant.",
409
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
410
+ version="qwen_v2",
411
+ messages=(),
412
+ offset=0,
413
+ sep_style=SeparatorStyle.QWEN_2,
414
+ sep="<|im_end|>\n",
415
+ )
416
+
417
+
418
+ # conv_qwen_2 = Conversation(
419
+ # system="",
420
+ # roles=("user", "assistant"),
421
+ # version="qwen_v2",
422
+ # messages=(),
423
+ # offset=0,
424
+ # sep_style=SeparatorStyle.QWEN_2,
425
+ # sep=" ",
426
+ # sep2="<|im_end|>",
427
+ # )
428
+
429
+
430
+ # fix: add qwen2
431
+ # conv_qwen_2 = Conversation(
432
+ # system="A chat between a curious user and an artificial intelligence assistant. "
433
+ # "The assistant gives helpful, detailed, and polite answers to the user's questions.",
434
+ # roles=("USER", "ASSISTANT"),
435
+ # version="qwen_v2",
436
+ # messages=(),
437
+ # offset=0,
438
+ # sep_style=SeparatorStyle.QWEN_2,
439
+ # sep=" ",
440
+ # sep2="<|endoftext|>",
441
+ # )
442
+
443
+ # conv_qwen_2 = Conversation(
444
+ # system="""<|im_start|>system
445
+ # You are a helpful assistant.""",
446
+ # roles=("<|im_start|>user", "<|im_start|>assistant"),
447
+ # version="qwen_v2",
448
+ # messages=[],
449
+ # offset=0,
450
+ # sep_style=SeparatorStyle.QWEN_2,
451
+ # sep="<|im_end|>",
452
+ # sep2="<|im_end|>",
453
+ # )
454
+
455
+ default_conversation = conv_qwen_2
456
+ conv_templates = {
457
+ "default": conv_qwen_2,
458
+ "v0": conv_vicuna_v0,
459
+ "v1": conv_vicuna_v1,
460
+ "vicuna_v1": conv_vicuna_v1,
461
+ "qwen_2": conv_qwen_2,
462
+ "llama_2": conv_llama_2,
463
+ "mistral_instruct": conv_mistral_instruct,
464
+ "chatml_direct": conv_chatml_direct,
465
+ "mistral_direct": conv_chatml_direct,
466
+
467
+ "plain": conv_llava_plain,
468
+ "v0_plain": conv_llava_plain,
469
+ "llava_v0": conv_llava_v0,
470
+ "v0_mmtag": conv_llava_v0_mmtag,
471
+ "llava_v1": conv_llava_v1,
472
+ "v1_mmtag": conv_llava_v1_mmtag,
473
+ "llava_llama_2": conv_llava_llama_2,
474
+
475
+ "mpt": conv_mpt,
476
+ }
477
+
478
+ if __name__ == "__main__":
479
+ print("conversation:", default_conversation.get_prompt())
llava/mm_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ PIL.Image.MAX_IMAGE_PIXELS=500000000
4
+ from io import BytesIO
5
+ import base64
6
+ import torch
7
+ import math
8
+ import ast
9
+
10
+ from transformers import StoppingCriteria
11
+ from llava.constants import IMAGE_TOKEN_INDEX
12
+
13
+
14
+ def select_best_resolution(original_size, possible_resolutions):
15
+ """
16
+ Selects the best resolution from a list of possible resolutions based on the original size.
17
+
18
+ Args:
19
+ original_size (tuple): The original size of the image in the format (width, height).
20
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
21
+
22
+ Returns:
23
+ tuple: The best fit resolution in the format (width, height).
24
+ """
25
+ original_width, original_height = original_size
26
+ best_fit = None
27
+ max_effective_resolution = 0
28
+ min_wasted_resolution = float('inf')
29
+
30
+ for width, height in possible_resolutions:
31
+ scale = min(width / original_width, height / original_height)
32
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
33
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
34
+ wasted_resolution = (width * height) - effective_resolution
35
+
36
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
37
+ max_effective_resolution = effective_resolution
38
+ min_wasted_resolution = wasted_resolution
39
+ best_fit = (width, height)
40
+
41
+ return best_fit
42
+
43
+
44
+ def resize_and_pad_image(image, target_resolution):
45
+ """
46
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
47
+
48
+ Args:
49
+ image (PIL.Image.Image): The input image.
50
+ target_resolution (tuple): The target resolution (width, height) of the image.
51
+
52
+ Returns:
53
+ PIL.Image.Image: The resized and padded image.
54
+ """
55
+ original_width, original_height = image.size
56
+ target_width, target_height = target_resolution
57
+
58
+ scale_w = target_width / original_width
59
+ scale_h = target_height / original_height
60
+
61
+ if scale_w < scale_h:
62
+ new_width = target_width
63
+ new_height = min(math.ceil(original_height * scale_w), target_height)
64
+ else:
65
+ new_height = target_height
66
+ new_width = min(math.ceil(original_width * scale_h), target_width)
67
+
68
+ # Resize the image
69
+ resized_image = image.resize((new_width, new_height))
70
+
71
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
72
+ paste_x = (target_width - new_width) // 2
73
+ paste_y = (target_height - new_height) // 2
74
+ new_image.paste(resized_image, (paste_x, paste_y))
75
+
76
+ return new_image
77
+
78
+
79
+ def divide_to_patches(image, patch_size):
80
+ """
81
+ Divides an image into patches of a specified size.
82
+
83
+ Args:
84
+ image (PIL.Image.Image): The input image.
85
+ patch_size (int): The size of each patch.
86
+
87
+ Returns:
88
+ list: A list of PIL.Image.Image objects representing the patches.
89
+ """
90
+ patches = []
91
+ width, height = image.size
92
+ for i in range(0, height, patch_size):
93
+ for j in range(0, width, patch_size):
94
+ box = (j, i, j + patch_size, i + patch_size)
95
+ patch = image.crop(box)
96
+ patches.append(patch)
97
+
98
+ return patches
99
+
100
+
101
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
102
+ """
103
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
104
+
105
+ Args:
106
+ image_size (tuple): The size of the input image in the format (width, height).
107
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
108
+ patch_size (int): The size of each image patch.
109
+
110
+ Returns:
111
+ tuple: The shape of the image patch grid in the format (width, height).
112
+ """
113
+ if type(grid_pinpoints) is list:
114
+ possible_resolutions = grid_pinpoints
115
+ else:
116
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
117
+ width, height = select_best_resolution(image_size, possible_resolutions)
118
+ return width // patch_size, height // patch_size
119
+
120
+
121
+ def process_anyres_image(image, processor, grid_pinpoints):
122
+ """
123
+ Process an image with variable resolutions.
124
+
125
+ Args:
126
+ image (PIL.Image.Image): The input image to be processed.
127
+ processor: The image processor object.
128
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
129
+
130
+ Returns:
131
+ torch.Tensor: A tensor containing the processed image patches.
132
+ """
133
+ if type(grid_pinpoints) is list:
134
+ possible_resolutions = grid_pinpoints
135
+ else:
136
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
137
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
138
+ image_padded = resize_and_pad_image(image, best_resolution)
139
+
140
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
141
+
142
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
143
+
144
+ image_patches = [image_original_resize] + patches
145
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
146
+ for image_patch in image_patches]
147
+ return torch.stack(image_patches, dim=0)
148
+
149
+
150
+ def load_image_from_base64(image):
151
+ return Image.open(BytesIO(base64.b64decode(image)))
152
+
153
+
154
+ def expand2square(pil_img, background_color):
155
+ width, height = pil_img.size
156
+ if width == height:
157
+ return pil_img
158
+ elif width > height:
159
+ result = Image.new(pil_img.mode, (width, width), background_color)
160
+ result.paste(pil_img, (0, (width - height) // 2))
161
+ return result
162
+ else:
163
+ result = Image.new(pil_img.mode, (height, height), background_color)
164
+ result.paste(pil_img, ((height - width) // 2, 0))
165
+ return result
166
+
167
+
168
+ def process_images(images, image_processor, model_cfg):
169
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
170
+ new_images = []
171
+ if image_aspect_ratio == 'pad':
172
+ for image in images:
173
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
174
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
175
+ new_images.append(image)
176
+ elif image_aspect_ratio == "anyres":
177
+ for image in images:
178
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
179
+ new_images.append(image)
180
+ else:
181
+ return image_processor(images, return_tensors='pt')['pixel_values']
182
+ if all(x.shape == new_images[0].shape for x in new_images):
183
+ new_images = torch.stack(new_images, dim=0)
184
+ return new_images
185
+
186
+
187
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
188
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
189
+
190
+ def insert_separator(X, sep):
191
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
192
+
193
+ input_ids = []
194
+ offset = 0
195
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
196
+ offset = 1
197
+ input_ids.append(prompt_chunks[0][0])
198
+
199
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
200
+ input_ids.extend(x[offset:])
201
+
202
+ if return_tensors is not None:
203
+ if return_tensors == 'pt':
204
+ return torch.tensor(input_ids, dtype=torch.long)
205
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
206
+ return input_ids
207
+
208
+
209
+ def get_model_name_from_path(model_path):
210
+ model_path = model_path.strip("/")
211
+ model_paths = model_path.split("/")
212
+ if model_paths[-1].startswith('checkpoint-'):
213
+ return model_paths[-2] + "_" + model_paths[-1]
214
+ else:
215
+ return model_paths[-1]
216
+
217
+
218
+ class KeywordsStoppingCriteria(StoppingCriteria):
219
+ def __init__(self, keywords, tokenizer, input_ids):
220
+ self.keywords = keywords
221
+ self.keyword_ids = []
222
+ self.max_keyword_len = 0
223
+ for keyword in keywords:
224
+ cur_keyword_ids = tokenizer(keyword).input_ids
225
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
226
+ cur_keyword_ids = cur_keyword_ids[1:]
227
+ if len(cur_keyword_ids) > self.max_keyword_len:
228
+ self.max_keyword_len = len(cur_keyword_ids)
229
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
230
+ self.tokenizer = tokenizer
231
+ self.start_len = input_ids.shape[1]
232
+
233
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
234
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
235
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
236
+ for keyword_id in self.keyword_ids:
237
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
238
+ if torch.equal(truncated_output_ids, keyword_id):
239
+ return True
240
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
241
+ for keyword in self.keywords:
242
+ if keyword in outputs:
243
+ return True
244
+ return False
245
+
246
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
247
+ outputs = []
248
+ for i in range(output_ids.shape[0]):
249
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
250
+ return all(outputs)
llava/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # try:
2
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3
+ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5
+ from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig
6
+ # except:
7
+ # pass
8
+
llava/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (554 Bytes). View file
 
llava/model/__pycache__/builder.cpython-312.pyc ADDED
Binary file (9.66 kB). View file
 
llava/model/__pycache__/llava_arch.cpython-312.pyc ADDED
Binary file (20.3 kB). View file
 
llava/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if use_flash_attn:
46
+ kwargs['attn_implementation'] = 'flash_attention_2'
47
+
48
+ if 'llava' in model_name.lower():
49
+ # Load LLaVA model
50
+ if 'lora' in model_name.lower() and model_base is None:
51
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52
+ if 'lora' in model_name.lower() and model_base is not None:
53
+ from llava.model.language_model.llava_llama import LlavaConfig
54
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56
+ print('Loading LLaVA from base model...')
57
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59
+ if model.lm_head.weight.shape[0] != token_num:
60
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62
+
63
+ print('Loading additional LLaVA weights...')
64
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66
+ else:
67
+ # this is probably from HF Hub
68
+ from huggingface_hub import hf_hub_download
69
+
70
+ def load_from_hf(repo_id, filename, subfolder=None):
71
+ cache_file = hf_hub_download(
72
+ repo_id=repo_id,
73
+ filename=filename,
74
+ subfolder=subfolder)
75
+ return torch.load(cache_file, map_location='cpu')
76
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
77
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
78
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
79
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
80
+ model.load_state_dict(non_lora_trainables, strict=False)
81
+
82
+ from peft import PeftModel
83
+ print('Loading LoRA weights...')
84
+ model = PeftModel.from_pretrained(model, model_path)
85
+ print('Merging LoRA weights...')
86
+ model = model.merge_and_unload()
87
+ print('Model is loaded...')
88
+ elif model_base is not None:
89
+ # this may be mm projector only
90
+ print('Loading LLaVA from base model...')
91
+ if 'mpt' in model_name.lower():
92
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
93
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
94
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
95
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
96
+ model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
97
+ else:
98
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
99
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
100
+ # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
101
+ model = LlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
102
+
103
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
104
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
105
+ model.load_state_dict(mm_projector_weights, strict=False)
106
+ else:
107
+ if 'mpt' in model_name.lower():
108
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
109
+ model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
110
+ elif 'mistral' in model_name.lower():
111
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
112
+ model = LlavaMistralForCausalLM.from_pretrained(
113
+ model_path,
114
+ low_cpu_mem_usage=True,
115
+ **kwargs
116
+ )
117
+ elif 'dclm' in model_name.lower():
118
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
119
+ model = LlavaOpenlmForCausalLM.from_pretrained(
120
+ model_path,
121
+ low_cpu_mem_usage=True,
122
+ **kwargs
123
+ )
124
+ else:
125
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
126
+ # model = LlavaLlamaForCausalLM.from_pretrained(
127
+ # model_path,
128
+ # low_cpu_mem_usage=True,
129
+ # **kwargs
130
+ # )
131
+ model = LlavaQwen2ForCausalLM.from_pretrained(
132
+ model_path,
133
+ low_cpu_mem_usage=True,
134
+ **kwargs
135
+ )
136
+ else:
137
+ # Load language model
138
+ if model_base is not None:
139
+ # PEFT model
140
+ from peft import PeftModel
141
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
142
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
143
+ print(f"Loading LoRA weights from {model_path}")
144
+ model = PeftModel.from_pretrained(model, model_path)
145
+ print(f"Merging weights")
146
+ model = model.merge_and_unload()
147
+ print('Convert to FP16...')
148
+ model.to(torch.float16)
149
+ else:
150
+ use_fast = False
151
+ if 'mpt' in model_name.lower():
152
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
153
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
154
+ else:
155
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
156
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
157
+
158
+ image_processor = None
159
+
160
+ if 'llava' in model_name.lower():
161
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
162
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
163
+ if mm_use_im_patch_token:
164
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
165
+ if mm_use_im_start_end:
166
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
167
+ model.resize_token_embeddings(len(tokenizer))
168
+
169
+ vision_tower = model.get_vision_tower()
170
+ if not vision_tower.is_loaded:
171
+ vision_tower.load_model(device_map=device_map)
172
+ if device_map != 'auto':
173
+ vision_tower.to(device=device_map, dtype=torch.float16)
174
+ image_processor = vision_tower.image_processor
175
+
176
+ if hasattr(model.config, "max_sequence_length"):
177
+ context_len = model.config.max_sequence_length
178
+ else:
179
+ context_len = 2048
180
+
181
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llava.model import *
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
llava/model/language_model/__pycache__/llava_llama.cpython-312.pyc ADDED
Binary file (5.58 kB). View file
 
llava/model/language_model/__pycache__/llava_mistral.cpython-312.pyc ADDED
Binary file (5.52 kB). View file
 
llava/model/language_model/__pycache__/llava_mpt.cpython-312.pyc ADDED
Binary file (4.37 kB). View file
 
llava/model/language_model/__pycache__/llava_qwen.cpython-312.pyc ADDED
Binary file (5.51 kB). View file
 
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava_llama"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ cache_position=None,
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+
74
+ if inputs_embeds is None:
75
+ (
76
+ input_ids,
77
+ position_ids,
78
+ attention_mask,
79
+ past_key_values,
80
+ inputs_embeds,
81
+ labels
82
+ ) = self.prepare_inputs_labels_for_multimodal(
83
+ input_ids,
84
+ position_ids,
85
+ attention_mask,
86
+ past_key_values,
87
+ labels,
88
+ images,
89
+ image_sizes
90
+ )
91
+
92
+ return super().forward(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ labels=labels,
99
+ use_cache=use_cache,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict
103
+ )
104
+
105
+ @torch.no_grad()
106
+ def generate(
107
+ self,
108
+ inputs: Optional[torch.Tensor] = None,
109
+ images: Optional[torch.Tensor] = None,
110
+ image_sizes: Optional[torch.Tensor] = None,
111
+ **kwargs,
112
+ ) -> Union[GenerateOutput, torch.LongTensor]:
113
+ position_ids = kwargs.pop("position_ids", None)
114
+ attention_mask = kwargs.pop("attention_mask", None)
115
+ if "inputs_embeds" in kwargs:
116
+ raise NotImplementedError("`inputs_embeds` is not supported")
117
+
118
+ if images is not None:
119
+ (
120
+ inputs,
121
+ position_ids,
122
+ attention_mask,
123
+ _,
124
+ inputs_embeds,
125
+ _
126
+ ) = self.prepare_inputs_labels_for_multimodal(
127
+ inputs,
128
+ position_ids,
129
+ attention_mask,
130
+ None,
131
+ None,
132
+ images,
133
+ image_sizes=image_sizes
134
+ )
135
+ else:
136
+ inputs_embeds = self.get_model().embed_tokens(inputs)
137
+
138
+ return super().generate(
139
+ position_ids=position_ids,
140
+ attention_mask=attention_mask,
141
+ inputs_embeds=inputs_embeds,
142
+ **kwargs
143
+ )
144
+
145
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146
+ inputs_embeds=None, **kwargs):
147
+ images = kwargs.pop("images", None)
148
+ image_sizes = kwargs.pop("image_sizes", None)
149
+ inputs = super().prepare_inputs_for_generation(
150
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151
+ )
152
+ if images is not None:
153
+ inputs['images'] = images
154
+ if image_sizes is not None:
155
+ inputs['image_sizes'] = image_sizes
156
+ return inputs
157
+
158
+ AutoConfig.register("llava_llama", LlavaConfig)
159
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ MistralConfig, MistralModel, MistralForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+ from transformers.generation.utils import GenerateOutput
27
+
28
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+
30
+
31
+ class LlavaMistralConfig(MistralConfig):
32
+ model_type = "llava_mistral"
33
+
34
+
35
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
36
+ config_class = LlavaMistralConfig
37
+
38
+ def __init__(self, config: MistralConfig):
39
+ super(LlavaMistralModel, self).__init__(config)
40
+
41
+
42
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaMistralConfig
44
+
45
+ def __init__(self, config):
46
+ super(MistralForCausalLM, self).__init__(config)
47
+ self.model = LlavaMistralModel(config)
48
+
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
158
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, \
21
+ MptConfig, MptForCausalLM, MptModel
22
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23
+
24
+
25
+ class LlavaMptConfig(MptConfig):
26
+ model_type = "llava_mpt"
27
+
28
+
29
+ class LlavaMptModel(LlavaMetaModel, MptModel):
30
+ config_class = LlavaMptConfig
31
+
32
+ def __init__(self, config: MptConfig):
33
+ config.hidden_size = config.d_model
34
+ super(LlavaMptModel, self).__init__(config)
35
+
36
+ def embed_tokens(self, x):
37
+ return self.wte(x)
38
+
39
+
40
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41
+ config_class = LlavaMptConfig
42
+ supports_gradient_checkpointing = True
43
+
44
+ def __init__(self, config):
45
+ super(MptForCausalLM, self).__init__(config)
46
+
47
+ self.transformer = LlavaMptModel(config)
48
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.transformer
55
+
56
+ def _set_gradient_checkpointing(self, module, value=False):
57
+ if isinstance(module, LlavaMptModel):
58
+ module.gradient_checkpointing = value
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ inputs_embeds: Optional[torch.Tensor] = None,
66
+ labels: Optional[torch.Tensor] = None,
67
+ use_cache: Optional[bool] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ images=None):
72
+
73
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74
+
75
+ return super().forward(
76
+ input_ids,
77
+ past_key_values=past_key_values,
78
+ attention_mask=attention_mask,
79
+ inputs_embeds=inputs_embeds,
80
+ labels=labels,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ )
86
+
87
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88
+ images = kwargs.pop("images", None)
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91
+ )
92
+ _inputs['images'] = images
93
+ return _inputs
94
+
95
+
96
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
97
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llava/model/language_model/llava_qwen.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(Qwen2Config):
31
+ model_type = "llava_qwen2"
32
+
33
+
34
+ class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: Qwen2Config):
38
+ super(LlavaQwen2Model, self).__init__(config)
39
+
40
+
41
+ class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(Qwen2ForCausalLM, self).__init__(config)
46
+ self.model = LlavaQwen2Model(config)
47
+ # self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ cache_position=None,
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+
74
+ if inputs_embeds is None:
75
+ (
76
+ input_ids,
77
+ position_ids,
78
+ attention_mask,
79
+ past_key_values,
80
+ inputs_embeds,
81
+ labels
82
+ ) = self.prepare_inputs_labels_for_multimodal(
83
+ input_ids,
84
+ position_ids,
85
+ attention_mask,
86
+ past_key_values,
87
+ labels,
88
+ images,
89
+ image_sizes
90
+ )
91
+
92
+ return super().forward(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ labels=labels,
99
+ use_cache=use_cache,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict
103
+ )
104
+
105
+ @torch.no_grad()
106
+ def generate(
107
+ self,
108
+ inputs: Optional[torch.Tensor] = None,
109
+ images: Optional[torch.Tensor] = None,
110
+ image_sizes: Optional[torch.Tensor] = None,
111
+ **kwargs,
112
+ ) -> Union[GenerateOutput, torch.LongTensor]:
113
+ position_ids = kwargs.pop("position_ids", None)
114
+ attention_mask = kwargs.pop("attention_mask", None)
115
+ if "inputs_embeds" in kwargs:
116
+ raise NotImplementedError("`inputs_embeds` is not supported")
117
+
118
+ if images is not None:
119
+ (
120
+ inputs,
121
+ position_ids,
122
+ attention_mask,
123
+ _,
124
+ inputs_embeds,
125
+ _
126
+ ) = self.prepare_inputs_labels_for_multimodal(
127
+ inputs,
128
+ position_ids,
129
+ attention_mask,
130
+ None,
131
+ None,
132
+ images,
133
+ image_sizes=image_sizes
134
+ )
135
+ else:
136
+ inputs_embeds = self.get_model().embed_tokens(inputs)
137
+
138
+ return super().generate(
139
+ position_ids=position_ids,
140
+ attention_mask=attention_mask,
141
+ inputs_embeds=inputs_embeds,
142
+ **kwargs
143
+ )
144
+
145
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146
+ inputs_embeds=None, **kwargs):
147
+ images = kwargs.pop("images", None)
148
+ image_sizes = kwargs.pop("image_sizes", None)
149
+ inputs = super().prepare_inputs_for_generation(
150
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151
+ )
152
+ if images is not None:
153
+ inputs['images'] = images
154
+ if image_sizes is not None:
155
+ inputs['image_sizes'] = image_sizes
156
+ return inputs
157
+
158
+
159
+ AutoConfig.register("llava_qwen2", LlavaConfig)
160
+ AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ from llava.mm_utils import get_anyres_image_grid_shape
27
+
28
+
29
+ class LlavaMetaModel:
30
+
31
+ def __init__(self, config):
32
+ super(LlavaMetaModel, self).__init__(config)
33
+
34
+ if hasattr(config, "mm_vision_tower"):
35
+ self.vision_tower = build_vision_tower(config, delay_load=True)
36
+ self.mm_projector = build_vision_projector(config)
37
+
38
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
39
+ self.image_newline = nn.Parameter(
40
+ torch.empty(config.hidden_size, dtype=self.dtype)
41
+ )
42
+
43
+ def get_vision_tower(self):
44
+ vision_tower = getattr(self, 'vision_tower', None)
45
+ if type(vision_tower) is list:
46
+ vision_tower = vision_tower[0]
47
+ return vision_tower
48
+
49
+ def initialize_vision_modules(self, model_args, fsdp=None):
50
+ vision_tower = model_args.vision_tower
51
+ mm_vision_select_layer = model_args.mm_vision_select_layer
52
+ mm_vision_select_feature = model_args.mm_vision_select_feature
53
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
54
+ mm_patch_merge_type = model_args.mm_patch_merge_type
55
+
56
+ self.config.mm_vision_tower = vision_tower
57
+
58
+ if self.get_vision_tower() is None:
59
+ vision_tower = build_vision_tower(model_args)
60
+
61
+ if fsdp is not None and len(fsdp) > 0:
62
+ self.vision_tower = [vision_tower]
63
+ else:
64
+ self.vision_tower = vision_tower
65
+ else:
66
+ if fsdp is not None and len(fsdp) > 0:
67
+ vision_tower = self.vision_tower[0]
68
+ else:
69
+ vision_tower = self.vision_tower
70
+ vision_tower.load_model()
71
+
72
+ self.config.use_mm_proj = True
73
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
74
+ self.config.mm_hidden_size = vision_tower.hidden_size
75
+ self.config.mm_vision_select_layer = mm_vision_select_layer
76
+ self.config.mm_vision_select_feature = mm_vision_select_feature
77
+ self.config.mm_patch_merge_type = mm_patch_merge_type
78
+
79
+ if getattr(self, 'mm_projector', None) is None:
80
+ self.mm_projector = build_vision_projector(self.config)
81
+
82
+ if 'unpad' in mm_patch_merge_type:
83
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
84
+ self.image_newline = nn.Parameter(
85
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
86
+ )
87
+ else:
88
+ # In case it is frozen by LoRA
89
+ for p in self.mm_projector.parameters():
90
+ p.requires_grad = True
91
+
92
+ if pretrain_mm_mlp_adapter is not None:
93
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
94
+
95
+ def get_w(weights, keyword):
96
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
97
+
98
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
99
+
100
+
101
+ def unpad_image(tensor, original_size):
102
+ """
103
+ Unpads a PyTorch tensor of a padded and resized image.
104
+
105
+ Args:
106
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
107
+ original_size (tuple): The original size of PIL image (width, height).
108
+
109
+ Returns:
110
+ torch.Tensor: The unpadded image tensor.
111
+ """
112
+ original_width, original_height = original_size
113
+ current_height, current_width = tensor.shape[1:]
114
+
115
+ original_aspect_ratio = original_width / original_height
116
+ current_aspect_ratio = current_width / current_height
117
+
118
+ if original_aspect_ratio > current_aspect_ratio:
119
+ scale_factor = current_width / original_width
120
+ new_height = int(original_height * scale_factor)
121
+ padding = (current_height - new_height) // 2
122
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
123
+ else:
124
+ scale_factor = current_height / original_height
125
+ new_width = int(original_width * scale_factor)
126
+ padding = (current_width - new_width) // 2
127
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
128
+
129
+ return unpadded_tensor
130
+
131
+
132
+ class LlavaMetaForCausalLM(ABC):
133
+
134
+ @abstractmethod
135
+ def get_model(self):
136
+ pass
137
+
138
+ def get_vision_tower(self):
139
+ return self.get_model().get_vision_tower()
140
+
141
+ def encode_images(self, images):
142
+ image_features = self.get_model().get_vision_tower()(images)
143
+ image_features = self.get_model().mm_projector(image_features)
144
+ return image_features
145
+
146
+ def prepare_inputs_labels_for_multimodal(
147
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
148
+ images, image_sizes=None
149
+ ):
150
+ vision_tower = self.get_vision_tower()
151
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
152
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
153
+
154
+ if type(images) is list or images.ndim == 5:
155
+ if type(images) is list:
156
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
157
+ concat_images = torch.cat([image for image in images], dim=0)
158
+ image_features = self.encode_images(concat_images)
159
+ split_sizes = [image.shape[0] for image in images]
160
+ image_features = torch.split(image_features, split_sizes, dim=0)
161
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
162
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
163
+ if mm_patch_merge_type == 'flat':
164
+ image_features = [x.flatten(0, 1) for x in image_features]
165
+ elif mm_patch_merge_type.startswith('spatial'):
166
+ new_image_features = []
167
+ for image_idx, image_feature in enumerate(image_features):
168
+ if image_feature.shape[0] > 1:
169
+ base_image_feature = image_feature[0]
170
+ image_feature = image_feature[1:]
171
+ height = width = self.get_vision_tower().num_patches_per_side
172
+ assert height * width == base_image_feature.shape[0]
173
+ if image_aspect_ratio == 'anyres':
174
+ if hasattr(self.get_vision_tower(), 's2_image_size'):
175
+ img_size = self.get_vision_tower().s2_image_size
176
+ elif isinstance(self.get_vision_tower().config, dict):
177
+ img_size = self.get_vision_tower().config["image_cfg"]["image_size"]
178
+ else:
179
+ img_size = self.get_vision_tower().config.image_size
180
+
181
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, img_size)
182
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
183
+ else:
184
+ raise NotImplementedError
185
+ if 'unpad' in mm_patch_merge_type:
186
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
187
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
188
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
189
+ image_feature = torch.cat((
190
+ image_feature,
191
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
192
+ ), dim=-1)
193
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
194
+ else:
195
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
196
+ image_feature = image_feature.flatten(0, 3)
197
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
198
+ else:
199
+ image_feature = image_feature[0]
200
+ if 'unpad' in mm_patch_merge_type:
201
+ image_feature = torch.cat((
202
+ image_feature,
203
+ self.model.image_newline[None].to(image_feature.device)
204
+ ), dim=0)
205
+ new_image_features.append(image_feature)
206
+ image_features = new_image_features
207
+ else:
208
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
209
+ else:
210
+ image_features = self.encode_images(images)
211
+
212
+ # TODO: image start / end is not implemented here to support pretraining.
213
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
214
+ raise NotImplementedError
215
+
216
+ # Let's just add dummy tensors if they do not exist,
217
+ # it is a headache to deal with None all the time.
218
+ # But it is not ideal, and if you have a better idea,
219
+ # please open an issue / submit a PR, thanks.
220
+ _labels = labels
221
+ _position_ids = position_ids
222
+ _attention_mask = attention_mask
223
+ if attention_mask is None:
224
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
225
+ else:
226
+ attention_mask = attention_mask.bool()
227
+ if position_ids is None:
228
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
229
+ if labels is None:
230
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
231
+
232
+ # remove the padding using attention_mask -- FIXME
233
+ _input_ids = input_ids
234
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
235
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
236
+
237
+ new_input_embeds = []
238
+ new_labels = []
239
+ cur_image_idx = 0
240
+ for batch_idx, cur_input_ids in enumerate(input_ids):
241
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
242
+ if num_images == 0:
243
+ cur_image_features = image_features[cur_image_idx]
244
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
245
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
246
+ new_input_embeds.append(cur_input_embeds)
247
+ new_labels.append(labels[batch_idx])
248
+ cur_image_idx += 1
249
+ continue
250
+
251
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
252
+ cur_input_ids_noim = []
253
+ cur_labels = labels[batch_idx]
254
+ cur_labels_noim = []
255
+ for i in range(len(image_token_indices) - 1):
256
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
257
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
258
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
259
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
260
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
261
+ cur_new_input_embeds = []
262
+ cur_new_labels = []
263
+
264
+ for i in range(num_images + 1):
265
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
266
+ cur_new_labels.append(cur_labels_noim[i])
267
+ if i < num_images:
268
+ cur_image_features = image_features[cur_image_idx]
269
+ cur_image_idx += 1
270
+ cur_new_input_embeds.append(cur_image_features)
271
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
272
+
273
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
274
+
275
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
276
+ cur_new_labels = torch.cat(cur_new_labels)
277
+
278
+ new_input_embeds.append(cur_new_input_embeds)
279
+ new_labels.append(cur_new_labels)
280
+
281
+ # Truncate sequences to max length as image embeddings can make the sequence longer
282
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
283
+ if tokenizer_model_max_length is not None:
284
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
285
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
286
+
287
+ # Combine them
288
+ max_len = max(x.shape[0] for x in new_input_embeds)
289
+ batch_size = len(new_input_embeds)
290
+
291
+ new_input_embeds_padded = []
292
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
293
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
294
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
295
+
296
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
297
+ cur_len = cur_new_embed.shape[0]
298
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
299
+ new_input_embeds_padded.append(torch.cat((
300
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
301
+ cur_new_embed
302
+ ), dim=0))
303
+ if cur_len > 0:
304
+ new_labels_padded[i, -cur_len:] = cur_new_labels
305
+ attention_mask[i, -cur_len:] = True
306
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
307
+ else:
308
+ new_input_embeds_padded.append(torch.cat((
309
+ cur_new_embed,
310
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
311
+ ), dim=0))
312
+ if cur_len > 0:
313
+ new_labels_padded[i, :cur_len] = cur_new_labels
314
+ attention_mask[i, :cur_len] = True
315
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
316
+
317
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
318
+
319
+ if _labels is None:
320
+ new_labels = None
321
+ else:
322
+ new_labels = new_labels_padded
323
+
324
+ if _attention_mask is None:
325
+ attention_mask = None
326
+ else:
327
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
328
+
329
+ if _position_ids is None:
330
+ position_ids = None
331
+
332
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
333
+
334
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
335
+ if model_args.mm_use_im_patch_token:
336
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
337
+ self.resize_token_embeddings(len(tokenizer))
338
+
339
+ if model_args.mm_use_im_start_end:
340
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
341
+ self.resize_token_embeddings(len(tokenizer))
342
+
343
+ if num_new_tokens > 0:
344
+ input_embeddings = self.get_input_embeddings().weight.data
345
+ output_embeddings = self.get_output_embeddings().weight.data
346
+
347
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
348
+ dim=0, keepdim=True)
349
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
350
+ dim=0, keepdim=True)
351
+
352
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
353
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
354
+
355
+ if model_args.tune_mm_mlp_adapter:
356
+ for p in self.get_input_embeddings().parameters():
357
+ p.requires_grad = True
358
+ for p in self.get_output_embeddings().parameters():
359
+ p.requires_grad = False
360
+
361
+ if model_args.pretrain_mm_mlp_adapter:
362
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
363
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
364
+ assert num_new_tokens == 2
365
+ if input_embeddings.shape == embed_tokens_weight.shape:
366
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
367
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
368
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
369
+ else:
370
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
371
+ elif model_args.mm_use_im_patch_token:
372
+ if model_args.tune_mm_mlp_adapter:
373
+ for p in self.get_input_embeddings().parameters():
374
+ p.requires_grad = False
375
+ for p in self.get_output_embeddings().parameters():
376
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llava/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc ADDED
Binary file (1.29 kB). View file
 
llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
llava/model/multimodal_encoder/__pycache__/mobileclip_encoder.cpython-312.pyc ADDED
Binary file (6.68 kB). View file
 
llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3
+ from .mobileclip_encoder import MobileCLIPVisionTower
4
+
5
+
6
+ def build_vision_tower(vision_tower_cfg, **kwargs):
7
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8
+ is_absolute_path_exists = os.path.exists(vision_tower)
9
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
10
+
11
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
12
+ if use_s2:
13
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
14
+ else:
15
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
+ elif "mobileclip" in vision_tower.lower():
17
+ return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
18
+
19
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+ self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
17
+ self.input_image_size = getattr(args, 'input_image_size', None)
18
+
19
+ if self.tune_vision_tower:
20
+ print("CLIP Vision tower is set to tunable")
21
+
22
+ if not delay_load:
23
+ self.load_model()
24
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
25
+ self.load_model()
26
+ else:
27
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
28
+ if self.input_image_size is not None:
29
+ self.cfg_only.image_size = self.input_image_size
30
+
31
+ def load_model(self, device_map=None):
32
+ if self.is_loaded:
33
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
34
+ return
35
+
36
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
37
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
38
+ if not self.tune_vision_tower:
39
+ self.vision_tower.requires_grad_(False)
40
+
41
+ if self.input_image_size is not None:
42
+ print("Using input image size: {}".format(self.input_image_size))
43
+ self.image_processor.size['shortest_edge'] = self.input_image_size
44
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.input_image_size
45
+
46
+ self.is_loaded = True
47
+
48
+ def feature_select(self, image_forward_outs):
49
+ image_features = image_forward_outs.hidden_states[self.select_layer]
50
+ if self.select_feature == 'patch':
51
+ image_features = image_features[:, 1:]
52
+ elif self.select_feature == 'cls_patch':
53
+ image_features = image_features
54
+ else:
55
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
56
+ return image_features
57
+
58
+ def forward(self, images):
59
+ if self.tune_vision_tower:
60
+ return self.forward_images(images)
61
+ else:
62
+ with torch.no_grad():
63
+ return self.forward_images(images)
64
+
65
+ def forward_images(self, images):
66
+ if type(images) is list:
67
+ image_features = []
68
+ for image in images:
69
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
70
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
71
+ image_features.append(image_feature)
72
+ else:
73
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
74
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
75
+
76
+ return image_features
77
+
78
+ @property
79
+ def dummy_feature(self):
80
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
81
+
82
+ @property
83
+ def dtype(self):
84
+ return self.vision_tower.dtype
85
+
86
+ @property
87
+ def device(self):
88
+ return self.vision_tower.device
89
+
90
+ @property
91
+ def config(self):
92
+ if self.is_loaded:
93
+ return self.vision_tower.config
94
+ else:
95
+ return self.cfg_only
96
+
97
+ @property
98
+ def hidden_size(self):
99
+ return self.config.hidden_size
100
+
101
+ @property
102
+ def num_patches_per_side(self):
103
+ return self.config.image_size // self.config.patch_size
104
+
105
+ @property
106
+ def num_patches(self):
107
+ return (self.config.image_size // self.config.patch_size) ** 2
108
+
109
+
110
+
111
+ class CLIPVisionTowerS2(CLIPVisionTower):
112
+ def __init__(self, vision_tower, args, delay_load=False):
113
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
114
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
115
+ self.s2_scales.sort()
116
+ self.s2_split_size = self.s2_scales[0]
117
+ self.s2_image_size = self.s2_scales[-1]
118
+
119
+ super().__init__(vision_tower, args, delay_load)
120
+
121
+ try:
122
+ from s2wrapper import forward as multiscale_forward
123
+ except ImportError:
124
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
125
+ self.multiscale_forward = multiscale_forward
126
+
127
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
128
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
129
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
130
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
131
+
132
+ def load_model(self, device_map=None):
133
+ if self.is_loaded:
134
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
135
+ return
136
+
137
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
138
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
139
+ self.vision_tower.requires_grad_(False)
140
+
141
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
142
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
143
+
144
+ self.is_loaded = True
145
+
146
+ @torch.no_grad()
147
+ def forward_feature(self, images):
148
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
149
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
150
+ return image_features
151
+
152
+ @torch.no_grad()
153
+ def forward(self, images):
154
+ if type(images) is list:
155
+ image_features = []
156
+ for image in images:
157
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
158
+ image_features.append(image_feature)
159
+ else:
160
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
161
+
162
+ return image_features
163
+
164
+ @property
165
+ def hidden_size(self):
166
+ return self.config.hidden_size * len(self.s2_scales)
llava/model/multimodal_encoder/mobileclip/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import os
6
+ import json
7
+ from typing import Any
8
+
9
+ import torch.nn as nn
10
+ from timm.models import create_model
11
+
12
+ from .mci import GlobalPool2D
13
+
14
+
15
+ def load_model_config(
16
+ model_name: str,
17
+ ) -> Any:
18
+ # Strip suffixes to model name
19
+ model_name = "_".join(model_name.split("_")[0:2])
20
+
21
+ # Config files
22
+ root_dir = os.path.dirname(os.path.abspath(__file__))
23
+ configs_dir = os.path.join(root_dir, "configs")
24
+ model_cfg_file = os.path.join(configs_dir, model_name + ".json")
25
+
26
+ # Get config from yaml file
27
+ if not os.path.exists(model_cfg_file):
28
+ raise ValueError(f"Unsupported model name: {model_name}")
29
+ model_cfg = json.load(open(model_cfg_file, "r"))
30
+
31
+ return model_cfg
32
+
33
+
34
+ class MCi(nn.Module):
35
+ """
36
+ This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
37
+ """
38
+
39
+ def __init__(self, model_name: str, *args, **kwargs) -> None:
40
+ super().__init__()
41
+ self.projection_dim = None
42
+ if "projection_dim" in kwargs:
43
+ self.projection_dim = kwargs.get("projection_dim")
44
+
45
+ # Create model
46
+ self.model = create_model(model_name, projection_dim=self.projection_dim)
47
+
48
+ # Build out projection head.
49
+ if self.projection_dim is not None:
50
+ if hasattr(self.model, "head"):
51
+ self.model.head = MCi._update_image_classifier(
52
+ image_classifier=self.model.head, projection_dim=self.projection_dim
53
+ )
54
+
55
+ def forward(self, x: Any, *args, **kwargs) -> Any:
56
+ """A forward function of the model."""
57
+ x = self.model(x, *args, **kwargs)
58
+ return x
59
+
60
+ @staticmethod
61
+ def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
62
+ """Return the input feature dimension to the image classification head."""
63
+ in_features = None
64
+ if isinstance(image_classifier, nn.Sequential):
65
+ # Classifier that uses nn.Sequential usually has global pooling and
66
+ # multiple linear layers. Find the first linear layer and get its
67
+ # in_features
68
+ for layer in image_classifier:
69
+ if isinstance(layer, nn.Linear):
70
+ in_features = layer.in_features
71
+ break
72
+ elif isinstance(image_classifier, nn.Linear):
73
+ in_features = image_classifier.in_features
74
+
75
+ if in_features is None:
76
+ raise NotImplementedError(
77
+ f"Cannot get input feature dimension of {image_classifier}."
78
+ )
79
+ return in_features
80
+
81
+ @staticmethod
82
+ def _update_image_classifier(
83
+ image_classifier: nn.Module, projection_dim: int, *args, **kwargs
84
+ ) -> nn.Module:
85
+ in_features = MCi._get_in_feature_dimension(image_classifier)
86
+ new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
87
+ return new_img_classifier
llava/model/multimodal_encoder/mobileclip/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (4.12 kB). View file
 
llava/model/multimodal_encoder/mobileclip/__pycache__/mci.cpython-312.pyc ADDED
Binary file (58.8 kB). View file
 
llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "image_cfg": {
4
+ "image_size": 1024,
5
+ "model_name": "fastvithd",
6
+ "embed_dim": 3072,
7
+ "patch_size": 64
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "dim": 768,
13
+ "ffn_multiplier_per_layer": 4.0,
14
+ "n_heads_per_layer": 12,
15
+ "n_transformer_layers": 12,
16
+ "norm_layer": "layer_norm_fp32",
17
+ "causal_masking": false,
18
+ "model_name": "base"
19
+ }
20
+ }
llava/model/multimodal_encoder/mobileclip/mci.py ADDED
@@ -0,0 +1,1479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import copy
6
+ from functools import partial
7
+ from typing import List, Tuple, Optional, Union, Dict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_
14
+
15
+ from timm.models import register_model
16
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.layers import DropPath, SqueezeExcite
18
+
19
+
20
+ def _cfg(url="", **kwargs):
21
+ return {
22
+ "url": url,
23
+ "num_classes": 1000,
24
+ "input_size": (3, 256, 256),
25
+ "pool_size": None,
26
+ "crop_pct": 0.95,
27
+ "interpolation": "bicubic",
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "classifier": "head",
31
+ **kwargs,
32
+ }
33
+
34
+
35
+ default_cfgs = {
36
+ "fastvit_t": _cfg(crop_pct=0.9),
37
+ "fastvit_s": _cfg(crop_pct=0.9),
38
+ "fastvit_m": _cfg(crop_pct=0.95),
39
+ }
40
+
41
+
42
+ class SEBlock(nn.Module):
43
+ """Squeeze and Excite module.
44
+
45
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
46
+ https://arxiv.org/pdf/1709.01507.pdf
47
+ """
48
+
49
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
50
+ """Construct a Squeeze and Excite Module.
51
+
52
+ Args:
53
+ in_channels: Number of input channels.
54
+ rd_ratio: Input channel reduction ratio.
55
+ """
56
+ super(SEBlock, self).__init__()
57
+ self.reduce = nn.Conv2d(
58
+ in_channels=in_channels,
59
+ out_channels=int(in_channels * rd_ratio),
60
+ kernel_size=1,
61
+ stride=1,
62
+ bias=True,
63
+ )
64
+ self.expand = nn.Conv2d(
65
+ in_channels=int(in_channels * rd_ratio),
66
+ out_channels=in_channels,
67
+ kernel_size=1,
68
+ stride=1,
69
+ bias=True,
70
+ )
71
+
72
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
73
+ """Apply forward pass."""
74
+ b, c, h, w = inputs.size()
75
+ # print('################ h, w = ', h, w)
76
+ x = F.avg_pool2d(inputs, kernel_size=[16, 16])
77
+ x = self.reduce(x)
78
+ x = F.relu(x)
79
+ x = self.expand(x)
80
+ x = torch.sigmoid(x)
81
+ x = x.view(-1, c, 1, 1)
82
+ return inputs * x
83
+
84
+
85
+ class MobileOneBlock(nn.Module):
86
+ """MobileOne building block.
87
+
88
+ This block has a multi-branched architecture at train-time
89
+ and plain-CNN style architecture at inference time
90
+ For more details, please refer to our paper:
91
+ `An Improved One millisecond Mobile Backbone` -
92
+ https://arxiv.org/pdf/2206.04040.pdf
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ in_channels: int,
98
+ out_channels: int,
99
+ kernel_size: int,
100
+ stride: int = 1,
101
+ padding: int = 0,
102
+ dilation: int = 1,
103
+ groups: int = 1,
104
+ inference_mode: bool = False,
105
+ use_se: bool = False,
106
+ use_act: bool = True,
107
+ use_scale_branch: bool = True,
108
+ num_conv_branches: int = 1,
109
+ activation: nn.Module = nn.GELU(),
110
+ ) -> None:
111
+ """Construct a MobileOneBlock module.
112
+
113
+ Args:
114
+ in_channels: Number of channels in the input.
115
+ out_channels: Number of channels produced by the block.
116
+ kernel_size: Size of the convolution kernel.
117
+ stride: Stride size.
118
+ padding: Zero-padding size.
119
+ dilation: Kernel dilation factor.
120
+ groups: Group number.
121
+ inference_mode: If True, instantiates model in inference mode.
122
+ use_se: Whether to use SE-ReLU activations.
123
+ use_act: Whether to use activation. Default: ``True``
124
+ use_scale_branch: Whether to use scale branch. Default: ``True``
125
+ num_conv_branches: Number of linear conv branches.
126
+ """
127
+ super(MobileOneBlock, self).__init__()
128
+ self.inference_mode = inference_mode
129
+ self.groups = groups
130
+ self.stride = stride
131
+ self.padding = padding
132
+ self.dilation = dilation
133
+ self.kernel_size = kernel_size
134
+ self.in_channels = in_channels
135
+ self.out_channels = out_channels
136
+ self.num_conv_branches = num_conv_branches
137
+
138
+ # Check if SE-ReLU is requested
139
+ if use_se:
140
+ self.se = SEBlock(out_channels)
141
+ else:
142
+ self.se = nn.Identity()
143
+
144
+ if use_act:
145
+ self.activation = activation
146
+ else:
147
+ self.activation = nn.Identity()
148
+
149
+ if inference_mode:
150
+ self.reparam_conv = nn.Conv2d(
151
+ in_channels=in_channels,
152
+ out_channels=out_channels,
153
+ kernel_size=kernel_size,
154
+ stride=stride,
155
+ padding=padding,
156
+ dilation=dilation,
157
+ groups=groups,
158
+ bias=True,
159
+ )
160
+ else:
161
+ # Re-parameterizable skip connection
162
+ # Fallback, sometimes batchnorm tensors
163
+ # do not get instantiated correctly on some processes
164
+ # when using deepspeed + accelerate
165
+ norm_layer = nn.BatchNorm2d(num_features=in_channels)
166
+ if norm_layer.weight.shape[0] == 0:
167
+ norm_layer.weight = nn.Parameter(torch.zeros(in_channels))
168
+ if norm_layer.bias.shape[0] == 0:
169
+ norm_layer.bias = nn.Parameter(torch.zeros(in_channels))
170
+
171
+ self.rbr_skip = (
172
+ norm_layer
173
+ if out_channels == in_channels and stride == 1
174
+ else None
175
+ )
176
+
177
+ # Re-parameterizable conv branches
178
+ if num_conv_branches > 0:
179
+ rbr_conv = list()
180
+ for _ in range(self.num_conv_branches):
181
+ rbr_conv.append(
182
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
183
+ )
184
+ self.rbr_conv = nn.ModuleList(rbr_conv)
185
+ else:
186
+ self.rbr_conv = None
187
+
188
+ # Re-parameterizable scale branch
189
+ self.rbr_scale = None
190
+ if not isinstance(kernel_size, int):
191
+ kernel_size = kernel_size[0]
192
+ if (kernel_size > 1) and use_scale_branch:
193
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ """Apply forward pass."""
197
+ # Inference mode forward pass.
198
+ if self.inference_mode:
199
+ return self.activation(self.se(self.reparam_conv(x)))
200
+
201
+ # Multi-branched train-time forward pass.
202
+ # Skip branch output
203
+ identity_out = 0
204
+ if self.rbr_skip is not None:
205
+ identity_out = self.rbr_skip(x)
206
+
207
+ # Scale branch output
208
+ scale_out = 0
209
+ if self.rbr_scale is not None:
210
+ scale_out = self.rbr_scale(x)
211
+
212
+ # Other branches
213
+ out = scale_out + identity_out
214
+ if self.rbr_conv is not None:
215
+ for ix in range(self.num_conv_branches):
216
+ out += self.rbr_conv[ix](x)
217
+
218
+ return self.activation(self.se(out))
219
+
220
+ def reparameterize(self):
221
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
222
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
223
+ architecture used at training time to obtain a plain CNN-like structure
224
+ for inference.
225
+ """
226
+ if self.inference_mode:
227
+ return
228
+ kernel, bias = self._get_kernel_bias()
229
+ self.reparam_conv = nn.Conv2d(
230
+ in_channels=self.in_channels,
231
+ out_channels=self.out_channels,
232
+ kernel_size=self.kernel_size,
233
+ stride=self.stride,
234
+ padding=self.padding,
235
+ dilation=self.dilation,
236
+ groups=self.groups,
237
+ bias=True,
238
+ )
239
+ self.reparam_conv.weight.data = kernel
240
+ self.reparam_conv.bias.data = bias
241
+
242
+ # Delete un-used branches
243
+ self.__delattr__("rbr_conv")
244
+ self.__delattr__("rbr_scale")
245
+ if hasattr(self, "rbr_skip"):
246
+ self.__delattr__("rbr_skip")
247
+
248
+ self.inference_mode = True
249
+
250
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ """Method to obtain re-parameterized kernel and bias.
252
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
253
+
254
+ Returns:
255
+ Tuple of (kernel, bias) after fusing branches.
256
+ """
257
+ # get weights and bias of scale branch
258
+ kernel_scale = 0
259
+ bias_scale = 0
260
+ if self.rbr_scale is not None:
261
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
262
+ # Pad scale branch kernel to match conv branch kernel size.
263
+ pad = self.kernel_size // 2
264
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
265
+
266
+ # get weights and bias of skip branch
267
+ kernel_identity = 0
268
+ bias_identity = 0
269
+ if self.rbr_skip is not None:
270
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
271
+
272
+ # get weights and bias of conv branches
273
+ kernel_conv = 0
274
+ bias_conv = 0
275
+ if self.rbr_conv is not None:
276
+ for ix in range(self.num_conv_branches):
277
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
278
+ kernel_conv += _kernel
279
+ bias_conv += _bias
280
+
281
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
282
+ bias_final = bias_conv + bias_scale + bias_identity
283
+ return kernel_final, bias_final
284
+
285
+ def _fuse_bn_tensor(
286
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
287
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
288
+ """Method to fuse batchnorm layer with preceeding conv layer.
289
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
290
+
291
+ Args:
292
+ branch: Sequence of ops to be fused.
293
+
294
+ Returns:
295
+ Tuple of (kernel, bias) after fusing batchnorm.
296
+ """
297
+ if isinstance(branch, nn.Sequential):
298
+ kernel = branch.conv.weight
299
+ running_mean = branch.bn.running_mean
300
+ running_var = branch.bn.running_var
301
+ gamma = branch.bn.weight
302
+ beta = branch.bn.bias
303
+ eps = branch.bn.eps
304
+ else:
305
+ assert isinstance(branch, nn.BatchNorm2d)
306
+ if not hasattr(self, "id_tensor"):
307
+ input_dim = self.in_channels // self.groups
308
+
309
+ kernel_size = self.kernel_size
310
+ if isinstance(self.kernel_size, int):
311
+ kernel_size = (self.kernel_size, self.kernel_size)
312
+
313
+ kernel_value = torch.zeros(
314
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
315
+ dtype=branch.weight.dtype,
316
+ device=branch.weight.device,
317
+ )
318
+ for i in range(self.in_channels):
319
+ kernel_value[
320
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
321
+ ] = 1
322
+ self.id_tensor = kernel_value
323
+ kernel = self.id_tensor
324
+ running_mean = branch.running_mean
325
+ running_var = branch.running_var
326
+ gamma = branch.weight
327
+ beta = branch.bias
328
+ eps = branch.eps
329
+ std = (running_var + eps).sqrt()
330
+ t = (gamma / std).reshape(-1, 1, 1, 1)
331
+ return kernel * t, beta - running_mean * gamma / std
332
+
333
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
334
+ """Helper method to construct conv-batchnorm layers.
335
+
336
+ Args:
337
+ kernel_size: Size of the convolution kernel.
338
+ padding: Zero-padding size.
339
+
340
+ Returns:
341
+ Conv-BN module.
342
+ """
343
+ # Fallback, sometimes batchnorm tensors
344
+ # do not get instantiated correctly on some processes
345
+ # when using deepspeed + accelerate
346
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
347
+ if norm_layer.weight.shape[0] == 0:
348
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
349
+ if norm_layer.bias.shape[0] == 0:
350
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
351
+
352
+ mod_list = nn.Sequential()
353
+ mod_list.add_module(
354
+ "conv",
355
+ nn.Conv2d(
356
+ in_channels=self.in_channels,
357
+ out_channels=self.out_channels,
358
+ kernel_size=kernel_size,
359
+ stride=self.stride,
360
+ padding=padding,
361
+ groups=self.groups,
362
+ bias=False,
363
+ ),
364
+ )
365
+ mod_list.add_module("bn", norm_layer)
366
+ return mod_list
367
+
368
+
369
+ class ReparamLargeKernelConv(nn.Module):
370
+ """Building Block of RepLKNet
371
+
372
+ This class defines overparameterized large kernel conv block
373
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
374
+
375
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ in_channels: int,
381
+ out_channels: int,
382
+ kernel_size: int,
383
+ stride: int,
384
+ groups: int,
385
+ small_kernel: int,
386
+ inference_mode: bool = False,
387
+ use_se: bool = False,
388
+ activation: nn.Module = nn.GELU(),
389
+ ) -> None:
390
+ """Construct a ReparamLargeKernelConv module.
391
+
392
+ Args:
393
+ in_channels: Number of input channels.
394
+ out_channels: Number of output channels.
395
+ kernel_size: Kernel size of the large kernel conv branch.
396
+ stride: Stride size. Default: 1
397
+ groups: Group number. Default: 1
398
+ small_kernel: Kernel size of small kernel conv branch.
399
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
400
+ activation: Activation module. Default: ``nn.GELU``
401
+ """
402
+ super(ReparamLargeKernelConv, self).__init__()
403
+
404
+ self.stride = stride
405
+ self.groups = groups
406
+ self.in_channels = in_channels
407
+ self.out_channels = out_channels
408
+ self.activation = activation
409
+
410
+ self.kernel_size = kernel_size
411
+ self.small_kernel = small_kernel
412
+ self.padding = kernel_size // 2
413
+
414
+ # Check if SE is requested
415
+ if use_se:
416
+ self.se = SqueezeExcite(out_channels, rd_ratio=0.25)
417
+ else:
418
+ self.se = nn.Identity()
419
+
420
+ if inference_mode:
421
+ self.lkb_reparam = nn.Conv2d(
422
+ in_channels=in_channels,
423
+ out_channels=out_channels,
424
+ kernel_size=kernel_size,
425
+ stride=stride,
426
+ padding=self.padding,
427
+ dilation=1,
428
+ groups=groups,
429
+ bias=True,
430
+ )
431
+ else:
432
+ self.lkb_origin = self._conv_bn(
433
+ kernel_size=kernel_size, padding=self.padding
434
+ )
435
+ if small_kernel is not None:
436
+ assert (
437
+ small_kernel <= kernel_size
438
+ ), "The kernel size for re-param cannot be larger than the large kernel!"
439
+ self.small_conv = self._conv_bn(
440
+ kernel_size=small_kernel, padding=small_kernel // 2
441
+ )
442
+
443
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
444
+ """Apply forward pass."""
445
+ if hasattr(self, "lkb_reparam"):
446
+ out = self.lkb_reparam(x)
447
+ else:
448
+ out = self.lkb_origin(x)
449
+ if hasattr(self, "small_conv"):
450
+ out += self.small_conv(x)
451
+
452
+ return self.activation(self.se(out))
453
+
454
+ def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
455
+ """Method to obtain re-parameterized kernel and bias.
456
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
457
+
458
+ Returns:
459
+ Tuple of (kernel, bias) after fusing branches.
460
+ """
461
+ eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
462
+ if hasattr(self, "small_conv"):
463
+ small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn)
464
+ eq_b += small_b
465
+ eq_k += nn.functional.pad(
466
+ small_k, [(self.kernel_size - self.small_kernel) // 2] * 4
467
+ )
468
+ return eq_k, eq_b
469
+
470
+ def reparameterize(self) -> None:
471
+ """
472
+ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
473
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
474
+ architecture used at training time to obtain a plain CNN-like structure
475
+ for inference.
476
+ """
477
+ eq_k, eq_b = self.get_kernel_bias()
478
+ self.lkb_reparam = nn.Conv2d(
479
+ in_channels=self.in_channels,
480
+ out_channels=self.out_channels,
481
+ kernel_size=self.kernel_size,
482
+ stride=self.stride,
483
+ padding=self.padding,
484
+ dilation=self.lkb_origin.conv.dilation,
485
+ groups=self.groups,
486
+ bias=True,
487
+ )
488
+
489
+ self.lkb_reparam.weight.data = eq_k
490
+ self.lkb_reparam.bias.data = eq_b
491
+ self.__delattr__("lkb_origin")
492
+ if hasattr(self, "small_conv"):
493
+ self.__delattr__("small_conv")
494
+
495
+ @staticmethod
496
+ def _fuse_bn(
497
+ conv: torch.Tensor, bn: nn.BatchNorm2d
498
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
499
+ """Method to fuse batchnorm layer with conv layer.
500
+
501
+ Args:
502
+ conv: Convolutional kernel weights.
503
+ bn: Batchnorm 2d layer.
504
+
505
+ Returns:
506
+ Tuple of (kernel, bias) after fusing batchnorm.
507
+ """
508
+ kernel = conv.weight
509
+ running_mean = bn.running_mean
510
+ running_var = bn.running_var
511
+ gamma = bn.weight
512
+ beta = bn.bias
513
+ eps = bn.eps
514
+ std = (running_var + eps).sqrt()
515
+ t = (gamma / std).reshape(-1, 1, 1, 1)
516
+ return kernel * t, beta - running_mean * gamma / std
517
+
518
+ def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential:
519
+ """Helper method to construct conv-batchnorm layers.
520
+
521
+ Args:
522
+ kernel_size: Size of the convolution kernel.
523
+ padding: Zero-padding size.
524
+
525
+ Returns:
526
+ A nn.Sequential Conv-BN module.
527
+ """
528
+ # Fallback, sometimes batchnorm tensors
529
+ # do not get instantiated correctly on some processes
530
+ # when using deepspeed + accelerate
531
+ norm_layer = nn.BatchNorm2d(num_features=self.out_channels)
532
+ if norm_layer.weight.shape[0] == 0:
533
+ norm_layer.weight = nn.Parameter(torch.zeros(self.out_channels))
534
+ if norm_layer.bias.shape[0] == 0:
535
+ norm_layer.bias = nn.Parameter(torch.zeros(self.out_channels))
536
+
537
+ mod_list = nn.Sequential()
538
+ mod_list.add_module(
539
+ "conv",
540
+ nn.Conv2d(
541
+ in_channels=self.in_channels,
542
+ out_channels=self.out_channels,
543
+ kernel_size=kernel_size,
544
+ stride=self.stride,
545
+ padding=padding,
546
+ groups=self.groups,
547
+ bias=False,
548
+ ),
549
+ )
550
+ mod_list.add_module("bn", norm_layer)
551
+ return mod_list
552
+
553
+
554
+ def convolutional_stem(
555
+ in_channels: int, out_channels: int, inference_mode: bool = False, use_scale_branch: bool = True,
556
+ ) -> nn.Sequential:
557
+ """Build convolutional stem with MobileOne blocks.
558
+
559
+ Args:
560
+ in_channels: Number of input channels.
561
+ out_channels: Number of output channels.
562
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
563
+
564
+ Returns:
565
+ nn.Sequential object with stem elements.
566
+ """
567
+ return nn.Sequential(
568
+ MobileOneBlock(
569
+ in_channels=in_channels,
570
+ out_channels=out_channels,
571
+ kernel_size=3,
572
+ stride=2,
573
+ padding=1,
574
+ groups=1,
575
+ inference_mode=inference_mode,
576
+ use_se=False,
577
+ num_conv_branches=1,
578
+ use_scale_branch=use_scale_branch
579
+ ),
580
+ MobileOneBlock(
581
+ in_channels=out_channels,
582
+ out_channels=out_channels,
583
+ kernel_size=3,
584
+ stride=2,
585
+ padding=1,
586
+ groups=out_channels,
587
+ inference_mode=inference_mode,
588
+ use_se=False,
589
+ num_conv_branches=1,
590
+ use_scale_branch=use_scale_branch
591
+ ),
592
+ MobileOneBlock(
593
+ in_channels=out_channels,
594
+ out_channels=out_channels,
595
+ kernel_size=1,
596
+ stride=1,
597
+ padding=0,
598
+ groups=1,
599
+ inference_mode=inference_mode,
600
+ use_se=False,
601
+ num_conv_branches=1,
602
+ use_scale_branch=use_scale_branch
603
+ ),
604
+ )
605
+
606
+
607
+ class LayerNormChannel(nn.Module):
608
+ """
609
+ LayerNorm only for Channel Dimension.
610
+ Input: tensor in shape [B, C, H, W]
611
+ """
612
+ def __init__(self, num_features, eps=1e-05) -> None:
613
+ super().__init__()
614
+ self.weight = nn.Parameter(torch.ones(num_features))
615
+ self.bias = nn.Parameter(torch.zeros(num_features))
616
+ self.eps = eps
617
+
618
+ def forward(self, x) -> torch.Tensor:
619
+ u = x.mean(1, keepdim=True)
620
+ s = (x - u).pow(2).mean(1, keepdim=True)
621
+ x = (x - u) / torch.sqrt(s + self.eps)
622
+ x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
623
+ + self.bias.unsqueeze(-1).unsqueeze(-1)
624
+ return x
625
+
626
+
627
+ class MHSA(nn.Module):
628
+ """Multi-headed Self Attention module.
629
+
630
+ Source modified from:
631
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
632
+ """
633
+
634
+ def __init__(
635
+ self,
636
+ dim: int,
637
+ head_dim: int = 32,
638
+ qkv_bias: bool = False,
639
+ attn_drop: float = 0.0,
640
+ proj_drop: float = 0.0,
641
+ ) -> None:
642
+ """Build MHSA module that can handle 3D or 4D input tensors.
643
+
644
+ Args:
645
+ dim: Number of embedding dimensions.
646
+ head_dim: Number of hidden dimensions per head. Default: ``32``
647
+ qkv_bias: Use bias or not. Default: ``False``
648
+ attn_drop: Dropout rate for attention tensor.
649
+ proj_drop: Dropout rate for projection tensor.
650
+ """
651
+ super().__init__()
652
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
653
+ self.head_dim = head_dim
654
+ self.num_heads = dim // head_dim
655
+ self.scale = head_dim**-0.5
656
+
657
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
658
+ self.attn_drop = nn.Dropout(attn_drop)
659
+ self.proj = nn.Linear(dim, dim)
660
+ self.proj_drop = nn.Dropout(proj_drop)
661
+
662
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
663
+ shape = x.shape
664
+ B, C, H, W = shape
665
+ N = H * W
666
+ if len(shape) == 4:
667
+ x = torch.flatten(x, start_dim=2).transpose(-2, -1) # (B, N, C)
668
+ qkv = (
669
+ self.qkv(x)
670
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
671
+ .permute(2, 0, 3, 1, 4)
672
+ )
673
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
674
+
675
+ # trick here to make [email protected] more stable
676
+ attn = (q * self.scale) @ k.transpose(-2, -1)
677
+ attn = attn.softmax(dim=-1)
678
+ attn = self.attn_drop(attn)
679
+
680
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
681
+ x = self.proj(x)
682
+ x = self.proj_drop(x)
683
+ if len(shape) == 4:
684
+ x = x.transpose(-2, -1).reshape(B, C, H, W)
685
+
686
+ return x
687
+
688
+
689
+ class PatchEmbed(nn.Module):
690
+ """Convolutional patch embedding layer."""
691
+
692
+ def __init__(
693
+ self,
694
+ patch_size: int,
695
+ stride: int,
696
+ in_channels: int,
697
+ embed_dim: int,
698
+ inference_mode: bool = False,
699
+ use_se: bool = False,
700
+ ) -> None:
701
+ """Build patch embedding layer.
702
+
703
+ Args:
704
+ patch_size: Patch size for embedding computation.
705
+ stride: Stride for convolutional embedding layer.
706
+ in_channels: Number of channels of input tensor.
707
+ embed_dim: Number of embedding dimensions.
708
+ inference_mode: Flag to instantiate model in inference mode. Default: ``False``
709
+ use_se: If ``True`` SE block will be used.
710
+ """
711
+ super().__init__()
712
+ block = list()
713
+ block.append(
714
+ ReparamLargeKernelConv(
715
+ in_channels=in_channels,
716
+ out_channels=embed_dim,
717
+ kernel_size=patch_size,
718
+ stride=stride,
719
+ groups=in_channels,
720
+ small_kernel=3,
721
+ inference_mode=inference_mode,
722
+ use_se=use_se,
723
+ )
724
+ )
725
+ block.append(
726
+ MobileOneBlock(
727
+ in_channels=embed_dim,
728
+ out_channels=embed_dim,
729
+ kernel_size=1,
730
+ stride=1,
731
+ padding=0,
732
+ groups=1,
733
+ inference_mode=inference_mode,
734
+ use_se=False,
735
+ num_conv_branches=1,
736
+ )
737
+ )
738
+ self.proj = nn.Sequential(*block)
739
+
740
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
741
+ x = self.proj(x)
742
+ return x
743
+
744
+
745
+ class RepMixer(nn.Module):
746
+ """Reparameterizable token mixer.
747
+
748
+ For more details, please refer to our paper:
749
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
750
+ """
751
+
752
+ def __init__(
753
+ self,
754
+ dim,
755
+ kernel_size=3,
756
+ use_layer_scale=True,
757
+ layer_scale_init_value=1e-5,
758
+ inference_mode: bool = False,
759
+ ):
760
+ """Build RepMixer Module.
761
+
762
+ Args:
763
+ dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`.
764
+ kernel_size: Kernel size for spatial mixing. Default: 3
765
+ use_layer_scale: If True, learnable layer scale is used. Default: ``True``
766
+ layer_scale_init_value: Initial value for layer scale. Default: 1e-5
767
+ inference_mode: If True, instantiates model in inference mode. Default: ``False``
768
+ """
769
+ super().__init__()
770
+ self.dim = dim
771
+ self.kernel_size = kernel_size
772
+ self.inference_mode = inference_mode
773
+
774
+ if inference_mode:
775
+ self.reparam_conv = nn.Conv2d(
776
+ in_channels=self.dim,
777
+ out_channels=self.dim,
778
+ kernel_size=self.kernel_size,
779
+ stride=1,
780
+ padding=self.kernel_size // 2,
781
+ groups=self.dim,
782
+ bias=True,
783
+ )
784
+ else:
785
+ self.norm = MobileOneBlock(
786
+ dim,
787
+ dim,
788
+ kernel_size,
789
+ padding=kernel_size // 2,
790
+ groups=dim,
791
+ use_act=False,
792
+ use_scale_branch=False,
793
+ num_conv_branches=0,
794
+ )
795
+ self.mixer = MobileOneBlock(
796
+ dim,
797
+ dim,
798
+ kernel_size,
799
+ padding=kernel_size // 2,
800
+ groups=dim,
801
+ use_act=False,
802
+ )
803
+ self.use_layer_scale = use_layer_scale
804
+ if use_layer_scale:
805
+ self.layer_scale = nn.Parameter(
806
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
807
+ )
808
+
809
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
810
+ if hasattr(self, "reparam_conv"):
811
+ x = self.reparam_conv(x)
812
+ return x
813
+ else:
814
+ if self.use_layer_scale:
815
+ x = x + self.layer_scale * (self.mixer(x) - self.norm(x))
816
+ else:
817
+ x = x + self.mixer(x) - self.norm(x)
818
+ return x
819
+
820
+ def reparameterize(self) -> None:
821
+ """Reparameterize mixer and norm into a single
822
+ convolutional layer for efficient inference.
823
+ """
824
+ if self.inference_mode:
825
+ return
826
+
827
+ self.mixer.reparameterize()
828
+ self.norm.reparameterize()
829
+
830
+ if self.use_layer_scale:
831
+ w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * (
832
+ self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight
833
+ )
834
+ b = torch.squeeze(self.layer_scale) * (
835
+ self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
836
+ )
837
+ else:
838
+ w = (
839
+ self.mixer.id_tensor
840
+ + self.mixer.reparam_conv.weight
841
+ - self.norm.reparam_conv.weight
842
+ )
843
+ b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias
844
+
845
+ self.reparam_conv = nn.Conv2d(
846
+ in_channels=self.dim,
847
+ out_channels=self.dim,
848
+ kernel_size=self.kernel_size,
849
+ stride=1,
850
+ padding=self.kernel_size // 2,
851
+ groups=self.dim,
852
+ bias=True,
853
+ )
854
+ self.reparam_conv.weight.data = w
855
+ self.reparam_conv.bias.data = b
856
+
857
+ self.__delattr__("mixer")
858
+ self.__delattr__("norm")
859
+ if self.use_layer_scale:
860
+ self.__delattr__("layer_scale")
861
+
862
+
863
+ class ConvFFN(nn.Module):
864
+ """Convolutional FFN Module."""
865
+
866
+ def __init__(
867
+ self,
868
+ in_channels: int,
869
+ hidden_channels: Optional[int] = None,
870
+ out_channels: Optional[int] = None,
871
+ act_layer: nn.Module = nn.GELU,
872
+ drop: float = 0.0,
873
+ ) -> None:
874
+ """Build convolutional FFN module.
875
+
876
+ Args:
877
+ in_channels: Number of input channels.
878
+ hidden_channels: Number of channels after expansion. Default: None
879
+ out_channels: Number of output channels. Default: None
880
+ act_layer: Activation layer. Default: ``GELU``
881
+ drop: Dropout rate. Default: ``0.0``.
882
+ """
883
+ super().__init__()
884
+ out_channels = out_channels or in_channels
885
+ hidden_channels = hidden_channels or in_channels
886
+ self.conv = nn.Sequential()
887
+ self.conv.add_module(
888
+ "conv",
889
+ nn.Conv2d(
890
+ in_channels=in_channels,
891
+ out_channels=out_channels,
892
+ kernel_size=7,
893
+ padding=3,
894
+ groups=in_channels,
895
+ bias=False,
896
+ ),
897
+ )
898
+
899
+ # Fallback, sometimes batchnorm tensors
900
+ # do not get instantiated correctly on some processes
901
+ # when using deepspeed + accelerate
902
+ norm_layer = nn.BatchNorm2d(num_features=out_channels)
903
+ if norm_layer.weight.shape[0] == 0:
904
+ norm_layer.weight = nn.Parameter(torch.zeros(out_channels))
905
+ if norm_layer.bias.shape[0] == 0:
906
+ norm_layer.bias = nn.Parameter(torch.zeros(out_channels))
907
+
908
+ self.conv.add_module("bn", norm_layer)
909
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
910
+ self.act = act_layer()
911
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
912
+ self.drop = nn.Dropout(drop)
913
+ self.apply(self._init_weights)
914
+
915
+ def _init_weights(self, m: nn.Module) -> None:
916
+ if isinstance(m, nn.Conv2d):
917
+ normal_(m.weight, std=0.02)
918
+ if m.bias is not None:
919
+ nn.init.constant_(m.bias, 0)
920
+
921
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
922
+ x = self.conv(x)
923
+ x = self.fc1(x)
924
+ x = self.act(x)
925
+ x = self.drop(x)
926
+ x = self.fc2(x)
927
+ x = self.drop(x)
928
+ return x
929
+
930
+
931
+ class RepCPE(nn.Module):
932
+ """Implementation of conditional positional encoding.
933
+
934
+ For more details refer to paper:
935
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
936
+
937
+ In our implementation, we can reparameterize this module to eliminate a skip connection.
938
+ """
939
+
940
+ def __init__(
941
+ self,
942
+ in_channels: int,
943
+ embed_dim: int = 768,
944
+ spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
945
+ inference_mode=False,
946
+ ) -> None:
947
+ """Build reparameterizable conditional positional encoding
948
+
949
+ Args:
950
+ in_channels: Number of input channels.
951
+ embed_dim: Number of embedding dimensions. Default: 768
952
+ spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
953
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
954
+ """
955
+ super(RepCPE, self).__init__()
956
+ if isinstance(spatial_shape, int):
957
+ spatial_shape = tuple([spatial_shape] * 2)
958
+ assert isinstance(spatial_shape, Tuple), (
959
+ f'"spatial_shape" must by a sequence or int, '
960
+ f"get {type(spatial_shape)} instead."
961
+ )
962
+ assert len(spatial_shape) == 2, (
963
+ f'Length of "spatial_shape" should be 2, '
964
+ f"got {len(spatial_shape)} instead."
965
+ )
966
+
967
+ self.spatial_shape = spatial_shape
968
+ self.embed_dim = embed_dim
969
+ self.in_channels = in_channels
970
+ self.groups = embed_dim
971
+
972
+ if inference_mode:
973
+ self.reparam_conv = nn.Conv2d(
974
+ in_channels=self.in_channels,
975
+ out_channels=self.embed_dim,
976
+ kernel_size=self.spatial_shape,
977
+ stride=1,
978
+ padding=int(self.spatial_shape[0] // 2),
979
+ groups=self.embed_dim,
980
+ bias=True,
981
+ )
982
+ else:
983
+ self.pe = nn.Conv2d(
984
+ in_channels,
985
+ embed_dim,
986
+ spatial_shape,
987
+ 1,
988
+ int(spatial_shape[0] // 2),
989
+ bias=True,
990
+ groups=embed_dim,
991
+ )
992
+
993
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
994
+ if hasattr(self, "reparam_conv"):
995
+ x = self.reparam_conv(x)
996
+ return x
997
+ else:
998
+ x = self.pe(x) + x
999
+ return x
1000
+
1001
+ def reparameterize(self) -> None:
1002
+ # Build equivalent Id tensor
1003
+ input_dim = self.in_channels // self.groups
1004
+ kernel_value = torch.zeros(
1005
+ (
1006
+ self.in_channels,
1007
+ input_dim,
1008
+ self.spatial_shape[0],
1009
+ self.spatial_shape[1],
1010
+ ),
1011
+ dtype=self.pe.weight.dtype,
1012
+ device=self.pe.weight.device,
1013
+ )
1014
+ for i in range(self.in_channels):
1015
+ kernel_value[
1016
+ i,
1017
+ i % input_dim,
1018
+ self.spatial_shape[0] // 2,
1019
+ self.spatial_shape[1] // 2,
1020
+ ] = 1
1021
+ id_tensor = kernel_value
1022
+
1023
+ # Reparameterize Id tensor and conv
1024
+ w_final = id_tensor + self.pe.weight
1025
+ b_final = self.pe.bias
1026
+
1027
+ # Introduce reparam conv
1028
+ self.reparam_conv = nn.Conv2d(
1029
+ in_channels=self.in_channels,
1030
+ out_channels=self.embed_dim,
1031
+ kernel_size=self.spatial_shape,
1032
+ stride=1,
1033
+ padding=int(self.spatial_shape[0] // 2),
1034
+ groups=self.embed_dim,
1035
+ bias=True,
1036
+ )
1037
+ self.reparam_conv.weight.data = w_final
1038
+ self.reparam_conv.bias.data = b_final
1039
+
1040
+ self.__delattr__("pe")
1041
+
1042
+
1043
+ class RepMixerBlock(nn.Module):
1044
+ """Implementation of Metaformer block with RepMixer as token mixer.
1045
+
1046
+ For more details on Metaformer structure, please refer to:
1047
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1048
+ """
1049
+
1050
+ def __init__(
1051
+ self,
1052
+ dim: int,
1053
+ kernel_size: int = 3,
1054
+ mlp_ratio: float = 4.0,
1055
+ act_layer: nn.Module = nn.GELU,
1056
+ drop: float = 0.0,
1057
+ drop_path: float = 0.0,
1058
+ use_layer_scale: bool = True,
1059
+ layer_scale_init_value: float = 1e-5,
1060
+ inference_mode: bool = False,
1061
+ ):
1062
+ """Build RepMixer Block.
1063
+
1064
+ Args:
1065
+ dim: Number of embedding dimensions.
1066
+ kernel_size: Kernel size for repmixer. Default: 3
1067
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1068
+ act_layer: Activation layer. Default: ``nn.GELU``
1069
+ drop: Dropout rate. Default: 0.0
1070
+ drop_path: Drop path rate. Default: 0.0
1071
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1072
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1073
+ inference_mode: Flag to instantiate block in inference mode. Default: ``False``
1074
+ """
1075
+
1076
+ super().__init__()
1077
+
1078
+ self.token_mixer = RepMixer(
1079
+ dim,
1080
+ kernel_size=kernel_size,
1081
+ use_layer_scale=use_layer_scale,
1082
+ layer_scale_init_value=layer_scale_init_value,
1083
+ inference_mode=inference_mode,
1084
+ )
1085
+
1086
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1087
+ mlp_ratio
1088
+ )
1089
+ mlp_hidden_dim = int(dim * mlp_ratio)
1090
+ self.convffn = ConvFFN(
1091
+ in_channels=dim,
1092
+ hidden_channels=mlp_hidden_dim,
1093
+ act_layer=act_layer,
1094
+ drop=drop,
1095
+ )
1096
+
1097
+ # Drop Path
1098
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1099
+
1100
+ # Layer Scale
1101
+ self.use_layer_scale = use_layer_scale
1102
+ if use_layer_scale:
1103
+ self.layer_scale = nn.Parameter(
1104
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1105
+ )
1106
+
1107
+ def forward(self, x):
1108
+ if self.use_layer_scale:
1109
+ x = self.token_mixer(x)
1110
+ x = x + self.drop_path(self.layer_scale * self.convffn(x))
1111
+ else:
1112
+ x = self.token_mixer(x)
1113
+ x = x + self.drop_path(self.convffn(x))
1114
+ return x
1115
+
1116
+
1117
+ class AttentionBlock(nn.Module):
1118
+ """Implementation of metaformer block with MHSA as token mixer.
1119
+
1120
+ For more details on Metaformer structure, please refer to:
1121
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
1122
+ """
1123
+
1124
+ def __init__(
1125
+ self,
1126
+ dim: int,
1127
+ mlp_ratio: float = 4.0,
1128
+ act_layer: nn.Module = nn.GELU,
1129
+ norm_layer: nn.Module = nn.BatchNorm2d,
1130
+ drop: float = 0.0,
1131
+ drop_path: float = 0.0,
1132
+ use_layer_scale: bool = True,
1133
+ layer_scale_init_value: float = 1e-5,
1134
+ ):
1135
+ """Build Attention Block.
1136
+
1137
+ Args:
1138
+ dim: Number of embedding dimensions.
1139
+ mlp_ratio: MLP expansion ratio. Default: 4.0
1140
+ act_layer: Activation layer. Default: ``nn.GELU``
1141
+ norm_layer: Normalization layer. Default: ``nn.BatchNorm2d``
1142
+ drop: Dropout rate. Default: 0.0
1143
+ drop_path: Drop path rate. Default: 0.0
1144
+ use_layer_scale: Flag to turn on layer scale. Default: ``True``
1145
+ layer_scale_init_value: Layer scale value at initialization. Default: 1e-5
1146
+ """
1147
+
1148
+ super().__init__()
1149
+
1150
+ # Fallback, sometimes batchnorm tensors
1151
+ # do not get instantiated correctly on some processes
1152
+ # when using deepspeed + accelerate
1153
+ norm_layer_ = norm_layer(num_features=dim)
1154
+ if norm_layer_.weight.shape[0] == 0:
1155
+ norm_layer_.weight = nn.Parameter(torch.zeros(dim))
1156
+ if norm_layer_.bias.shape[0] == 0:
1157
+ norm_layer_.bias = nn.Parameter(torch.zeros(dim))
1158
+
1159
+ self.norm = norm_layer_
1160
+ self.token_mixer = MHSA(dim=dim)
1161
+
1162
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
1163
+ mlp_ratio
1164
+ )
1165
+ mlp_hidden_dim = int(dim * mlp_ratio)
1166
+ self.convffn = ConvFFN(
1167
+ in_channels=dim,
1168
+ hidden_channels=mlp_hidden_dim,
1169
+ act_layer=act_layer,
1170
+ drop=drop,
1171
+ )
1172
+
1173
+ # Drop path
1174
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1175
+
1176
+ # Layer Scale
1177
+ self.use_layer_scale = use_layer_scale
1178
+ if use_layer_scale:
1179
+ self.layer_scale_1 = nn.Parameter(
1180
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1181
+ )
1182
+ self.layer_scale_2 = nn.Parameter(
1183
+ layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
1184
+ )
1185
+
1186
+ def forward(self, x):
1187
+ if self.use_layer_scale:
1188
+ x = x + self.drop_path(self.layer_scale_1 * self.token_mixer(self.norm(x)))
1189
+ x = x + self.drop_path(self.layer_scale_2 * self.convffn(x))
1190
+ else:
1191
+ x = x + self.drop_path(self.token_mixer(self.norm(x)))
1192
+ x = x + self.drop_path(self.convffn(x))
1193
+ return x
1194
+
1195
+
1196
+ def basic_blocks(
1197
+ dim: int,
1198
+ block_index: int,
1199
+ num_blocks: List[int],
1200
+ token_mixer_type: str,
1201
+ kernel_size: int = 3,
1202
+ mlp_ratio: float = 4.0,
1203
+ act_layer: nn.Module = nn.GELU,
1204
+ norm_layer: nn.Module = nn.BatchNorm2d,
1205
+ drop_rate: float = 0.0,
1206
+ drop_path_rate: float = 0.0,
1207
+ use_layer_scale: bool = True,
1208
+ layer_scale_init_value: float = 1e-5,
1209
+ inference_mode=False,
1210
+ ) -> nn.Sequential:
1211
+ """Build FastViT blocks within a stage.
1212
+
1213
+ Args:
1214
+ dim: Number of embedding dimensions.
1215
+ block_index: block index.
1216
+ num_blocks: List containing number of blocks per stage.
1217
+ token_mixer_type: Token mixer type.
1218
+ kernel_size: Kernel size for repmixer.
1219
+ mlp_ratio: MLP expansion ratio.
1220
+ act_layer: Activation layer.
1221
+ norm_layer: Normalization layer.
1222
+ drop_rate: Dropout rate.
1223
+ drop_path_rate: Drop path rate.
1224
+ use_layer_scale: Flag to turn on layer scale regularization.
1225
+ layer_scale_init_value: Layer scale value at initialization.
1226
+ inference_mode: Flag to instantiate block in inference mode.
1227
+
1228
+ Returns:
1229
+ nn.Sequential object of all the blocks within the stage.
1230
+ """
1231
+ blocks = []
1232
+ for block_idx in range(num_blocks[block_index]):
1233
+ block_dpr = (
1234
+ drop_path_rate
1235
+ * (block_idx + sum(num_blocks[:block_index]))
1236
+ / (sum(num_blocks) - 1)
1237
+ )
1238
+ if token_mixer_type == "repmixer":
1239
+ blocks.append(
1240
+ RepMixerBlock(
1241
+ dim,
1242
+ kernel_size=kernel_size,
1243
+ mlp_ratio=mlp_ratio,
1244
+ act_layer=act_layer,
1245
+ drop=drop_rate,
1246
+ drop_path=block_dpr,
1247
+ use_layer_scale=use_layer_scale,
1248
+ layer_scale_init_value=layer_scale_init_value,
1249
+ inference_mode=inference_mode,
1250
+ )
1251
+ )
1252
+ elif token_mixer_type == "attention":
1253
+ blocks.append(
1254
+ AttentionBlock(
1255
+ dim,
1256
+ mlp_ratio=mlp_ratio,
1257
+ act_layer=act_layer,
1258
+ norm_layer=norm_layer,
1259
+ drop=drop_rate,
1260
+ drop_path=block_dpr,
1261
+ use_layer_scale=use_layer_scale,
1262
+ layer_scale_init_value=layer_scale_init_value,
1263
+ )
1264
+ )
1265
+ else:
1266
+ raise ValueError(
1267
+ "Token mixer type: {} not supported".format(token_mixer_type)
1268
+ )
1269
+ blocks = nn.Sequential(*blocks)
1270
+ return blocks
1271
+
1272
+
1273
+ class GlobalPool2D(nn.Module):
1274
+ """This class implements global pooling with linear projection."""
1275
+
1276
+ def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
1277
+ super().__init__()
1278
+ scale = in_dim**-0.5
1279
+ self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
1280
+ self.in_dim = in_dim
1281
+ self.out_dim = out_dim
1282
+
1283
+ def pool(self, x) -> Tensor:
1284
+ if x.dim() == 4:
1285
+ dims = [-2, -1]
1286
+ elif x.dim() == 5:
1287
+ dims = [-3, -2, -1]
1288
+ x = torch.mean(x, dim=dims, keepdim=False)
1289
+ return x
1290
+
1291
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
1292
+ # x is of shape [batch, in_dim]
1293
+ assert (
1294
+ x.dim() == 4
1295
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
1296
+ x.shape
1297
+ )
1298
+
1299
+ # [batch, in_dim, in_height, in_width] --> [batch, in_dim]
1300
+ x = self.pool(x)
1301
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
1302
+ x = x @ self.proj
1303
+ return x
1304
+
1305
+
1306
+ class FastViT(nn.Module):
1307
+ """
1308
+ This class implements `FastViT architecture <https://arxiv.org/pdf/2303.14189.pdf>`_
1309
+ """
1310
+
1311
+ def __init__(
1312
+ self,
1313
+ layers,
1314
+ token_mixers: Tuple[str, ...],
1315
+ embed_dims=None,
1316
+ mlp_ratios=None,
1317
+ downsamples=None,
1318
+ se_downsamples=None,
1319
+ repmixer_kernel_size=3,
1320
+ norm_layer: nn.Module = nn.BatchNorm2d,
1321
+ act_layer: nn.Module = nn.GELU,
1322
+ num_classes=1000,
1323
+ pos_embs=None,
1324
+ down_patch_size=7,
1325
+ down_stride=2,
1326
+ drop_rate=0.0,
1327
+ drop_path_rate=0.0,
1328
+ use_layer_scale=True,
1329
+ layer_scale_init_value=1e-5,
1330
+ init_cfg=None,
1331
+ pretrained=None,
1332
+ cls_ratio=2.0,
1333
+ inference_mode=False,
1334
+ stem_scale_branch=True,
1335
+ **kwargs,
1336
+ ) -> None:
1337
+
1338
+ super().__init__()
1339
+
1340
+ self.num_classes = num_classes
1341
+ if len(layers) == 4:
1342
+ self.out_indices = [0, 2, 4, 7]
1343
+ elif len(layers) == 5:
1344
+ self.out_indices = [0, 2, 4, 7, 10]
1345
+ else:
1346
+ raise NotImplementedError("FPN is not implemented for more than 5 stages.")
1347
+
1348
+ if pos_embs is None:
1349
+ pos_embs = [None] * len(layers)
1350
+
1351
+ if se_downsamples is None:
1352
+ se_downsamples = [False] * len(layers)
1353
+
1354
+ # Convolutional stem
1355
+ self.patch_embed = convolutional_stem(3, embed_dims[0], inference_mode,
1356
+ use_scale_branch=stem_scale_branch)
1357
+
1358
+ # Build the main stages of the network architecture
1359
+ network = []
1360
+ for i in range(len(layers)):
1361
+ # Add position embeddings if requested
1362
+ if pos_embs[i] is not None:
1363
+ network.append(
1364
+ pos_embs[i](
1365
+ embed_dims[i], embed_dims[i], inference_mode=inference_mode
1366
+ )
1367
+ )
1368
+ stage = basic_blocks(
1369
+ embed_dims[i],
1370
+ i,
1371
+ layers,
1372
+ token_mixer_type=token_mixers[i],
1373
+ kernel_size=repmixer_kernel_size,
1374
+ mlp_ratio=mlp_ratios[i],
1375
+ act_layer=act_layer,
1376
+ norm_layer=norm_layer,
1377
+ drop_rate=drop_rate,
1378
+ drop_path_rate=drop_path_rate,
1379
+ use_layer_scale=use_layer_scale,
1380
+ layer_scale_init_value=layer_scale_init_value,
1381
+ inference_mode=inference_mode,
1382
+ )
1383
+ network.append(stage)
1384
+ if i >= len(layers) - 1:
1385
+ break
1386
+
1387
+ # Patch merging/downsampling between stages.
1388
+ if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
1389
+ network.append(
1390
+ PatchEmbed(
1391
+ patch_size=down_patch_size,
1392
+ stride=down_stride,
1393
+ in_channels=embed_dims[i],
1394
+ embed_dim=embed_dims[i + 1],
1395
+ inference_mode=inference_mode,
1396
+ use_se=se_downsamples[i + 1],
1397
+ )
1398
+ )
1399
+ self.network = nn.ModuleList(network)
1400
+
1401
+ # Classifier head
1402
+ self.conv_exp = MobileOneBlock(
1403
+ in_channels=embed_dims[-1],
1404
+ out_channels=int(embed_dims[-1] * cls_ratio),
1405
+ kernel_size=3,
1406
+ stride=1,
1407
+ padding=1,
1408
+ groups=embed_dims[-1],
1409
+ inference_mode=inference_mode,
1410
+ use_se=True,
1411
+ num_conv_branches=1,
1412
+ )
1413
+ self.head = (
1414
+ nn.Linear(int(embed_dims[-1] * cls_ratio), num_classes)
1415
+ if num_classes > 0
1416
+ else nn.Identity()
1417
+ )
1418
+ self.apply(self.cls_init_weights)
1419
+ self.init_cfg = copy.deepcopy(init_cfg)
1420
+
1421
+ def cls_init_weights(self, m: nn.Module) -> None:
1422
+ """Init. for classification"""
1423
+ if isinstance(m, nn.Linear):
1424
+ normal_(m.weight, std=0.02)
1425
+ if isinstance(m, nn.Linear) and m.bias is not None:
1426
+ nn.init.constant_(m.bias, 0)
1427
+
1428
+ def forward_embeddings(self, x: torch.Tensor) -> torch.Tensor:
1429
+ x = self.patch_embed(x)
1430
+ return x
1431
+
1432
+ def forward_tokens(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1433
+ for idx, block in enumerate(self.network):
1434
+ x = block(x)
1435
+ return x
1436
+
1437
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
1438
+ # input embedding
1439
+ x = self.forward_embeddings(x)
1440
+ # through backbone
1441
+ x = self.forward_tokens(x)
1442
+ # for image classification/embedding
1443
+ x = self.conv_exp(x)
1444
+ cls_out = self.head(x)
1445
+
1446
+ out_dict = dict()
1447
+ if kwargs.get("return_image_embeddings", False):
1448
+ out_dict.update({"logits": cls_out})
1449
+ out_dict.update({"image_embeddings": x})
1450
+ return out_dict
1451
+ else:
1452
+ return cls_out
1453
+
1454
+
1455
+ @register_model
1456
+ def fastvithd(pretrained=False, **kwargs):
1457
+ """Instantiate FastViTHD model variant."""
1458
+ layers = [2, 12, 24, 4, 2]
1459
+ embed_dims = [96, 192, 384, 768, 1536]
1460
+ mlp_ratios = [4, 4, 4, 4, 4]
1461
+ downsamples = [True, True, True, True, True]
1462
+ pos_embs = [None, None, None, partial(RepCPE, spatial_shape=(7, 7)), partial(RepCPE, spatial_shape=(7, 7))]
1463
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
1464
+ model = FastViT(
1465
+ layers,
1466
+ token_mixers=token_mixers,
1467
+ embed_dims=embed_dims,
1468
+ pos_embs=pos_embs,
1469
+ mlp_ratios=mlp_ratios,
1470
+ downsamples=downsamples,
1471
+ norm_layer=LayerNormChannel,
1472
+ stem_scale_branch=False,
1473
+ inference_mode=True,
1474
+ **kwargs,
1475
+ )
1476
+ model.default_cfg = default_cfgs["fastvit_m"]
1477
+ if pretrained:
1478
+ raise ValueError("Functionality not implemented.")
1479
+ return model
llava/model/multimodal_encoder/mobileclip_encoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from transformers import CLIPImageProcessor
10
+ import llava.model.multimodal_encoder.mobileclip as mobileclip
11
+
12
+
13
+ class MobileCLIPVisionTower(nn.Module):
14
+ def __init__(self, vision_tower, args, delay_load=False):
15
+ super().__init__()
16
+
17
+ self.is_loaded = False
18
+ self.vision_tower_name = vision_tower
19
+ self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
20
+ self.input_image_size = int(vision_tower.split("_")[-1])
21
+
22
+ # Delay load is disabled for now
23
+ if not delay_load:
24
+ self.load_model()
25
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
26
+ self.load_model()
27
+ else:
28
+ model_cfg = mobileclip.load_model_config(self.vision_tower_name)
29
+ self.cfg_only = model_cfg
30
+
31
+ def load_model(self, device_map=None):
32
+ if self.is_loaded:
33
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
34
+ return
35
+
36
+ # Load model config
37
+ model_cfg = mobileclip.load_model_config(self.vision_tower_name)
38
+
39
+ # Override default image resolution
40
+ model_cfg["image_cfg"]["image_size"] = self.input_image_size
41
+
42
+ self.cfg_only = model_cfg
43
+
44
+ # Build HF CLIPImageProcessor with MobileCLIP parameters
45
+ self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
46
+ "width": model_cfg["image_cfg"]["image_size"]},
47
+ image_mean=[0.0, 0.0, 0.0],
48
+ image_std=[1.0, 1.0, 1.0],
49
+ size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
50
+
51
+ # Instantiate the image encoder
52
+ self.vision_tower = mobileclip.MCi(model_name=model_cfg["image_cfg"]["model_name"],
53
+ projection_dim=model_cfg["embed_dim"])
54
+
55
+ if not self.tune_vision_tower:
56
+ self.vision_tower.requires_grad_(False)
57
+
58
+ self.is_loaded = True
59
+
60
+ def feature_select(self, image_forward_outs):
61
+ # Features from penultimate layer
62
+ image_features = image_forward_outs["image_embeddings"]
63
+
64
+ # Reshape 4D tensor to 3D
65
+ B, C, H, W = image_features.shape
66
+ image_features = image_features.reshape(B, C, H*W)
67
+ image_features = image_features.transpose(1, 2)
68
+ return image_features
69
+
70
+ def forward(self, images):
71
+ if self.tune_vision_tower:
72
+ return self.forward_images(images)
73
+ else:
74
+ with torch.no_grad():
75
+ return self.forward_images(images)
76
+
77
+ def forward_images(self, images):
78
+ if type(images) is list:
79
+ image_features = []
80
+ for image in images:
81
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
82
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
83
+ image_features.append(image_feature)
84
+ else:
85
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
86
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
87
+
88
+ return image_features
89
+
90
+ @property
91
+ def dummy_feature(self):
92
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
93
+
94
+ @property
95
+ def dtype(self):
96
+ return next(self.vision_tower.parameters()).dtype
97
+
98
+ @property
99
+ def device(self):
100
+ return next(self.vision_tower.parameters()).device
101
+
102
+ @property
103
+ def config(self):
104
+ return self.cfg_only
105
+
106
+ @property
107
+ def hidden_size(self):
108
+ return self.config["image_cfg"]["embed_dim"]
109
+
110
+ @property
111
+ def num_patches_per_side(self):
112
+ return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
113
+
114
+ @property
115
+ def num_patches(self):
116
+ return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
llava/model/multimodal_projector/__pycache__/builder.cpython-312.pyc ADDED
Binary file (2.26 kB). View file
 
llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import re
3
+
4
+
5
+ class IdentityMap(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_projector_type": 'identity'}
15
+
16
+
17
+ def build_vision_projector(config, delay_load=False, **kwargs):
18
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
19
+
20
+ if projector_type == 'linear':
21
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
22
+
23
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
24
+ if mlp_gelu_match:
25
+ mlp_depth = int(mlp_gelu_match.group(1))
26
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
27
+ for _ in range(1, mlp_depth):
28
+ modules.append(nn.GELU())
29
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
30
+ return nn.Sequential(*modules)
31
+
32
+ if projector_type == 'identity':
33
+ return IdentityMap()
34
+
35
+ raise ValueError(f'Unknown projector type: {projector_type}')
llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
llava/serve/__init__.py ADDED
File without changes
llava/serve/cli.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from llava.conversation import conv_templates, SeparatorStyle
6
+ from llava.model.builder import load_pretrained_model
7
+ from llava.utils import disable_torch_init
8
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
+ response = requests.get(image_file)
21
+ image = Image.open(BytesIO(response.content)).convert('RGB')
22
+ else:
23
+ image = Image.open(image_file).convert('RGB')
24
+ return image
25
+
26
+
27
+ def main(args):
28
+ # Model
29
+ disable_torch_init()
30
+
31
+ model_name = get_model_name_from_path(args.model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
+
34
+ if "llama-2" in model_name.lower():
35
+ conv_mode = "llava_llama_2"
36
+ elif "mistral" in model_name.lower():
37
+ conv_mode = "mistral_instruct"
38
+ elif "v1.6-34b" in model_name.lower():
39
+ conv_mode = "chatml_direct"
40
+ elif "v1" in model_name.lower():
41
+ conv_mode = "llava_v1"
42
+ elif "mpt" in model_name.lower():
43
+ conv_mode = "mpt"
44
+ else:
45
+ conv_mode = "llava_v0"
46
+
47
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
48
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
49
+ else:
50
+ args.conv_mode = conv_mode
51
+
52
+ conv = conv_templates[args.conv_mode].copy()
53
+ if "mpt" in model_name.lower():
54
+ roles = ('user', 'assistant')
55
+ else:
56
+ roles = conv.roles
57
+
58
+ image = load_image(args.image_file)
59
+ image_size = image.size
60
+ # Similar operation in model_worker.py
61
+ image_tensor = process_images([image], image_processor, model.config)
62
+ if type(image_tensor) is list:
63
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
64
+ else:
65
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
66
+
67
+ while True:
68
+ try:
69
+ inp = input(f"{roles[0]}: ")
70
+ except EOFError:
71
+ inp = ""
72
+ if not inp:
73
+ print("exit...")
74
+ break
75
+
76
+ print(f"{roles[1]}: ", end="")
77
+
78
+ if image is not None:
79
+ # first message
80
+ if model.config.mm_use_im_start_end:
81
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
82
+ else:
83
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
84
+ image = None
85
+
86
+ conv.append_message(conv.roles[0], inp)
87
+ conv.append_message(conv.roles[1], None)
88
+ prompt = conv.get_prompt()
89
+
90
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
91
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
92
+ keywords = [stop_str]
93
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
94
+
95
+ with torch.inference_mode():
96
+ output_ids = model.generate(
97
+ input_ids,
98
+ images=image_tensor,
99
+ image_sizes=[image_size],
100
+ do_sample=True if args.temperature > 0 else False,
101
+ temperature=args.temperature,
102
+ max_new_tokens=args.max_new_tokens,
103
+ streamer=streamer,
104
+ use_cache=True)
105
+
106
+ outputs = tokenizer.decode(output_ids[0]).strip()
107
+ conv.messages[-1][-1] = outputs
108
+
109
+ if args.debug:
110
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116
+ parser.add_argument("--model-base", type=str, default=None)
117
+ parser.add_argument("--image-file", type=str, required=True)
118
+ parser.add_argument("--device", type=str, default="cuda")
119
+ parser.add_argument("--conv-mode", type=str, default=None)
120
+ parser.add_argument("--temperature", type=float, default=0.2)
121
+ parser.add_argument("--max-new-tokens", type=int, default=512)
122
+ parser.add_argument("--load-8bit", action="store_true")
123
+ parser.add_argument("--load-4bit", action="store_true")
124
+ parser.add_argument("--debug", action="store_true")
125
+ args = parser.parse_args()
126
+ main(args)
llava/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from llava.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,), daemon=True)
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+ # Let the controller act as a worker to achieve hierarchical
218
+ # management. This can be used to connect isolated sub networks.
219
+
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
llava/serve/examples/extreme_ironing.jpg ADDED
llava/serve/examples/waterview.jpg ADDED