ShaoRun commited on
Commit
b2a1d7a
·
verified ·
1 Parent(s): 076e1ef

Upload 36 files

Browse files
Files changed (37) hide show
  1. .gitattributes +3 -0
  2. inference/chat_vision_point.ipynb +0 -0
  3. inference/demo_assets/e393be9a47a24a7cae6142e13f5686d1_8192.npy +3 -0
  4. inference/demo_assets/image1.png +3 -0
  5. inference/demo_assets/image1_384.png +3 -0
  6. inference/demo_assets/image2.png +3 -0
  7. inference/forward_speed.ipynb +182 -0
  8. mm_models/__init__.py +2 -0
  9. mm_models/configuration_mm.py +23 -0
  10. mm_models/llms/__pycache__/llama_modal_moe.cpython-310.pyc +0 -0
  11. mm_models/llms/__pycache__/qwen_model_moe.cpython-310.pyc +0 -0
  12. mm_models/llms/qwen_model_moe.py +338 -0
  13. mm_models/modal_module/__init__.py +22 -0
  14. mm_models/modal_module/__pycache__/__init__.cpython-310.pyc +0 -0
  15. mm_models/modal_module/point/__pycache__/reconv2.cpython-310.pyc +0 -0
  16. mm_models/modal_module/point/recon/__pycache__/transformer.cpython-310.pyc +0 -0
  17. mm_models/modal_module/point/recon/reconv2_utils/AverageMeter.py +42 -0
  18. mm_models/modal_module/point/recon/reconv2_utils/__pycache__/knn.cpython-310.pyc +0 -0
  19. mm_models/modal_module/point/recon/reconv2_utils/__pycache__/logger.cpython-310.pyc +0 -0
  20. mm_models/modal_module/point/recon/reconv2_utils/__pycache__/misc.cpython-310.pyc +0 -0
  21. mm_models/modal_module/point/recon/reconv2_utils/checkpoint.py +129 -0
  22. mm_models/modal_module/point/recon/reconv2_utils/config.py +69 -0
  23. mm_models/modal_module/point/recon/reconv2_utils/data.py +109 -0
  24. mm_models/modal_module/point/recon/reconv2_utils/dist_utils.py +49 -0
  25. mm_models/modal_module/point/recon/reconv2_utils/knn.py +37 -0
  26. mm_models/modal_module/point/recon/reconv2_utils/logger.py +127 -0
  27. mm_models/modal_module/point/recon/reconv2_utils/misc.py +294 -0
  28. mm_models/modal_module/point/recon/reconv2_utils/parser.py +117 -0
  29. mm_models/modal_module/point/recon/reconv2_utils/randaugment.py +216 -0
  30. mm_models/modal_module/point/recon/reconv2_utils/registry.py +289 -0
  31. mm_models/modal_module/point/recon/reconv2_utils/transforms.py +78 -0
  32. mm_models/modal_module/point/recon/transformer.py +647 -0
  33. mm_models/modal_module/point/reconv2.py +266 -0
  34. mm_models/modal_module/vision/__pycache__/siglip.cpython-310.pyc +0 -0
  35. mm_models/modal_module/vision/siglip.py +122 -0
  36. mm_models/modeling_mm.py +259 -0
  37. utils.py +181 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ inference/demo_assets/image1_384.png filter=lfs diff=lfs merge=lfs -text
37
+ inference/demo_assets/image1.png filter=lfs diff=lfs merge=lfs -text
38
+ inference/demo_assets/image2.png filter=lfs diff=lfs merge=lfs -text
inference/chat_vision_point.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
inference/demo_assets/e393be9a47a24a7cae6142e13f5686d1_8192.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e9bee56f432fce81bc1d356f54656064ac4926d02ff95b1b4f2e09d18ba79c7
3
+ size 196736
inference/demo_assets/image1.png ADDED

Git LFS Details

  • SHA256: 69f615c932632dfba059d4704da9c530df30167e8ed9dab42d0cd09280b79876
  • Pointer size: 132 Bytes
  • Size of remote file: 4.94 MB
inference/demo_assets/image1_384.png ADDED

Git LFS Details

  • SHA256: 9a65eb53d31a7b031ce9f5a30355ffde812b1b434e2752ee15dcfa540af0783e
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
inference/demo_assets/image2.png ADDED

Git LFS Details

  • SHA256: 240666bf35e80c5e17b3b5188cd99a04989002ca80da1e2edc1c613cb32481f1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
inference/forward_speed.ipynb ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import sys\n",
19
+ "import os\n",
20
+ "import random\n",
21
+ "import torch\n",
22
+ "sys.path.append(\"../\")\n",
23
+ "from mm_models import AllSparkForCausalLM\n",
24
+ "from transformers import AutoImageProcessor, AutoTokenizer\n",
25
+ "from PIL import Image\n",
26
+ "import numpy as np\n",
27
+ "from fvcore.nn import FlopCountAnalysis\n",
28
+ "from plyfile import PlyData\n",
29
+ "import plotly.graph_objects as go\n",
30
+ "from mm_datasets.data_utils import point_preprocess, load_pts, process_pts\n",
31
+ "from utils import SYSTEM_PROMPT\n",
32
+ "\n",
33
+ "system_prompt = SYSTEM_PROMPT\n",
34
+ "\n",
35
+ "\n",
36
+ "def show_pointcloud(data, background=None):\n",
37
+ " points = data[:, :3]\n",
38
+ " colors = data[:, 3:6]\n",
39
+ "\n",
40
+ " if colors is not None:\n",
41
+ " # * if colors in range(0-1)\n",
42
+ " if np.max(colors) <= 1:\n",
43
+ " color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)\n",
44
+ " # * if colors in range(0-255)\n",
45
+ " elif np.max(colors) <= 255:\n",
46
+ " color_data = colors.astype(int)\n",
47
+ " else:\n",
48
+ " color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available\n",
49
+ " colors = color_data.astype(np.float32) / 255 # model input is (0-1)\n",
50
+ "\n",
51
+ " color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]\n",
52
+ "\n",
53
+ " fig = go.Figure(\n",
54
+ " data=[\n",
55
+ " go.Scatter3d(\n",
56
+ " x=points[:, 0], y=points[:, 1], z=points[:, 2],\n",
57
+ " mode='markers',\n",
58
+ " marker=dict(\n",
59
+ " size=1.2,\n",
60
+ " color=color_strings, # Use the list of RGB strings for the marker colors\n",
61
+ " )\n",
62
+ " )\n",
63
+ " ],\n",
64
+ " layout=dict(\n",
65
+ " scene=dict(\n",
66
+ " xaxis=dict(visible=False),\n",
67
+ " yaxis=dict(visible=False),\n",
68
+ " zaxis=dict(visible=False)\n",
69
+ " ),\n",
70
+ " paper_bgcolor='rgb(50,50,50)' if background is None else background # Set the background color to dark gray 50, 50, 50\n",
71
+ " ),\n",
72
+ " )\n",
73
+ " fig.show()"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [
81
+ {
82
+ "name": "stderr",
83
+ "output_type": "stream",
84
+ "text": [
85
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00, 1.73s/it]\n"
86
+ ]
87
+ }
88
+ ],
89
+ "source": [
90
+ "model_path = \"[path/to/model]\"\n",
91
+ "\n",
92
+ "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
93
+ "model = AllSparkForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()\n",
94
+ "img_processor = AutoImageProcessor.from_pretrained(model_path)\n",
95
+ "modal_place_token = dict()\n",
96
+ "for modal_cfg in model.config.modal_configs:\n",
97
+ " modal_place_token[modal_cfg['modal_tag']] = modal_cfg['modal_placeholder_token']"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [
105
+ {
106
+ "name": "stderr",
107
+ "output_type": "stream",
108
+ "text": [
109
+ "Unsupported operator aten::_convolution_mode encountered 1 time(s)\n",
110
+ "Unsupported operator aten::embedding encountered 3 time(s)\n",
111
+ "Unsupported operator aten::add encountered 224 time(s)\n",
112
+ "Unsupported operator aten::mul encountered 342 time(s)\n",
113
+ "Unsupported operator aten::softmax encountered 26 time(s)\n",
114
+ "Unsupported operator aten::gelu encountered 28 time(s)\n",
115
+ "Unsupported operator aten::pad encountered 1 time(s)\n",
116
+ "Unsupported operator aten::mul_ encountered 1 time(s)\n",
117
+ "Unsupported operator aten::ones_like encountered 1 time(s)\n",
118
+ "Unsupported operator aten::sub encountered 1 time(s)\n",
119
+ "Unsupported operator aten::cos encountered 1 time(s)\n",
120
+ "Unsupported operator aten::sin encountered 1 time(s)\n",
121
+ "Unsupported operator aten::pow encountered 57 time(s)\n",
122
+ "Unsupported operator aten::mean encountered 57 time(s)\n",
123
+ "Unsupported operator aten::rsqrt encountered 57 time(s)\n",
124
+ "Unsupported operator aten::neg encountered 56 time(s)\n",
125
+ "Unsupported operator prim::PythonOp.FlashAttnFunc encountered 28 time(s)\n",
126
+ "Unsupported operator aten::silu encountered 28 time(s)\n",
127
+ "Unsupported operator aten::cross_entropy_loss encountered 1 time(s)\n",
128
+ "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
129
+ "llm.model.layers.0.self_attn.rotary_emb, llm.model.layers.1.self_attn.rotary_emb, llm.model.layers.10.self_attn.rotary_emb, llm.model.layers.11.self_attn.rotary_emb, llm.model.layers.12.self_attn.rotary_emb, llm.model.layers.13.self_attn.rotary_emb, llm.model.layers.14.self_attn.rotary_emb, llm.model.layers.15.self_attn.rotary_emb, llm.model.layers.16.self_attn.rotary_emb, llm.model.layers.17.self_attn.rotary_emb, llm.model.layers.18.self_attn.rotary_emb, llm.model.layers.19.self_attn.rotary_emb, llm.model.layers.2.self_attn.rotary_emb, llm.model.layers.20.self_attn.rotary_emb, llm.model.layers.21.self_attn.rotary_emb, llm.model.layers.22.self_attn.rotary_emb, llm.model.layers.23.self_attn.rotary_emb, llm.model.layers.24.self_attn.rotary_emb, llm.model.layers.25.self_attn.rotary_emb, llm.model.layers.26.self_attn.rotary_emb, llm.model.layers.27.self_attn.rotary_emb, llm.model.layers.3.self_attn.rotary_emb, llm.model.layers.4.self_attn.rotary_emb, llm.model.layers.5.self_attn.rotary_emb, llm.model.layers.6.self_attn.rotary_emb, llm.model.layers.7.self_attn.rotary_emb, llm.model.layers.8.self_attn.rotary_emb, llm.model.layers.9.self_attn.rotary_emb, modal_encoders.vision.vision_model.encoder.layers.26, modal_encoders.vision.vision_model.encoder.layers.26.layer_norm1, modal_encoders.vision.vision_model.encoder.layers.26.layer_norm2, modal_encoders.vision.vision_model.encoder.layers.26.mlp, modal_encoders.vision.vision_model.encoder.layers.26.mlp.activation_fn, modal_encoders.vision.vision_model.encoder.layers.26.mlp.fc1, modal_encoders.vision.vision_model.encoder.layers.26.mlp.fc2, modal_encoders.vision.vision_model.encoder.layers.26.self_attn, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.k_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.out_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.q_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.v_proj, modal_encoders.vision.vision_model.head, modal_encoders.vision.vision_model.head.attention, modal_encoders.vision.vision_model.head.attention.out_proj, modal_encoders.vision.vision_model.head.layernorm, modal_encoders.vision.vision_model.head.mlp, modal_encoders.vision.vision_model.head.mlp.activation_fn, modal_encoders.vision.vision_model.head.mlp.fc1, modal_encoders.vision.vision_model.head.mlp.fc2, modal_encoders.vision.vision_model.post_layernorm\n"
130
+ ]
131
+ },
132
+ {
133
+ "name": "stdout",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "2781.7917 GFLOPs\n"
137
+ ]
138
+ }
139
+ ],
140
+ "source": [
141
+ "image_path = \"../demo_images/image1_384.png\"\n",
142
+ "img = Image.open(image_path).convert(\"RGB\")\n",
143
+ "\n",
144
+ "img = img_processor(images=img, return_tensors=\"pt\").pixel_values.to(\"cuda\").squeeze().to(model.dtype)\n",
145
+ "\n",
146
+ "question_2 = modal_place_token['vision'] + \"\\nDescribe this image.\"\n",
147
+ "\n",
148
+ "modal_inputs = [('vision', img)]\n",
149
+ "\n",
150
+ "messages = [\n",
151
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
152
+ " {\"role\": \"user\", \"content\": \"The 3D object is a football, specifically a soccer ball, which is used in the sport of soccer or football. The ball is designed with a series of black and white panels that form a spherical shape, making it easy to kick and control during gameplay. The design also allows for the ball to spin and bounce when in motion, adding a strategic element to the sport.\\n\" + question_2},\n",
153
+ "]\n",
154
+ "inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)\n",
155
+ "\n",
156
+ "flops = FlopCountAnalysis(model, (inputs, [modal_inputs]))\n",
157
+ "print(f\"{flops.total()/1e9:.4f} GFLOPs\")"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "base",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.10.14"
178
+ }
179
+ },
180
+ "nbformat": 4,
181
+ "nbformat_minor": 2
182
+ }
mm_models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_mm import AllSparkConfig
2
+ from .modeling_mm import AllSparkForCausalLM
mm_models/configuration_mm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional, List, Dict
3
+
4
+
5
+ class AllSparkConfig(PretrainedConfig):
6
+
7
+ model_type = "allspark"
8
+
9
+ def __init__(self,
10
+ llm_name_or_path: str = None,
11
+ modal_configs: Optional[List[Dict]] = None,
12
+ initializer_range: float = 0.02,
13
+ ignore_index: int = -100,
14
+ tokenizer_padding_side: str = "right",
15
+ add_moe: bool = True,
16
+ **kwargs):
17
+ self.llm_name_or_path = llm_name_or_path
18
+ self.modal_configs = modal_configs
19
+ self.initializer_range = initializer_range
20
+ self.ignore_index = ignore_index
21
+ self.tokenizer_padding_side = tokenizer_padding_side
22
+ self.add_moe = add_moe
23
+ super().__init__(**kwargs)
mm_models/llms/__pycache__/llama_modal_moe.cpython-310.pyc ADDED
Binary file (9.32 kB). View file
 
mm_models/llms/__pycache__/qwen_model_moe.cpython-310.pyc ADDED
Binary file (8.53 kB). View file
 
mm_models/llms/qwen_model_moe.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Qwen2ForCausalLM, Qwen2Model
4
+ from transformers.cache_utils import Cache, DynamicCache
5
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP
6
+ from typing import Optional, Tuple, List, Union
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
8
+ from utils import rank0_print
9
+ import copy
10
+
11
+
12
+ class Qwen2DecoderLayerMoE(Qwen2DecoderLayer):
13
+
14
+ def __init__(self, config, layer_idx, modal_tags, add_moe):
15
+ config._attn_implementation = "flash_attention_2"
16
+ super().__init__(config, layer_idx)
17
+ self.modal_tags = modal_tags
18
+
19
+ if modal_tags is not None and add_moe:
20
+ self.modal_moes = nn.ModuleDict()
21
+ for tag in modal_tags:
22
+ self.modal_moes[tag] = Qwen2MLP(config)
23
+
24
+ def forward(
25
+ self,
26
+ hidden_states: torch.Tensor,
27
+ modal_token_idx_matrix,
28
+ modal_idx_mapping,
29
+ attention_mask: Optional[torch.Tensor] = None,
30
+ position_ids: Optional[torch.LongTensor] = None,
31
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
32
+ output_attentions: Optional[bool] = False,
33
+ use_cache: Optional[bool] = False,
34
+ cache_position: Optional[torch.LongTensor] = None,
35
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
36
+ **kwargs,
37
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
38
+
39
+ residual = hidden_states
40
+
41
+ hidden_states = self.input_layernorm(hidden_states)
42
+
43
+ # Self Attention
44
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
45
+ hidden_states=hidden_states,
46
+ attention_mask=attention_mask,
47
+ position_ids=position_ids,
48
+ past_key_value=past_key_value,
49
+ output_attentions=output_attentions,
50
+ use_cache=use_cache,
51
+ cache_position=cache_position,
52
+ position_embeddings=position_embeddings,
53
+ )
54
+ hidden_states = residual + hidden_states
55
+
56
+ residual = hidden_states
57
+ hidden_states = self.post_attention_layernorm(hidden_states)
58
+ if modal_token_idx_matrix is not None and hidden_states.shape[1] > 1 and hasattr(self, 'modal_moes'):
59
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
60
+ hidden_states = hidden_states.view(-1, hidden_dim)
61
+ final_hidden_states = torch.zeros(
62
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
63
+ )
64
+ for modal_tag, modal_idx in modal_idx_mapping.items():
65
+ mask = modal_token_idx_matrix == modal_idx
66
+ if not torch.any(mask):
67
+ continue
68
+ mask = mask.view(-1)
69
+ assert mask.shape[0] == hidden_states.shape[0]
70
+ if modal_tag == 'text':
71
+ final_hidden_states[mask] = self.mlp(hidden_states[mask])
72
+ else:
73
+ final_hidden_states[mask] = self.modal_moes[modal_tag](hidden_states[mask])
74
+ hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
75
+ else:
76
+ hidden_states = self.mlp(hidden_states)
77
+ hidden_states = residual + hidden_states
78
+
79
+ outputs = (hidden_states,)
80
+
81
+ if output_attentions:
82
+ outputs += (self_attn_weights,)
83
+
84
+ if use_cache:
85
+ outputs += (present_key_value,)
86
+
87
+ return outputs
88
+
89
+
90
+ class Qwen2ModelMoE(Qwen2Model):
91
+
92
+ def __init__(self, config, modal_tags=None, add_moe=True):
93
+ super().__init__(config)
94
+
95
+ self.modal_tags = modal_tags
96
+
97
+ self.layers = nn.ModuleList(
98
+ [Qwen2DecoderLayerMoE(config, layer_idx, modal_tags, add_moe=add_moe)
99
+ for layer_idx in range(config.num_hidden_layers)]
100
+ )
101
+
102
+ def forward(
103
+ self,
104
+ input_ids: torch.LongTensor = None,
105
+ attention_mask: Optional[torch.Tensor] = None,
106
+ position_ids: Optional[torch.LongTensor] = None,
107
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
108
+ inputs_embeds: Optional[torch.FloatTensor] = None,
109
+ use_cache: Optional[bool] = None,
110
+ output_attentions: Optional[bool] = None,
111
+ output_hidden_states: Optional[bool] = None,
112
+ return_dict: Optional[bool] = None,
113
+ cache_position: Optional[torch.LongTensor] = None,
114
+ modal_tag_pos_list: Optional[List[List[Tuple[str, int, int]]]] = None, # batch, modal_num, (tag, start, end)
115
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
116
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
117
+ output_hidden_states = (
118
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
119
+ )
120
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
121
+
122
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
123
+
124
+ if (input_ids is None) ^ (inputs_embeds is not None):
125
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
126
+
127
+ if self.gradient_checkpointing and self.training:
128
+ if use_cache:
129
+ rank0_print(
130
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
131
+ )
132
+ use_cache = False
133
+
134
+ # kept for BC (non `Cache` `past_key_values` inputs)
135
+ return_legacy_cache = False
136
+ if use_cache and not isinstance(past_key_values, Cache):
137
+ return_legacy_cache = True
138
+ if past_key_values is None:
139
+ past_key_values = DynamicCache()
140
+ else:
141
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
142
+ rank0_print(
143
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
144
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
145
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
146
+ )
147
+
148
+ if inputs_embeds is None:
149
+ inputs_embeds = self.embed_tokens(input_ids)
150
+
151
+ if cache_position is None:
152
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
153
+ cache_position = torch.arange(
154
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
155
+ )
156
+ if position_ids is None:
157
+ position_ids = cache_position.unsqueeze(0)
158
+
159
+ causal_mask = self._update_causal_mask(
160
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
161
+ )
162
+
163
+ hidden_states = inputs_embeds
164
+
165
+ # create position embeddings to be shared across the decoder layers
166
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
167
+
168
+ # decoder layers
169
+ all_hidden_states = () if output_hidden_states else None
170
+ all_self_attns = () if output_attentions else None
171
+ next_decoder_cache = None
172
+
173
+ modal_token_idx_matrix = torch.zeros(hidden_states.shape[:2], dtype=torch.int8)
174
+ modal_idx_mapping = {"text": 0}
175
+ if self.modal_tags and modal_tag_pos_list:
176
+ for modal_idx, tag in enumerate(self.modal_tags):
177
+ modal_idx_mapping[tag] = modal_idx + 1
178
+ for sample_id, single_sample_mtp in enumerate(modal_tag_pos_list):
179
+ for tag, spos, epos in single_sample_mtp:
180
+ modal_token_idx_matrix[sample_id, spos:epos+1] = modal_idx_mapping[tag]
181
+
182
+ for decoder_layer in self.layers:
183
+ if output_hidden_states:
184
+ all_hidden_states += (hidden_states,)
185
+
186
+ if self.gradient_checkpointing and self.training:
187
+ layer_outputs = self._gradient_checkpointing_func(
188
+ decoder_layer.__call__,
189
+ hidden_states,
190
+ modal_token_idx_matrix,
191
+ modal_idx_mapping,
192
+ causal_mask,
193
+ position_ids,
194
+ past_key_values,
195
+ output_attentions,
196
+ use_cache,
197
+ cache_position,
198
+ position_embeddings,
199
+ )
200
+ else:
201
+ layer_outputs = decoder_layer(
202
+ hidden_states,
203
+ modal_token_idx_matrix=modal_token_idx_matrix,
204
+ modal_idx_mapping=modal_idx_mapping,
205
+ attention_mask=causal_mask,
206
+ position_ids=position_ids,
207
+ past_key_value=past_key_values,
208
+ output_attentions=output_attentions,
209
+ use_cache=use_cache,
210
+ cache_position=cache_position,
211
+ position_embeddings=position_embeddings,
212
+ )
213
+
214
+ hidden_states = layer_outputs[0]
215
+
216
+ if use_cache:
217
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
218
+
219
+ if output_attentions:
220
+ all_self_attns += (layer_outputs[1],)
221
+
222
+ hidden_states = self.norm(hidden_states)
223
+
224
+ # add hidden states from the last decoder layer
225
+ if output_hidden_states:
226
+ all_hidden_states += (hidden_states,)
227
+
228
+ next_cache = next_decoder_cache if use_cache else None
229
+ if return_legacy_cache:
230
+ next_cache = next_cache.to_legacy_cache()
231
+
232
+ if not return_dict:
233
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
234
+ return BaseModelOutputWithPast(
235
+ last_hidden_state=hidden_states,
236
+ past_key_values=next_cache,
237
+ hidden_states=all_hidden_states,
238
+ attentions=all_self_attns,
239
+ )
240
+
241
+
242
+ class Qwen2ForCausalLMMoE(Qwen2ForCausalLM):
243
+
244
+ def __init__(self, config, modal_tags=None, add_moe=True):
245
+ super().__init__(config)
246
+ self.model = Qwen2ModelMoE(config, modal_tags, add_moe=add_moe)
247
+
248
+ def forward(
249
+ self,
250
+ input_ids: torch.LongTensor = None,
251
+ attention_mask: Optional[torch.Tensor] = None,
252
+ position_ids: Optional[torch.LongTensor] = None,
253
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
254
+ inputs_embeds: Optional[torch.FloatTensor] = None,
255
+ labels: Optional[torch.LongTensor] = None,
256
+ use_cache: Optional[bool] = None,
257
+ output_attentions: Optional[bool] = None,
258
+ output_hidden_states: Optional[bool] = None,
259
+ return_dict: Optional[bool] = None,
260
+ cache_position: Optional[torch.LongTensor] = None,
261
+ num_logits_to_keep: int = 0,
262
+ modal_tag_pos_list: Optional[List[List[Tuple[str, int, int]]]] = None, # batch, modal_num, (tag, start, end)
263
+ **loss_kwargs,
264
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
265
+
266
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
267
+ output_hidden_states = (
268
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269
+ )
270
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271
+
272
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
273
+ outputs = self.model(
274
+ input_ids=input_ids,
275
+ attention_mask=attention_mask,
276
+ position_ids=position_ids,
277
+ past_key_values=past_key_values,
278
+ inputs_embeds=inputs_embeds,
279
+ use_cache=use_cache,
280
+ output_attentions=output_attentions,
281
+ output_hidden_states=output_hidden_states,
282
+ return_dict=return_dict,
283
+ cache_position=cache_position,
284
+ modal_tag_pos_list=modal_tag_pos_list,
285
+ )
286
+
287
+ hidden_states = outputs[0]
288
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
289
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
290
+
291
+ loss = None
292
+ if labels is not None:
293
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
294
+
295
+ if not return_dict:
296
+ output = (logits,) + outputs[1:]
297
+ return (loss,) + output if loss is not None else output
298
+
299
+ return CausalLMOutputWithPast(
300
+ loss=loss,
301
+ logits=logits,
302
+ past_key_values=outputs.past_key_values,
303
+ hidden_states=outputs.hidden_states,
304
+ attentions=outputs.attentions,
305
+ )
306
+
307
+ def prepare_inputs_for_generation(self,
308
+ input_ids: torch.LongTensor,
309
+ past_key_values: Optional[Cache] = None,
310
+ attention_mask: Optional[torch.LongTensor] = None,
311
+ inputs_embeds: Optional[torch.FloatTensor] = None,
312
+ cache_position: Optional[torch.LongTensor] = None,
313
+ **kwargs,):
314
+ model_inputs = super().prepare_inputs_for_generation(input_ids,
315
+ past_key_values,
316
+ attention_mask,
317
+ inputs_embeds,
318
+ cache_position,
319
+ **kwargs)
320
+ model_inputs.update(
321
+ {
322
+ "modal_tag_pos_list": kwargs.get("modal_tag_pos_list", None),
323
+ }
324
+ )
325
+ return model_inputs
326
+
327
+ def init_modal_moe_params(self, target_tag, src_tag):
328
+ for i, decoder_layer in enumerate(self.model.layers):
329
+ if hasattr(decoder_layer, "modal_moes"):
330
+ if src_tag == "text":
331
+ rank0_print(f"Initializing layer{i} {target_tag} moe params for text")
332
+ mlp_module = decoder_layer.mlp
333
+ decoder_layer.modal_moes[target_tag].load_state_dict(copy.deepcopy(mlp_module.state_dict()))
334
+ else:
335
+ rank0_print(f"Initializing layer{i} {target_tag} moe params for {src_tag}")
336
+ assert src_tag in decoder_layer.modal_moes, f"src_tag {src_tag} not found in decoder_layer.modal_moes"
337
+ src_module = decoder_layer.modal_moes[src_tag]
338
+ decoder_layer.modal_moes[target_tag].load_state_dict(copy.deepcopy(src_module.state_dict()))
mm_models/modal_module/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vision.siglip import build_vision_encoder, build_vision_projector, VISION_SIGLIP_MODAL_CFG
2
+ from .point.reconv2 import build_point_encoder, build_point_projector, POINT_RECON2_MODAL_CFG
3
+ # NOTE: import custom modal encoder and projector here
4
+
5
+
6
+ MODAL_CFG_MAPPING = {
7
+ 'vision': VISION_SIGLIP_MODAL_CFG,
8
+ 'point': POINT_RECON2_MODAL_CFG
9
+ # NOTE: add other modalities here
10
+ }
11
+
12
+ MODAL_ENCODERS_MAPPING = {
13
+ 'vision': build_vision_encoder,
14
+ 'point': build_point_encoder
15
+ # NOTE: add other modalities here
16
+ }
17
+
18
+ MODAL_PROJECTORS_MAPPING = {
19
+ 'vision': build_vision_projector,
20
+ 'point': build_point_projector
21
+ # NOTE: add other modalities here
22
+ }
mm_models/modal_module/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (546 Bytes). View file
 
mm_models/modal_module/point/__pycache__/reconv2.cpython-310.pyc ADDED
Binary file (7.89 kB). View file
 
mm_models/modal_module/point/recon/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (17.8 kB). View file
 
mm_models/modal_module/point/recon/reconv2_utils/AverageMeter.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class AverageMeter(object):
3
+ def __init__(self, items=None):
4
+ self.items = items
5
+ self.n_items = 1 if items is None else len(items)
6
+ self.reset()
7
+
8
+ def reset(self):
9
+ self._val = [0] * self.n_items
10
+ self._sum = [0] * self.n_items
11
+ self._count = [0] * self.n_items
12
+
13
+ def update(self, values):
14
+ if type(values).__name__ == 'list':
15
+ for idx, v in enumerate(values):
16
+ self._val[idx] = v
17
+ self._sum[idx] += v
18
+ self._count[idx] += 1
19
+ else:
20
+ self._val[0] = values
21
+ self._sum[0] += values
22
+ self._count[0] += 1
23
+
24
+ def val(self, idx=None):
25
+ if idx is None:
26
+ return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]
27
+ else:
28
+ return self._val[idx]
29
+
30
+ def count(self, idx=None):
31
+ if idx is None:
32
+ return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]
33
+ else:
34
+ return self._count[idx]
35
+
36
+ def avg(self, idx=None):
37
+ if idx is None:
38
+ return self._sum[0] / self._count[0] if self.items is None else [
39
+ self._sum[i] / self._count[i] for i in range(self.n_items)
40
+ ]
41
+ else:
42
+ return self._sum[idx] / self._count[idx]
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/knn.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (9.2 kB). View file
 
mm_models/modal_module/point/recon/reconv2_utils/checkpoint.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+
4
+ from collections import defaultdict
5
+ import torch.nn as nn
6
+
7
+ from typing import Any
8
+ from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
9
+
10
+ from termcolor import colored
11
+
12
+ def get_missing_parameters_message(keys: List[str]) -> str:
13
+ """
14
+ Get a logging-friendly message to report parameter names (keys) that are in
15
+ the model but not found in a checkpoint.
16
+ Args:
17
+ keys (list[str]): List of keys that were not found in the checkpoint.
18
+ Returns:
19
+ str: message.
20
+ """
21
+ groups = _group_checkpoint_keys(keys)
22
+ msg = "Some model parameters or buffers are not found in the checkpoint:\n"
23
+ msg += "\n".join(
24
+ " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
25
+ )
26
+ return msg
27
+
28
+
29
+ def get_unexpected_parameters_message(keys: List[str]) -> str:
30
+ """
31
+ Get a logging-friendly message to report parameter names (keys) that are in
32
+ the checkpoint but not found in the model.
33
+ Args:
34
+ keys (list[str]): List of keys that were not found in the model.
35
+ Returns:
36
+ str: message.
37
+ """
38
+ groups = _group_checkpoint_keys(keys)
39
+ msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
40
+ msg += "\n".join(
41
+ " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
42
+ )
43
+ return msg
44
+
45
+
46
+ def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
47
+ """
48
+ Strip the prefix in metadata, if any.
49
+ Args:
50
+ state_dict (OrderedDict): a state-dict to be loaded to the model.
51
+ prefix (str): prefix.
52
+ """
53
+ keys = sorted(state_dict.keys())
54
+ if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
55
+ return
56
+
57
+ for key in keys:
58
+ newkey = key[len(prefix):]
59
+ state_dict[newkey] = state_dict.pop(key)
60
+
61
+ # also strip the prefix in metadata, if any..
62
+ try:
63
+ metadata = state_dict._metadata # pyre-ignore
64
+ except AttributeError:
65
+ pass
66
+ else:
67
+ for key in list(metadata.keys()):
68
+ # for the metadata dict, the key can be:
69
+ # '': for the DDP module, which we want to remove.
70
+ # 'module': for the actual model.
71
+ # 'module.xx.xx': for the rest.
72
+
73
+ if len(key) == 0:
74
+ continue
75
+ newkey = key[len(prefix):]
76
+ metadata[newkey] = metadata.pop(key)
77
+
78
+
79
+ def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
80
+ """
81
+ Group keys based on common prefixes. A prefix is the string up to the final
82
+ "." in each key.
83
+ Args:
84
+ keys (list[str]): list of parameter names, i.e. keys in the model
85
+ checkpoint dict.
86
+ Returns:
87
+ dict[list]: keys with common prefixes are grouped into lists.
88
+ """
89
+ groups = defaultdict(list)
90
+ for key in keys:
91
+ pos = key.rfind(".")
92
+ if pos >= 0:
93
+ head, tail = key[:pos], [key[pos + 1:]]
94
+ else:
95
+ head, tail = key, []
96
+ groups[head].extend(tail)
97
+ return groups
98
+
99
+
100
+ def _group_to_str(group: List[str]) -> str:
101
+ """
102
+ Format a group of parameter name suffixes into a loggable string.
103
+ Args:
104
+ group (list[str]): list of parameter name suffixes.
105
+ Returns:
106
+ str: formated string.
107
+ """
108
+ if len(group) == 0:
109
+ return ""
110
+
111
+ if len(group) == 1:
112
+ return "." + group[0]
113
+
114
+ return ".{" + ", ".join(group) + "}"
115
+
116
+
117
+ def _named_modules_with_dup(
118
+ model: nn.Module, prefix: str = ""
119
+ ) -> Iterable[Tuple[str, nn.Module]]:
120
+ """
121
+ The same as `model.named_modules()`, except that it includes
122
+ duplicated modules that have more than one name.
123
+ """
124
+ yield prefix, model
125
+ for name, module in model._modules.items(): # pyre-ignore
126
+ if module is None:
127
+ continue
128
+ submodule_prefix = prefix + ("." if prefix else "") + name
129
+ yield from _named_modules_with_dup(module, submodule_prefix)
mm_models/modal_module/point/recon/reconv2_utils/config.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from easydict import EasyDict
3
+ import os
4
+ from .logger import print_log
5
+
6
+
7
+ def log_args_to_file(args, pre='args', logger=None):
8
+ for key, val in args.__dict__.items():
9
+ print_log(f'{pre}.{key} : {val}', logger=logger)
10
+
11
+
12
+ def log_config_to_file(cfg, pre='cfg', logger=None):
13
+ for key, val in cfg.items():
14
+ if isinstance(cfg[key], EasyDict):
15
+ print_log(f'{pre}.{key} = edict()', logger=logger)
16
+ log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
17
+ continue
18
+ print_log(f'{pre}.{key} : {val}', logger=logger)
19
+
20
+
21
+ def merge_new_config(config, new_config):
22
+ for key, val in new_config.items():
23
+ if not isinstance(val, dict):
24
+ if key == '_base_':
25
+ with open(new_config['_base_'], 'r') as f:
26
+ try:
27
+ val = yaml.load(f, Loader=yaml.FullLoader)
28
+ except:
29
+ val = yaml.load(f)
30
+ config[key] = EasyDict()
31
+ merge_new_config(config[key], val)
32
+ else:
33
+ config[key] = val
34
+ continue
35
+ if key not in config:
36
+ config[key] = EasyDict()
37
+ merge_new_config(config[key], val)
38
+ return config
39
+
40
+
41
+ def cfg_from_yaml_file(cfg_file):
42
+ config = EasyDict()
43
+ with open(cfg_file, 'r') as f:
44
+ try:
45
+ new_config = yaml.load(f, Loader=yaml.FullLoader)
46
+ except:
47
+ new_config = yaml.load(f)
48
+ merge_new_config(config=config, new_config=new_config)
49
+ return config
50
+
51
+
52
+ def get_config(args, logger=None):
53
+ if args.resume:
54
+ cfg_path = os.path.join(args.experiment_path, 'config.yaml')
55
+ if not os.path.exists(cfg_path):
56
+ print_log("Failed to resume", logger=logger)
57
+ raise FileNotFoundError()
58
+ print_log(f'Resume yaml from {cfg_path}', logger=logger)
59
+ args.config = cfg_path
60
+ config = cfg_from_yaml_file(args.config)
61
+ if not args.resume and args.local_rank == 0:
62
+ save_experiment_config(args, config, logger)
63
+ return config
64
+
65
+
66
+ def save_experiment_config(args, config, logger=None):
67
+ config_path = os.path.join(args.experiment_path, 'config.yaml')
68
+ os.system('cp %s %s' % (args.config, config_path))
69
+ print_log(f'Copy the Config file from {args.config} to {config_path}', logger=logger)
mm_models/modal_module/point/recon/reconv2_utils/data.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def random_rotate_z(pc):
5
+ # random roate around z axis
6
+ theta = np.random.uniform(0, 2 * np.pi)
7
+ R = np.array([[np.cos(theta), -np.sin(theta), 0],
8
+ [np.sin(theta), np.cos(theta), 0],
9
+ [0, 0, 1]])
10
+ return np.matmul(pc, R)
11
+
12
+
13
+ def normalize_pc(pc):
14
+ """ pc: NxC, return NxC """
15
+ centroid = np.mean(pc, axis=0)
16
+ pc = pc - centroid
17
+ m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
18
+ if m < 1e-6:
19
+ pc = np.zeros_like(pc)
20
+ else:
21
+ pc = pc / m
22
+ return pc
23
+
24
+
25
+ def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
26
+ """ batch_pc: BxNx3 """
27
+ for b in range(batch_pc.shape[0]):
28
+ dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
29
+ drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
30
+ if len(drop_idx) > 0:
31
+ batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point
32
+ return batch_pc
33
+
34
+
35
+ def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25):
36
+
37
+ scales = np.random.uniform(scale_low, scale_high)
38
+ data *= scales
39
+ return data
40
+
41
+
42
+ def shift_point_cloud(batch_data, shift_range=0.1):
43
+ """ Randomly shift point cloud. Shift is per point cloud.
44
+ Input:
45
+ BxNx3 array, original batch of point clouds
46
+ Return:
47
+ BxNx3 array, shifted batch of point clouds
48
+ """
49
+ B, N, C = batch_data.shape
50
+ shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
51
+ for batch_index in range(B):
52
+ batch_data[batch_index, :, :] += shifts[batch_index, :]
53
+ return batch_data
54
+
55
+
56
+ def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
57
+ """ Randomly perturb the point clouds by small rotations
58
+ Input:
59
+ BxNx3 array, original batch of point clouds
60
+ Return:
61
+ BxNx3 array, rotated batch of point clouds
62
+ """
63
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
64
+ for k in range(batch_data.shape[0]):
65
+ angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
66
+ Rx = np.array([[1, 0, 0],
67
+ [0, np.cos(angles[0]), -np.sin(angles[0])],
68
+ [0, np.sin(angles[0]), np.cos(angles[0])]])
69
+ Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
70
+ [0, 1, 0],
71
+ [-np.sin(angles[1]), 0, np.cos(angles[1])]])
72
+ Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
73
+ [np.sin(angles[2]), np.cos(angles[2]), 0],
74
+ [0, 0, 1]])
75
+ R = np.dot(Rz, np.dot(Ry, Rx))
76
+ shape_pc = batch_data[k, ...]
77
+ rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
78
+ return rotated_data
79
+
80
+
81
+ def rotate_point_cloud(batch_data):
82
+ """ Randomly rotate the point clouds to augument the dataset
83
+ rotation is per shape based along up direction
84
+ Input:
85
+ BxNx3 array, original batch of point clouds
86
+ Return:
87
+ BxNx3 array, rotated batch of point clouds
88
+ """
89
+ rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
90
+ for k in range(batch_data.shape[0]):
91
+ rotation_angle = np.random.uniform() * 2 * np.pi
92
+ cosval = np.cos(rotation_angle)
93
+ sinval = np.sin(rotation_angle)
94
+ rotation_matrix = np.array([[cosval, 0, sinval],
95
+ [0, 1, 0],
96
+ [-sinval, 0, cosval]])
97
+ shape_pc = batch_data[k, ...]
98
+ rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
99
+ return rotated_data
100
+
101
+
102
+ def augment_pc(data):
103
+ # data = random_point_dropout(data[None, ...])
104
+ data = random_scale_point_cloud(data[None, ...])
105
+ data = shift_point_cloud(data)
106
+ data = rotate_perturbation_point_cloud(data)
107
+ data = rotate_point_cloud(data)
108
+ data = data.squeeze()
109
+ return data
mm_models/modal_module/point/recon/reconv2_utils/dist_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import distributed as dist
3
+
4
+
5
+ def init_dist(local_rank, backend='nccl', **kwargs):
6
+ torch.cuda.set_device(local_rank)
7
+ dist.init_process_group(backend=backend, **kwargs)
8
+ print(f'init distributed in rank {local_rank}')
9
+
10
+
11
+ def reduce_tensor(tensor, args):
12
+ '''
13
+ for acc kind, get the mean in each gpu
14
+ '''
15
+ rt = tensor.clone()
16
+ torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
17
+ rt /= args.world_size
18
+ return rt
19
+
20
+
21
+ def gather_tensor(tensor, args):
22
+ output_tensors = [tensor.clone() for _ in range(args.world_size)]
23
+ torch.distributed.all_gather(output_tensors, tensor)
24
+ concat = torch.cat(output_tensors, dim=0)
25
+ return concat
26
+
27
+
28
+ def set_batch_size(args, config):
29
+ if args.distributed:
30
+ assert config.total_bs % args.world_size == 0
31
+ if config.dataset.get('train'):
32
+ config.dataset.train.others.bs = config.total_bs // args.world_size
33
+ if config.dataset.get('extra_train'):
34
+ config.dataset.extra_train.others.bs = config.total_bs // args.world_size
35
+ if config.dataset.get('val'):
36
+ config.dataset.val.others.bs = config.total_bs // args.world_size
37
+ if config.dataset.get('test'):
38
+ config.dataset.test.others.bs = config.total_bs // args.world_size
39
+ else:
40
+ if config.dataset.get('train'):
41
+ config.dataset.train.others.bs = config.total_bs
42
+ if config.dataset.get('extra_train'):
43
+ config.dataset.extra_train.others.bs = config.total_bs
44
+ if config.dataset.get('extra_val'):
45
+ config.dataset.extra_val.others.bs = config.total_bs
46
+ if config.dataset.get('val'):
47
+ config.dataset.val.others.bs = config.total_bs
48
+ if config.dataset.get('test'):
49
+ config.dataset.test.others.bs = config.total_bs
mm_models/modal_module/point/recon/reconv2_utils/knn.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def square_distance(src, dst):
5
+ """
6
+ Calculate Euclid distance between each two points.
7
+ src^T * dst = xn * xm + yn * ym + zn * zm;
8
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
9
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
10
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
11
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
12
+ Input:
13
+ src: source points, [B, N, C]
14
+ dst: target points, [B, M, C]
15
+ Output:
16
+ dist: per-point square distance, [B, N, M]
17
+ """
18
+ B, N, _ = src.shape
19
+ _, M, _ = dst.shape
20
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
21
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
22
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
23
+ return dist
24
+
25
+
26
+ def knn_point(nsample, xyz, new_xyz):
27
+ """
28
+ Input:
29
+ nsample: max sample number in local region
30
+ xyz: all points, [B, N, C]
31
+ new_xyz: query points, [B, S, C]
32
+ Return:
33
+ group_idx: grouped points index, [B, S, nsample]
34
+ """
35
+ sqrdists = square_distance(new_xyz, xyz)
36
+ _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
37
+ return group_idx
mm_models/modal_module/point/recon/reconv2_utils/logger.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch.distributed as dist
3
+
4
+ logger_initialized = {}
5
+
6
+ def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7
+ """Get root logger and add a keyword filter to it.
8
+ The logger will be initialized if it has not been initialized. By default a
9
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
10
+ also be added. The name of the root logger is the top-level package name,
11
+ e.g., "mmdet3d".
12
+ Args:
13
+ log_file (str, optional): File path of log. Defaults to None.
14
+ log_level (int, optional): The level of logger.
15
+ Defaults to logging.INFO.
16
+ name (str, optional): The name of the root logger, also used as a
17
+ filter keyword. Defaults to 'mmdet3d'.
18
+ Returns:
19
+ :obj:`logging.Logger`: The obtained logger
20
+ """
21
+ logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22
+ # add a logging filter
23
+ logging_filter = logging.Filter(name)
24
+ logging_filter.filter = lambda record: record.find(name) != -1
25
+
26
+ return logger
27
+
28
+
29
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30
+ """Initialize and get a logger by name.
31
+ If the logger has not been initialized, this method will initialize the
32
+ logger by adding one or two handlers, otherwise the initialized logger will
33
+ be directly returned. During initialization, a StreamHandler will always be
34
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
35
+ will also be added.
36
+ Args:
37
+ name (str): Logger name.
38
+ log_file (str | None): The log filename. If specified, a FileHandler
39
+ will be added to the logger.
40
+ log_level (int): The logger level. Note that only the process of
41
+ rank 0 is affected, and other processes will set the level to
42
+ "Error" thus be silent most of the time.
43
+ file_mode (str): The file mode used in opening log file.
44
+ Defaults to 'w'.
45
+ Returns:
46
+ logging.Logger: The expected logger.
47
+ """
48
+ logger = logging.getLogger(name)
49
+ if name in logger_initialized:
50
+ return logger
51
+ # handle hierarchical names
52
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
53
+ # initialization since it is a child of "a".
54
+ for logger_name in logger_initialized:
55
+ if name.startswith(logger_name):
56
+ return logger
57
+
58
+ # handle duplicate logs to the console
59
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
60
+ # to the root logger. As logger.propagate is True by default, this root
61
+ # level handler causes logging messages from rank>0 processes to
62
+ # unexpectedly show up on the console, creating much unwanted clutter.
63
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
64
+ # at the ERROR level.
65
+ for handler in logger.root.handlers:
66
+ if type(handler) is logging.StreamHandler:
67
+ handler.setLevel(logging.ERROR)
68
+
69
+ stream_handler = logging.StreamHandler()
70
+ handlers = [stream_handler]
71
+
72
+ if dist.is_available() and dist.is_initialized():
73
+ rank = dist.get_rank()
74
+ else:
75
+ rank = 0
76
+
77
+ # only rank 0 will add a FileHandler
78
+ if rank == 0 and log_file is not None:
79
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
80
+ # provide an interface to change the file mode to the default
81
+ # behaviour.
82
+ file_handler = logging.FileHandler(log_file, file_mode)
83
+ handlers.append(file_handler)
84
+
85
+ formatter = logging.Formatter(
86
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87
+ for handler in handlers:
88
+ handler.setFormatter(formatter)
89
+ handler.setLevel(log_level)
90
+ logger.addHandler(handler)
91
+
92
+ if rank == 0:
93
+ logger.setLevel(log_level)
94
+ else:
95
+ logger.setLevel(logging.ERROR)
96
+
97
+ logger_initialized[name] = True
98
+
99
+
100
+ return logger
101
+
102
+
103
+ def print_log(msg, logger=None, level=logging.INFO):
104
+ """Print a log message.
105
+ Args:
106
+ msg (str): The message to be logged.
107
+ logger (logging.Logger | str | None): The logger to be used.
108
+ Some special loggers are:
109
+ - "silent": no message will be printed.
110
+ - other str: the logger obtained with `get_root_logger(logger)`.
111
+ - None: The `print()` method will be used to print log messages.
112
+ level (int): Logging level. Only available when `logger` is a Logger
113
+ object or "root".
114
+ """
115
+ if logger is None:
116
+ print(msg)
117
+ elif isinstance(logger, logging.Logger):
118
+ logger.log(level, msg)
119
+ elif logger == 'silent':
120
+ pass
121
+ elif isinstance(logger, str):
122
+ _logger = get_logger(logger)
123
+ _logger.log(level, msg)
124
+ else:
125
+ raise TypeError(
126
+ 'logger should be either a logging.Logger object, str, '
127
+ f'"silent" or None, but got {type(logger)}')
mm_models/modal_module/point/recon/reconv2_utils/misc.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from mpl_toolkits.mplot3d import Axes3D
4
+ import random
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import os
9
+ from collections import abc
10
+ # from pointnet2_ops import pointnet2_utils
11
+
12
+
13
+ # def fps(data, number):
14
+ # '''
15
+ # data B N 3
16
+ # number int
17
+ # '''
18
+ # fps_idx = pointnet2_utils.furthest_point_sample(data, number)
19
+ # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
20
+ # return fps_data
21
+
22
+ def index_points(points, idx):
23
+ """
24
+ Input:
25
+ points: input points data, [B, N, C]
26
+ idx: sample index data, [B, S]
27
+ Return:
28
+ new_points:, indexed points data, [B, S, C]
29
+ """
30
+ device = points.device
31
+ B = points.shape[0]
32
+ view_shape = list(idx.shape)
33
+ view_shape[1:] = [1] * (len(view_shape) - 1)
34
+ repeat_shape = list(idx.shape)
35
+ repeat_shape[0] = 1
36
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
37
+ new_points = points[batch_indices, idx, :]
38
+ return new_points
39
+
40
+ def fps(point_data, npoint):
41
+ """
42
+ Input:
43
+ xyz: pointcloud data, [B, N, 3]
44
+ npoint: number of samples
45
+ Return:
46
+ centroids: sampled pointcloud index, [B, npoint]
47
+ """
48
+ xyz = point_data[:, :, :3]
49
+ device = xyz.device
50
+ B, N, C = xyz.shape
51
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
52
+ distance = torch.ones(B, N).to(device) * 1e10
53
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
54
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
55
+ for i in range(npoint):
56
+ centroids[:, i] = farthest
57
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
58
+ dist = torch.sum((xyz - centroid) ** 2, -1)
59
+ distance = torch.min(distance, dist)
60
+ farthest = torch.max(distance, -1)[1]
61
+ return index_points(point_data, centroids)
62
+
63
+
64
+ def worker_init_fn(worker_id):
65
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
66
+
67
+
68
+ def build_lambda_sche(opti, config):
69
+ if config.get('decay_step') is not None:
70
+ lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
71
+ scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
72
+ else:
73
+ raise NotImplementedError()
74
+ return scheduler
75
+
76
+
77
+ def build_lambda_bnsche(model, config):
78
+ if config.get('decay_step') is not None:
79
+ bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
80
+ bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
81
+ else:
82
+ raise NotImplementedError()
83
+ return bnm_scheduler
84
+
85
+
86
+ def set_random_seed(seed, deterministic=False):
87
+ """Set random seed.
88
+ Args:
89
+ seed (int): Seed to be used.
90
+ deterministic (bool): Whether to set the deterministic option for
91
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
92
+ to True and `torch.backends.cudnn.benchmark` to False.
93
+ Default: False.
94
+
95
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
96
+ if cuda_deterministic: # slower, more reproducible
97
+ cudnn.deterministic = True
98
+ cudnn.benchmark = False
99
+ else: # faster, less reproducible
100
+ cudnn.deterministic = False
101
+ cudnn.benchmark = True
102
+
103
+ """
104
+ random.seed(seed)
105
+ np.random.seed(seed)
106
+ torch.manual_seed(seed)
107
+ torch.cuda.manual_seed_all(seed)
108
+ if deterministic:
109
+ torch.backends.cudnn.deterministic = True
110
+ torch.backends.cudnn.benchmark = False
111
+
112
+
113
+ def is_seq_of(seq, expected_type, seq_type=None):
114
+ """Check whether it is a sequence of some type.
115
+ Args:
116
+ seq (Sequence): The sequence to be checked.
117
+ expected_type (type): Expected type of sequence items.
118
+ seq_type (type, optional): Expected sequence type.
119
+ Returns:
120
+ bool: Whether the sequence is valid.
121
+ """
122
+ if seq_type is None:
123
+ exp_seq_type = abc.Sequence
124
+ else:
125
+ assert isinstance(seq_type, type)
126
+ exp_seq_type = seq_type
127
+ if not isinstance(seq, exp_seq_type):
128
+ return False
129
+ for item in seq:
130
+ if not isinstance(item, expected_type):
131
+ return False
132
+ return True
133
+
134
+
135
+ def set_bn_momentum_default(bn_momentum):
136
+ def fn(m):
137
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
138
+ m.momentum = bn_momentum
139
+
140
+ return fn
141
+
142
+
143
+ class BNMomentumScheduler(object):
144
+
145
+ def __init__(
146
+ self, model, bn_lambda, last_epoch=-1,
147
+ setter=set_bn_momentum_default
148
+ ):
149
+ if not isinstance(model, nn.Module):
150
+ raise RuntimeError(
151
+ "Class '{}' is not a PyTorch nn Module".format(
152
+ type(model).__name__
153
+ )
154
+ )
155
+
156
+ self.model = model
157
+ self.setter = setter
158
+ self.lmbd = bn_lambda
159
+
160
+ self.step(last_epoch + 1)
161
+ self.last_epoch = last_epoch
162
+
163
+ def step(self, epoch=None):
164
+ if epoch is None:
165
+ epoch = self.last_epoch + 1
166
+
167
+ self.last_epoch = epoch
168
+ self.model.apply(self.setter(self.lmbd(epoch)))
169
+
170
+ def get_momentum(self, epoch=None):
171
+ if epoch is None:
172
+ epoch = self.last_epoch + 1
173
+ return self.lmbd(epoch)
174
+
175
+
176
+ def seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False):
177
+ '''
178
+ seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
179
+ '''
180
+ _, n, c = xyz.shape
181
+
182
+ assert n == num_points
183
+ assert c == 3
184
+ if crop == num_points:
185
+ return xyz, None
186
+
187
+ INPUT = []
188
+ CROP = []
189
+ for points in xyz:
190
+ if isinstance(crop, list):
191
+ num_crop = random.randint(crop[0], crop[1])
192
+ else:
193
+ num_crop = crop
194
+
195
+ points = points.unsqueeze(0)
196
+
197
+ if fixed_points is None:
198
+ center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda()
199
+ else:
200
+ if isinstance(fixed_points, list):
201
+ fixed_point = random.sample(fixed_points, 1)[0]
202
+ else:
203
+ fixed_point = fixed_points
204
+ center = fixed_point.reshape(1, 1, 3).cuda()
205
+
206
+ distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1) # 1 1 2048
207
+
208
+ idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0] # 2048
209
+
210
+ if padding_zeros:
211
+ input_data = points.clone()
212
+ input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0
213
+
214
+ else:
215
+ input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
216
+
217
+ crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
218
+
219
+ if isinstance(crop, list):
220
+ INPUT.append(fps(input_data, 2048))
221
+ CROP.append(fps(crop_data, 2048))
222
+ else:
223
+ INPUT.append(input_data)
224
+ CROP.append(crop_data)
225
+
226
+ input_data = torch.cat(INPUT, dim=0) # B N 3
227
+ crop_data = torch.cat(CROP, dim=0) # B M 3
228
+
229
+ return input_data.contiguous(), crop_data.contiguous()
230
+
231
+
232
+ def get_ptcloud_img(ptcloud, roll, pitch):
233
+ fig = plt.figure(figsize=(8, 8))
234
+
235
+ x, z, y = ptcloud.transpose(1, 0)
236
+ ax = fig.gca(projection=Axes3D.name, adjustable='box')
237
+ ax.axis('off')
238
+ # ax.axis('scaled')
239
+ ax.view_init(roll, pitch)
240
+ max, min = np.max(ptcloud), np.min(ptcloud)
241
+ ax.set_xbound(min, max)
242
+ ax.set_ybound(min, max)
243
+ ax.set_zbound(min, max)
244
+ ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')
245
+
246
+ fig.canvas.draw()
247
+ img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
248
+ img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
249
+ return img
250
+
251
+
252
+ def visualize_KITTI(path, data_list, titles=['input', 'pred'], cmap=['bwr', 'autumn'], zdir='y',
253
+ xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)):
254
+ fig = plt.figure(figsize=(6 * len(data_list), 6))
255
+ cmax = data_list[-1][:, 0].max()
256
+
257
+ for i in range(len(data_list)):
258
+ data = data_list[i][:-2048] if i == 1 else data_list[i]
259
+ color = data[:, 0] / cmax
260
+ ax = fig.add_subplot(1, len(data_list), i + 1, projection='3d')
261
+ ax.view_init(30, -120)
262
+ b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color, vmin=-1, vmax=1, cmap=cmap[0], s=4,
263
+ linewidth=0.05, edgecolors='black')
264
+ ax.set_title(titles[i])
265
+
266
+ ax.set_axis_off()
267
+ ax.set_xlim(xlim)
268
+ ax.set_ylim(ylim)
269
+ ax.set_zlim(zlim)
270
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
271
+ if not os.path.exists(path):
272
+ os.makedirs(path)
273
+
274
+ pic_path = path + '.png'
275
+ fig.savefig(pic_path)
276
+
277
+ np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
278
+ np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
279
+ plt.close(fig)
280
+
281
+
282
+ def random_dropping(pc, e):
283
+ up_num = max(64, 768 // (e // 50 + 1))
284
+ pc = pc
285
+ random_num = torch.randint(1, up_num, (1, 1))[0, 0]
286
+ pc = fps(pc, random_num)
287
+ padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
288
+ pc = torch.cat([pc, padding], dim=1)
289
+ return pc
290
+
291
+
292
+ def random_scale(partial, scale_range=[0.8, 1.2]):
293
+ scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
294
+ return partial * scale
mm_models/modal_module/point/recon/reconv2_utils/parser.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+
6
+ def get_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ '--config',
10
+ type=str,
11
+ help='yaml config file')
12
+ parser.add_argument('--distributed', action='store_true', default=False)
13
+ parser.add_argument('--local-rank', type=int, default=0)
14
+ parser.add_argument('--num_workers', type=int, default=8)
15
+ # seed
16
+ parser.add_argument('--seed', type=int, default=0, help='random seed')
17
+ parser.add_argument(
18
+ '--deterministic',
19
+ action='store_true',
20
+ help='whether to set deterministic options for CUDNN backend.')
21
+ # bn
22
+ parser.add_argument(
23
+ '--sync_bn',
24
+ action='store_true',
25
+ default=False,
26
+ help='whether to use sync bn')
27
+ # some args
28
+ parser.add_argument('--exp_name', type=str, default='default', help='experiment name')
29
+ parser.add_argument('--start_ckpts', type=str, default=None, help='reload used ckpt path')
30
+ parser.add_argument('--ckpts', type=str, default=None, help='test used ckpt path')
31
+ parser.add_argument('--val_freq', type=int, default=1, help='test freq')
32
+ parser.add_argument(
33
+ '--vote',
34
+ action='store_true',
35
+ default=False,
36
+ help='vote acc')
37
+ parser.add_argument(
38
+ '--resume',
39
+ action='store_true',
40
+ default=False,
41
+ help='autoresume training (interrupted by accident)')
42
+ parser.add_argument(
43
+ '--svm',
44
+ action='store_true',
45
+ default=False,
46
+ help='svm')
47
+ parser.add_argument(
48
+ '--zeroshot',
49
+ action='store_true',
50
+ default=False,
51
+ help='zero-shot')
52
+ parser.add_argument(
53
+ '--test',
54
+ action='store_true',
55
+ default=False,
56
+ help='test mode for certain ckpt')
57
+ parser.add_argument(
58
+ '--reconstruct',
59
+ action='store_true',
60
+ default=False,
61
+ help='reconstruct pretraining stage')
62
+ parser.add_argument(
63
+ '--contrast',
64
+ action='store_true',
65
+ default=False,
66
+ help='contrast pretraining stage')
67
+ parser.add_argument(
68
+ '--finetune_model',
69
+ action='store_true',
70
+ default=False,
71
+ help='finetune modelnet with pretrained weight')
72
+ parser.add_argument(
73
+ '--way', type=int, default=-1)
74
+ parser.add_argument(
75
+ '--shot', type=int, default=-1)
76
+ parser.add_argument(
77
+ '--fold', type=int, default=-1)
78
+
79
+ args = parser.parse_args()
80
+
81
+ if args.test and args.resume:
82
+ raise ValueError(
83
+ '--test and --resume cannot be both activate')
84
+
85
+ if args.resume and args.start_ckpts is not None:
86
+ raise ValueError(
87
+ '--resume and --start_ckpts cannot be both activate')
88
+
89
+ if args.test and args.ckpts is None:
90
+ raise ValueError(
91
+ 'ckpts shouldnt be None while test mode')
92
+
93
+ if args.finetune_model and args.ckpts is None:
94
+ print(
95
+ 'training from scratch')
96
+
97
+ if 'LOCAL_RANK' not in os.environ:
98
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
99
+
100
+ if args.test:
101
+ args.exp_name = 'test_' + args.exp_name
102
+ args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,
103
+ args.exp_name)
104
+ args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, 'TFBoard',
105
+ args.exp_name)
106
+ args.log_name = Path(args.config).stem
107
+ create_experiment_dir(args)
108
+ return args
109
+
110
+
111
+ def create_experiment_dir(args):
112
+ if not os.path.exists(args.experiment_path):
113
+ os.makedirs(args.experiment_path, exist_ok=True)
114
+ print('Create experiment path successfully at %s' % args.experiment_path)
115
+ if not os.path.exists(args.tfboard_path):
116
+ os.makedirs(args.tfboard_path, exist_ok=True)
117
+ print('Create TFBoard path successfully at %s' % args.tfboard_path)
mm_models/modal_module/point/recon/reconv2_utils/randaugment.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import PIL.ImageOps
7
+ import PIL.ImageEnhance
8
+ import PIL.ImageDraw
9
+ from PIL import Image
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ PARAMETER_MAX = 10
14
+
15
+
16
+ def AutoContrast(img, **kwarg):
17
+ return PIL.ImageOps.autocontrast(img)
18
+
19
+
20
+ def Brightness(img, v, max_v, bias=0):
21
+ v = _float_parameter(v, max_v) + bias
22
+ return PIL.ImageEnhance.Brightness(img).enhance(v)
23
+
24
+
25
+ def Color(img, v, max_v, bias=0):
26
+ v = _float_parameter(v, max_v) + bias
27
+ return PIL.ImageEnhance.Color(img).enhance(v)
28
+
29
+
30
+ def Contrast(img, v, max_v, bias=0):
31
+ v = _float_parameter(v, max_v) + bias
32
+ return PIL.ImageEnhance.Contrast(img).enhance(v)
33
+
34
+
35
+ def Cutout(img, v, max_v, bias=0):
36
+ if v == 0:
37
+ return img
38
+ v = _float_parameter(v, max_v) + bias
39
+ v = int(v * min(img.size))
40
+ return CutoutAbs(img, v)
41
+
42
+
43
+ def CutoutAbs(img, v, **kwarg):
44
+ w, h = img.size
45
+ x0 = np.random.uniform(0, w)
46
+ y0 = np.random.uniform(0, h)
47
+ x0 = int(max(0, x0 - v / 2.))
48
+ y0 = int(max(0, y0 - v / 2.))
49
+ x1 = int(min(w, x0 + v))
50
+ y1 = int(min(h, y0 + v))
51
+ xy = (x0, y0, x1, y1)
52
+ # gray
53
+ color = (127, 127, 127)
54
+ img = img.copy()
55
+ PIL.ImageDraw.Draw(img).rectangle(xy, color)
56
+ return img
57
+
58
+
59
+ def Equalize(img, **kwarg):
60
+ return PIL.ImageOps.equalize(img)
61
+
62
+
63
+ def Identity(img, **kwarg):
64
+ return img
65
+
66
+
67
+ def Invert(img, **kwarg):
68
+ return PIL.ImageOps.invert(img)
69
+
70
+
71
+ def Posterize(img, v, max_v, bias=0):
72
+ v = _int_parameter(v, max_v) + bias
73
+ return PIL.ImageOps.posterize(img, v)
74
+
75
+
76
+ def Rotate(img, v, max_v, bias=0):
77
+ v = _int_parameter(v, max_v) + bias
78
+ if random.random() < 0.5:
79
+ v = -v
80
+ return img.rotate(v)
81
+
82
+
83
+ def Sharpness(img, v, max_v, bias=0):
84
+ v = _float_parameter(v, max_v) + bias
85
+ return PIL.ImageEnhance.Sharpness(img).enhance(v)
86
+
87
+
88
+ def ShearX(img, v, max_v, bias=0):
89
+ v = _float_parameter(v, max_v) + bias
90
+ if random.random() < 0.5:
91
+ v = -v
92
+ return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
93
+
94
+
95
+ def ShearY(img, v, max_v, bias=0):
96
+ v = _float_parameter(v, max_v) + bias
97
+ if random.random() < 0.5:
98
+ v = -v
99
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
100
+
101
+
102
+ def Solarize(img, v, max_v, bias=0):
103
+ v = _int_parameter(v, max_v) + bias
104
+ return PIL.ImageOps.solarize(img, 256 - v)
105
+
106
+
107
+ def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
108
+ v = _int_parameter(v, max_v) + bias
109
+ if random.random() < 0.5:
110
+ v = -v
111
+ img_np = np.array(img).astype(np.int)
112
+ img_np = img_np + v
113
+ img_np = np.clip(img_np, 0, 255)
114
+ img_np = img_np.astype(np.uint8)
115
+ img = Image.fromarray(img_np)
116
+ return PIL.ImageOps.solarize(img, threshold)
117
+
118
+
119
+ def TranslateX(img, v, max_v, bias=0):
120
+ v = _float_parameter(v, max_v) + bias
121
+ if random.random() < 0.5:
122
+ v = -v
123
+ v = int(v * img.size[0])
124
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
125
+
126
+
127
+ def TranslateY(img, v, max_v, bias=0):
128
+ v = _float_parameter(v, max_v) + bias
129
+ if random.random() < 0.5:
130
+ v = -v
131
+ v = int(v * img.size[1])
132
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
133
+
134
+
135
+ def _float_parameter(v, max_v):
136
+ return float(v) * max_v / PARAMETER_MAX
137
+
138
+
139
+ def _int_parameter(v, max_v):
140
+ return int(v * max_v / PARAMETER_MAX)
141
+
142
+
143
+ def fixmatch_augment_pool():
144
+ # FixMatch paper
145
+ augs = [(AutoContrast, None, None),
146
+ (Brightness, 0.9, 0.05),
147
+ (Color, 0.9, 0.05),
148
+ (Contrast, 0.9, 0.05),
149
+ (Equalize, None, None),
150
+ (Identity, None, None),
151
+ (Posterize, 4, 4),
152
+ (Rotate, 30, 0),
153
+ (Sharpness, 0.9, 0.05),
154
+ (ShearX, 0.3, 0),
155
+ (ShearY, 0.3, 0),
156
+ (Solarize, 256, 0),
157
+ (TranslateX, 0.3, 0),
158
+ (TranslateY, 0.3, 0)]
159
+ return augs
160
+
161
+
162
+ def my_augment_pool():
163
+ # Test
164
+ augs = [(AutoContrast, None, None),
165
+ (Brightness, 1.8, 0.1),
166
+ (Color, 1.8, 0.1),
167
+ (Contrast, 1.8, 0.1),
168
+ (Cutout, 0.2, 0),
169
+ (Equalize, None, None),
170
+ (Invert, None, None),
171
+ (Posterize, 4, 4),
172
+ (Rotate, 30, 0),
173
+ (Sharpness, 1.8, 0.1),
174
+ (ShearX, 0.3, 0),
175
+ (ShearY, 0.3, 0),
176
+ (Solarize, 256, 0),
177
+ (SolarizeAdd, 110, 0),
178
+ (TranslateX, 0.45, 0),
179
+ (TranslateY, 0.45, 0)]
180
+ return augs
181
+
182
+
183
+ class RandAugmentPC(object):
184
+ def __init__(self, n, m):
185
+ assert n >= 1
186
+ assert 1 <= m <= 10
187
+ self.n = n
188
+ self.m = m
189
+ self.augment_pool = my_augment_pool()
190
+
191
+ def __call__(self, img):
192
+ ops = random.choices(self.augment_pool, k=self.n)
193
+ for op, max_v, bias in ops:
194
+ prob = np.random.uniform(0.2, 0.8)
195
+ if random.random() + prob >= 1:
196
+ img = op(img, v=self.m, max_v=max_v, bias=bias)
197
+ img = CutoutAbs(img, int(32*0.5))
198
+ return img
199
+
200
+
201
+ class RandAugmentMC(object):
202
+ def __init__(self, n, m):
203
+ assert n >= 1
204
+ assert 1 <= m <= 10
205
+ self.n = n
206
+ self.m = m
207
+ self.augment_pool = fixmatch_augment_pool()
208
+
209
+ def __call__(self, img):
210
+ ops = random.choices(self.augment_pool, k=self.n)
211
+ for op, max_v, bias in ops:
212
+ v = np.random.randint(1, self.m)
213
+ if random.random() < 0.5:
214
+ img = op(img, v=v, max_v=max_v, bias=bias)
215
+ img = CutoutAbs(img, int(32*0.5))
216
+ return img
mm_models/modal_module/point/recon/reconv2_utils/registry.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from functools import partial
4
+ from ReConV2.utils import config
5
+
6
+
7
+ class Registry:
8
+ """A registry to map strings to classes.
9
+ Registered object could be built from registry.
10
+ Example:
11
+ >>> MODELS = Registry('models')
12
+ >>> @MODELS.register_module()
13
+ >>> class ResNet:
14
+ >>> pass
15
+ >>> resnet = MODELS.build(dict(NAME='ResNet'))
16
+ Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
17
+ advanced useage.
18
+ Args:
19
+ name (str): Registry name.
20
+ build_func(func, optional): Build function to construct instance from
21
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
22
+ ``build_func`` is specified. If ``parent`` is specified and
23
+ ``build_func`` is not given, ``build_func`` will be inherited
24
+ from ``parent``. Default: None.
25
+ parent (Registry, optional): Parent registry. The class registered in
26
+ children registry could be built from parent. Default: None.
27
+ scope (str, optional): The scope of registry. It is the key to search
28
+ for children registry. If not specified, scope will be the name of
29
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
30
+ Default: None.
31
+ """
32
+
33
+ def __init__(self, name, build_func=None, parent=None, scope=None):
34
+ self._name = name
35
+ self._module_dict = dict()
36
+ self._children = dict()
37
+ self._scope = self.infer_scope() if scope is None else scope
38
+
39
+ # self.build_func will be set with the following priority:
40
+ # 1. build_func
41
+ # 2. parent.build_func
42
+ # 3. build_from_cfg
43
+ if build_func is None:
44
+ if parent is not None:
45
+ self.build_func = parent.build_func
46
+ else:
47
+ self.build_func = build_from_cfg
48
+ else:
49
+ self.build_func = build_func
50
+ if parent is not None:
51
+ assert isinstance(parent, Registry)
52
+ parent._add_children(self)
53
+ self.parent = parent
54
+ else:
55
+ self.parent = None
56
+
57
+ def __len__(self):
58
+ return len(self._module_dict)
59
+
60
+ def __contains__(self, key):
61
+ return self.get(key) is not None
62
+
63
+ def __repr__(self):
64
+ format_str = self.__class__.__name__ + \
65
+ f'(name={self._name}, ' \
66
+ f'items={self._module_dict})'
67
+ return format_str
68
+
69
+ @staticmethod
70
+ def infer_scope():
71
+ """Infer the scope of registry.
72
+ The name of the package where registry is defined will be returned.
73
+ Example:
74
+ # in mmdet/models/backbone/resnet.py
75
+ >>> MODELS = Registry('models')
76
+ >>> @MODELS.register_module()
77
+ >>> class ResNet:
78
+ >>> pass
79
+ The scope of ``ResNet`` will be ``mmdet``.
80
+ Returns:
81
+ scope (str): The inferred scope name.
82
+ """
83
+ # inspect.stack() trace where this function is called, the index-2
84
+ # indicates the frame where `infer_scope()` is called
85
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
86
+ split_filename = filename.split('.')
87
+ return split_filename[0]
88
+
89
+ @staticmethod
90
+ def split_scope_key(key):
91
+ """Split scope and key.
92
+ The first scope will be split from key.
93
+ Examples:
94
+ >>> Registry.split_scope_key('mmdet.ResNet')
95
+ 'mmdet', 'ResNet'
96
+ >>> Registry.split_scope_key('ResNet')
97
+ None, 'ResNet'
98
+ Return:
99
+ scope (str, None): The first scope.
100
+ key (str): The remaining key.
101
+ """
102
+ split_index = key.find('.')
103
+ if split_index != -1:
104
+ return key[:split_index], key[split_index + 1:]
105
+ else:
106
+ return None, key
107
+
108
+ @property
109
+ def name(self):
110
+ return self._name
111
+
112
+ @property
113
+ def scope(self):
114
+ return self._scope
115
+
116
+ @property
117
+ def module_dict(self):
118
+ return self._module_dict
119
+
120
+ @property
121
+ def children(self):
122
+ return self._children
123
+
124
+ def get(self, key):
125
+ """Get the registry record.
126
+ Args:
127
+ key (str): The class name in string format.
128
+ Returns:
129
+ class: The corresponding class.
130
+ """
131
+ scope, real_key = self.split_scope_key(key)
132
+ if scope is None or scope == self._scope:
133
+ # get from self
134
+ if real_key in self._module_dict:
135
+ return self._module_dict[real_key]
136
+ else:
137
+ # get from self._children
138
+ if scope in self._children:
139
+ return self._children[scope].get(real_key)
140
+ else:
141
+ # goto root
142
+ parent = self.parent
143
+ while parent.parent is not None:
144
+ parent = parent.parent
145
+ return parent.get(key)
146
+
147
+ def build(self, *args, **kwargs):
148
+ return self.build_func(*args, **kwargs, registry=self)
149
+
150
+ def _add_children(self, registry):
151
+ """Add children for a registry.
152
+ The ``registry`` will be added as children based on its scope.
153
+ The parent registry could build objects from children registry.
154
+ Example:
155
+ >>> models = Registry('models')
156
+ >>> mmdet_models = Registry('models', parent=models)
157
+ >>> @mmdet_models.register_module()
158
+ >>> class ResNet:
159
+ >>> pass
160
+ >>> resnet = models.build(dict(NAME='mmdet.ResNet'))
161
+ """
162
+
163
+ assert isinstance(registry, Registry)
164
+ assert registry.scope is not None
165
+ assert registry.scope not in self.children, \
166
+ f'scope {registry.scope} exists in {self.name} registry'
167
+ self.children[registry.scope] = registry
168
+
169
+ def _register_module(self, module_class, module_name=None, force=False):
170
+ if not inspect.isclass(module_class):
171
+ raise TypeError('module must be a class, '
172
+ f'but got {type(module_class)}')
173
+
174
+ if module_name is None:
175
+ module_name = module_class.__name__
176
+ if isinstance(module_name, str):
177
+ module_name = [module_name]
178
+ for name in module_name:
179
+ if not force and name in self._module_dict:
180
+ raise KeyError(f'{name} is already registered '
181
+ f'in {self.name}')
182
+ self._module_dict[name] = module_class
183
+
184
+ def deprecated_register_module(self, cls=None, force=False):
185
+ warnings.warn(
186
+ 'The old API of register_module(module, force=False) '
187
+ 'is deprecated and will be removed, please use the new API '
188
+ 'register_module(name=None, force=False, module=None) instead.')
189
+ if cls is None:
190
+ return partial(self.deprecated_register_module, force=force)
191
+ self._register_module(cls, force=force)
192
+ return cls
193
+
194
+ def register_module(self, name=None, force=False, module=None):
195
+ """Register a module.
196
+ A record will be added to `self._module_dict`, whose key is the class
197
+ name or the specified name, and value is the class itself.
198
+ It can be used as a decorator or a normal function.
199
+ Example:
200
+ >>> backbones = Registry('backbone')
201
+ >>> @backbones.register_module()
202
+ >>> class ResNet:
203
+ >>> pass
204
+ >>> backbones = Registry('backbone')
205
+ >>> @backbones.register_module(name='mnet')
206
+ >>> class MobileNet:
207
+ >>> pass
208
+ >>> backbones = Registry('backbone')
209
+ >>> class ResNet:
210
+ >>> pass
211
+ >>> backbones.register_module(ResNet)
212
+ Args:
213
+ name (str | None): The module name to be registered. If not
214
+ specified, the class name will be used.
215
+ force (bool, optional): Whether to override an existing class with
216
+ the same name. Default: False.
217
+ module (type): Module class to be registered.
218
+ """
219
+ if not isinstance(force, bool):
220
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
221
+ # NOTE: This is a walkaround to be compatible with the old api,
222
+ # while it may introduce unexpected bugs.
223
+ if isinstance(name, type):
224
+ return self.deprecated_register_module(name, force=force)
225
+
226
+ # raise the error ahead of time
227
+ if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
228
+ raise TypeError(
229
+ 'name must be either of None, an instance of str or a sequence'
230
+ f' of str, but got {type(name)}')
231
+
232
+ # use it as a normal method: x.register_module(module=SomeClass)
233
+ if module is not None:
234
+ self._register_module(
235
+ module_class=module, module_name=name, force=force)
236
+ return module
237
+
238
+ # use it as a decorator: @x.register_module()
239
+ def _register(cls):
240
+ self._register_module(
241
+ module_class=cls, module_name=name, force=force)
242
+ return cls
243
+
244
+ return _register
245
+
246
+
247
+ def build_from_cfg(cfg, registry, default_args=None):
248
+ """Build a module from config dict.
249
+ Args:
250
+ cfg (edict): Config dict. It should at least contain the key "NAME".
251
+ registry (:obj:`Registry`): The registry to search the type from.
252
+ Returns:
253
+ object: The constructed object.
254
+ """
255
+ if not isinstance(cfg, dict):
256
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
257
+ if 'NAME' not in cfg:
258
+ if default_args is None or 'NAME' not in default_args:
259
+ raise KeyError(
260
+ '`cfg` or `default_args` must contain the key "NAME", '
261
+ f'but got {cfg}\n{default_args}')
262
+ if not isinstance(registry, Registry):
263
+ raise TypeError('registry must be an mmcv.Registry object, '
264
+ f'but got {type(registry)}')
265
+
266
+ if not (isinstance(default_args, dict) or default_args is None):
267
+ raise TypeError('default_args must be a dict or None, '
268
+ f'but got {type(default_args)}')
269
+
270
+ if default_args is not None:
271
+ cfg = config.merge_new_config(cfg, default_args)
272
+
273
+ obj_type = cfg.get('NAME')
274
+
275
+ if isinstance(obj_type, str):
276
+ obj_cls = registry.get(obj_type)
277
+ if obj_cls is None:
278
+ raise KeyError(
279
+ f'{obj_type} is not in the {registry.name} registry')
280
+ elif inspect.isclass(obj_type):
281
+ obj_cls = obj_type
282
+ else:
283
+ raise TypeError(
284
+ f'type must be a str or valid type, but got {type(obj_type)}')
285
+ try:
286
+ return obj_cls(cfg)
287
+ except Exception as e:
288
+ # Normal TypeError does not print class name.
289
+ raise type(e)(f'{obj_cls.__name__}: {e}')
mm_models/modal_module/point/recon/reconv2_utils/transforms.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from ReConV2.utils.randaugment import RandAugmentMC
4
+ __all__ = ['get_transforms']
5
+
6
+
7
+ class ResizeImage():
8
+ def __init__(self, size):
9
+ if isinstance(size, int):
10
+ self.size = (int(size), int(size))
11
+ else:
12
+ self.size = size
13
+
14
+ def __call__(self, img):
15
+ th, tw = self.size
16
+ return img.resize((th, tw))
17
+
18
+
19
+ class PlaceCrop(object):
20
+ """Crops the given PIL.Image at the particular index.
21
+ Args:
22
+ size (sequence or int): Desired output size of the crop. If size is an
23
+ int instead of sequence like (w, h), a square crop (size, size) is
24
+ made.
25
+ """
26
+
27
+ def __init__(self, size, start_x, start_y):
28
+ if isinstance(size, int):
29
+ self.size = (int(size), int(size))
30
+ else:
31
+ self.size = size
32
+ self.start_x = start_x
33
+ self.start_y = start_y
34
+
35
+ def __call__(self, img):
36
+ """
37
+ Args:
38
+ img (PIL.Image): Image to be cropped.
39
+ Returns:
40
+ PIL.Image: Cropped image.
41
+ """
42
+ th, tw = self.size
43
+ return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th))
44
+
45
+
46
+ class ForceFlip(object):
47
+ """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
48
+
49
+ def __call__(self, img):
50
+ """
51
+ Args:
52
+ img (PIL.Image): Image to be flipped.
53
+ Returns:
54
+ PIL.Image: Randomly flipped image.
55
+ """
56
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
57
+
58
+
59
+ def transform_train(resize_size=256, crop_size=224):
60
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
61
+ std=[0.229, 0.224, 0.225])
62
+ return transforms.Compose([
63
+ # ResizeImage(resize_size),
64
+ # transforms.RandomHorizontalFlip(),
65
+ # transforms.RandomResizedCrop(crop_size, scale=(0.64, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
66
+ # RandAugmentMC(n=2, m=10),
67
+ ResizeImage(crop_size),
68
+ transforms.ToTensor(),
69
+ normalize
70
+ ])
71
+
72
+
73
+ def get_transforms(resize_size=256, crop_size=224):
74
+ transforms = {
75
+ 'train': transform_train(resize_size, crop_size)
76
+ }
77
+
78
+ return transforms
mm_models/modal_module/point/recon/transformer.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import timm
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .reconv2_utils import misc
9
+ from .reconv2_utils.logger import *
10
+ from .reconv2_utils.knn import knn_point
11
+ from timm.layers import Mlp, DropPath
12
+ from typing import Optional, List
13
+
14
+
15
+ class PatchEmbedding(nn.Module): # Embedding module
16
+ def __init__(self, embed_dim, input_channel=3, large=False):
17
+ super().__init__()
18
+ self.embed_dim = embed_dim
19
+ self.input_channel = input_channel
20
+
21
+ # embed_dim_list = [c * (embed_dim // 512 + 1) for c in [128, 256, 512]]
22
+ #
23
+ # self.first_conv = nn.Sequential(
24
+ # nn.Conv1d(self.input_channel, embed_dim_list[0], 1),
25
+ # nn.BatchNorm1d(embed_dim_list[0]),
26
+ # nn.ReLU(inplace=True),
27
+ # nn.Conv1d(embed_dim_list[0], embed_dim_list[1], 1)
28
+ # )
29
+ # self.second_conv = nn.Sequential(
30
+ # nn.Conv1d(embed_dim_list[2], embed_dim_list[2], 1),
31
+ # nn.BatchNorm1d(embed_dim_list[2]),
32
+ # nn.ReLU(inplace=True),
33
+ # nn.Conv1d(embed_dim_list[2], self.embed_dim, 1)
34
+ # )
35
+
36
+ if large:
37
+ self.first_conv = nn.Sequential(
38
+ nn.Conv1d(self.input_channel, 256, 1),
39
+ nn.BatchNorm1d(256),
40
+ nn.ReLU(inplace=True),
41
+ nn.Conv1d(256, 512, 1),
42
+ nn.BatchNorm1d(512),
43
+ nn.ReLU(inplace=True),
44
+ nn.Conv1d(512, 1024, 1)
45
+ )
46
+ self.second_conv = nn.Sequential(
47
+ nn.Conv1d(2048, 2048, 1),
48
+ nn.BatchNorm1d(2048),
49
+ nn.ReLU(inplace=True),
50
+ nn.Conv1d(2048, embed_dim, 1)
51
+ )
52
+ else:
53
+ self.first_conv = nn.Sequential(
54
+ nn.Conv1d(self.input_channel, 128, 1),
55
+ nn.BatchNorm1d(128),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv1d(128, 256, 1)
58
+ )
59
+ self.second_conv = nn.Sequential(
60
+ nn.Conv1d(512, 512, 1),
61
+ nn.BatchNorm1d(512),
62
+ nn.ReLU(inplace=True),
63
+ nn.Conv1d(512, embed_dim, 1)
64
+ )
65
+
66
+ def forward(self, point_groups):
67
+ '''
68
+ point_groups : B G N 3/6
69
+ -----------------
70
+ feature_global : B G C
71
+ '''
72
+ bs, g, n, _ = point_groups.shape
73
+ point_groups = point_groups.reshape(bs * g, n, self.input_channel)
74
+ # encoder
75
+ feature = self.first_conv(point_groups.transpose(2, 1))
76
+ feature_global = torch.max(feature, dim=2, keepdim=True)[0]
77
+ feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
78
+ feature = self.second_conv(feature)
79
+ feature_global = torch.max(feature, dim=2, keepdim=False)[0]
80
+ return feature_global.reshape(bs, g, self.embed_dim)
81
+
82
+
83
+ class PositionEmbeddingCoordsSine(nn.Module):
84
+ """Similar to transformer's position encoding, but generalizes it to
85
+ arbitrary dimensions and continuous coordinates.
86
+
87
+ Args:
88
+ n_dim: Number of input dimensions, e.g. 2 for image coordinates.
89
+ d_model: Number of dimensions to encode into
90
+ temperature:
91
+ scale:
92
+ """
93
+
94
+ def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=1.0, scale=None):
95
+ super().__init__()
96
+
97
+ self.n_dim = n_dim
98
+ self.num_pos_feats = d_model // n_dim // 2 * 2
99
+ self.temperature = temperature
100
+ self.padding = d_model - self.num_pos_feats * self.n_dim
101
+
102
+ if scale is None:
103
+ scale = 1.0
104
+ self.scale = scale * 2 * math.pi
105
+
106
+ def forward(self, xyz: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Args:
109
+ xyz: Point positions (*, d_in)
110
+
111
+ Returns:
112
+ pos_emb (*, d_out)
113
+ """
114
+ assert xyz.shape[-1] == self.n_dim
115
+
116
+ dim_t = torch.arange(self.num_pos_feats,
117
+ dtype=torch.float32, device=xyz.device)
118
+ dim_t = self.temperature ** (2 * torch.div(dim_t,
119
+ 2, rounding_mode='trunc') / self.num_pos_feats)
120
+
121
+ xyz = xyz * self.scale
122
+ pos_divided = xyz.unsqueeze(-1) / dim_t
123
+ pos_sin = pos_divided[..., 0::2].sin()
124
+ pos_cos = pos_divided[..., 1::2].cos()
125
+ pos_emb = torch.stack([pos_sin, pos_cos], dim=-1).reshape(*xyz.shape[:-1], -1)
126
+
127
+ # Pad unused dimensions with zeros
128
+ pos_emb = F.pad(pos_emb, (0, self.padding))
129
+ return pos_emb
130
+
131
+
132
+ class Group(nn.Module): # FPS + KNN
133
+ def __init__(self, num_group, group_size):
134
+ super().__init__()
135
+ self.num_group = num_group
136
+ self.group_size = group_size
137
+
138
+ def forward(self, pts):
139
+ '''
140
+ input: B N 3/6
141
+ ---------------------------
142
+ output: B G M 3/6
143
+ center : B G 3
144
+ '''
145
+ xyz = pts[:, :, :3]
146
+ c = pts.shape[2]
147
+ batch_size, num_points, _ = xyz.shape
148
+ # fps the centers out
149
+ xyz = xyz.float()
150
+ center = misc.fps(xyz.contiguous(), self.num_group) # B G 3
151
+ # knn to get the neighborhood
152
+ idx = knn_point(self.group_size, xyz, center)
153
+ assert idx.size(1) == self.num_group
154
+ assert idx.size(2) == self.group_size
155
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
156
+ idx = idx + idx_base
157
+ idx = idx.view(-1)
158
+ neighborhood = pts.view(batch_size * num_points, -1)[idx, :]
159
+ neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, c).contiguous()
160
+ # normalize
161
+ neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2)
162
+ return neighborhood, center
163
+
164
+
165
+ class ZGroup(nn.Module):
166
+ def __init__(self, num_group, group_size):
167
+ super().__init__()
168
+ self.num_group = num_group
169
+ self.group_size = group_size
170
+
171
+ def simplied_morton_sorting(self, xyz, center):
172
+ """
173
+ Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient.
174
+ """
175
+ batch_size, num_points, _ = xyz.shape
176
+ distances_batch = torch.cdist(center, center)
177
+ distances_batch[:, torch.eye(self.num_group).bool()] = float("inf")
178
+ idx_base = torch.arange(
179
+ 0, batch_size, device=xyz.device) * self.num_group
180
+ sorted_indices_list = [idx_base]
181
+ distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
182
+ 1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
183
+ distances_batch[idx_base] = float("inf")
184
+ distances_batch = distances_batch.view(
185
+ batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
186
+ for i in range(self.num_group - 1):
187
+ distances_batch = distances_batch.view(
188
+ batch_size * self.num_group, self.num_group)
189
+ distances_to_last_batch = distances_batch[sorted_indices_list[-1]]
190
+ closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)
191
+ closest_point_idx = closest_point_idx + idx_base
192
+ sorted_indices_list.append(closest_point_idx)
193
+ distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
194
+ 1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
195
+ distances_batch[closest_point_idx] = float("inf")
196
+ distances_batch = distances_batch.view(
197
+ batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
198
+ sorted_indices = torch.stack(sorted_indices_list, dim=-1)
199
+ sorted_indices = sorted_indices.view(-1)
200
+ return sorted_indices
201
+
202
+ def forward(self, pts):
203
+ """
204
+ input: B N 3/6
205
+ ---------------------------
206
+ output: B G M 3/6
207
+ center : B G 3
208
+ """
209
+ xyz = pts[:, :, :3]
210
+ c = pts.shape[2]
211
+ batch_size, num_points, _ = xyz.shape
212
+ # fps the centers out
213
+ xyz = xyz.float()
214
+ center = misc.fps(xyz.contiguous(), self.num_group) # B G 3
215
+ # knn to get the neighborhood
216
+ idx = knn_point(self.group_size, xyz, center)
217
+ assert idx.size(1) == self.num_group
218
+ assert idx.size(2) == self.group_size
219
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
220
+ idx = idx + idx_base
221
+ idx = idx.view(-1)
222
+ neighborhood = pts.view(batch_size * num_points, -1)[idx, :]
223
+ neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, c).contiguous()
224
+ # normalize
225
+ neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2)
226
+
227
+ # can utilize morton_sorting by choosing morton_sorting function
228
+ sorted_indices = self.simplied_morton_sorting(xyz, center)
229
+
230
+ neighborhood = neighborhood.view(
231
+ batch_size * self.num_group, self.group_size, c)[sorted_indices, :, :]
232
+ neighborhood = neighborhood.view(
233
+ batch_size, self.num_group, self.group_size, c).contiguous()
234
+ center = center.view(
235
+ batch_size * self.num_group, 3)[sorted_indices, :]
236
+ center = center.view(
237
+ batch_size, self.num_group, 3).contiguous()
238
+
239
+ return neighborhood, center
240
+
241
+
242
+ # Transformers
243
+ class Attention(nn.Module):
244
+ def __init__(
245
+ self,
246
+ dim: int,
247
+ num_heads: int = 8,
248
+ qkv_bias: bool = True,
249
+ qk_norm: bool = False,
250
+ attn_drop: float = 0.,
251
+ proj_drop: float = 0.,
252
+ norm_layer: nn.Module = nn.LayerNorm,
253
+ ) -> None:
254
+ super().__init__()
255
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
256
+ self.num_heads = num_heads
257
+ self.head_dim = dim // num_heads
258
+ self.scale = self.head_dim ** -0.5
259
+
260
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
261
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
262
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
263
+ self.attn_drop = nn.Dropout(attn_drop)
264
+ self.proj = nn.Linear(dim, dim)
265
+ self.proj_drop = nn.Dropout(proj_drop)
266
+
267
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
268
+ B, N, C = x.shape
269
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
270
+ q, k, v = qkv.unbind(0)
271
+ q, k = self.q_norm(q), self.k_norm(k)
272
+
273
+ q = q * self.scale
274
+ attn = q @ k.transpose(-2, -1)
275
+ if mask is not None:
276
+ attn = attn.masked_fill(mask, float('-inf'))
277
+ attn = attn.softmax(dim=-1)
278
+ attn = self.attn_drop(attn)
279
+ x = attn @ v
280
+
281
+ x = x.transpose(1, 2).reshape(B, N, C)
282
+ x = self.proj(x)
283
+ x = self.proj_drop(x)
284
+ return x
285
+
286
+
287
+ class CrossAttention(nn.Module):
288
+ def __init__(
289
+ self,
290
+ dim: int,
291
+ num_heads: int = 8,
292
+ qkv_bias: bool = True,
293
+ qk_norm: bool = False,
294
+ attn_drop: float = 0.,
295
+ proj_drop: float = 0.,
296
+ norm_layer: nn.Module = nn.LayerNorm,
297
+ ) -> None:
298
+ super().__init__()
299
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
300
+ self.num_heads = num_heads
301
+ self.head_dim = dim // num_heads
302
+ self.scale = self.head_dim ** -0.5
303
+
304
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
305
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
306
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
307
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
308
+ self.attn_drop = nn.Dropout(attn_drop)
309
+ self.proj = nn.Linear(dim, dim)
310
+ self.proj_drop = nn.Dropout(proj_drop)
311
+
312
+ def forward(self, x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
313
+ B, N, C = y.shape
314
+ kv = self.kv(y).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
315
+ k, v = kv.unbind(0)
316
+
317
+ B, N, C = x.shape
318
+ q = self.q(x).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)[0]
319
+
320
+ q, k = self.q_norm(q), self.k_norm(k)
321
+ q = q * self.scale
322
+ attn = q @ k.transpose(-2, -1)
323
+ if mask is not None:
324
+ attn = attn.masked_fill(mask, float('-inf'))
325
+ attn = attn.softmax(dim=-1)
326
+ attn = self.attn_drop(attn)
327
+ x = attn @ v
328
+
329
+ x = x.transpose(1, 2).reshape(B, N, C)
330
+ x = self.proj(x)
331
+ x = self.proj_drop(x)
332
+ return x
333
+
334
+
335
+ class LayerScale(nn.Module):
336
+ def __init__(
337
+ self,
338
+ dim: int,
339
+ init_values: float = 1e-5,
340
+ inplace: bool = False,
341
+ ) -> None:
342
+ super().__init__()
343
+ self.inplace = inplace
344
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
345
+
346
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
347
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
348
+
349
+
350
+ class Block(nn.Module):
351
+ def __init__(
352
+ self,
353
+ dim: int,
354
+ num_heads: int,
355
+ mlp_ratio: float = 4.,
356
+ qkv_bias: bool = True,
357
+ qk_norm: bool = False,
358
+ proj_drop: float = 0.,
359
+ attn_drop: float = 0.,
360
+ init_values: Optional[float] = None,
361
+ drop_path: float = 0.,
362
+ act_layer: nn.Module = nn.GELU,
363
+ norm_layer: nn.Module = nn.LayerNorm,
364
+ ) -> None:
365
+ super().__init__()
366
+ self.norm1 = norm_layer(dim)
367
+ self.attn = Attention(
368
+ dim,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ qk_norm=qk_norm,
372
+ attn_drop=attn_drop,
373
+ proj_drop=proj_drop,
374
+ norm_layer=norm_layer,
375
+ )
376
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
377
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
378
+
379
+ self.norm2 = norm_layer(dim)
380
+ self.mlp = Mlp(
381
+ in_features=dim,
382
+ hidden_features=int(dim * mlp_ratio),
383
+ act_layer=act_layer,
384
+ drop=proj_drop,
385
+ )
386
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
387
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
388
+
389
+ def forward(self, x, attn_mask=None):
390
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask)))
391
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
392
+ return x
393
+
394
+
395
+ class CrossBlock(nn.Module):
396
+ def __init__(
397
+ self,
398
+ dim: int,
399
+ num_heads: int,
400
+ mlp_ratio: float = 4.,
401
+ qkv_bias: bool = True,
402
+ qk_norm: bool = False,
403
+ proj_drop: float = 0.,
404
+ attn_drop: float = 0.,
405
+ init_values: Optional[float] = None,
406
+ drop_path: float = 0.,
407
+ act_layer: nn.Module = nn.GELU,
408
+ norm_layer: nn.Module = nn.LayerNorm,
409
+ stop_grad: bool = False
410
+ ) -> None:
411
+ super().__init__()
412
+ self.norm1 = norm_layer(dim)
413
+ self.attn = CrossAttention(
414
+ dim,
415
+ num_heads=num_heads,
416
+ qkv_bias=qkv_bias,
417
+ qk_norm=qk_norm,
418
+ attn_drop=attn_drop,
419
+ proj_drop=proj_drop,
420
+ norm_layer=norm_layer,
421
+ )
422
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
423
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
424
+
425
+ self.norm2 = norm_layer(dim)
426
+ self.mlp = Mlp(
427
+ in_features=dim,
428
+ hidden_features=int(dim * mlp_ratio),
429
+ act_layer=act_layer,
430
+ drop=proj_drop,
431
+ )
432
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
433
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
434
+
435
+ self.stop_grad = stop_grad
436
+
437
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
438
+ if self.stop_grad:
439
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), self.norm1(y.detach()))))
440
+ else:
441
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), self.norm1(y))))
442
+
443
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
444
+ return x
445
+
446
+
447
+ class ReConBlocks(nn.Module):
448
+ def __init__(
449
+ self,
450
+ embed_dim: int = 768,
451
+ depth: int = 12,
452
+ num_heads: int = 12,
453
+ mlp_ratio: float = 4.,
454
+ qkv_bias: bool = True,
455
+ qk_norm: bool = False,
456
+ init_values: Optional[float] = None,
457
+ proj_drop: float = 0.,
458
+ attn_drop_rate: float = 0.,
459
+ drop_path_rate: List = [],
460
+ norm_layer: nn.Module = nn.LayerNorm,
461
+ act_layer: nn.Module = nn.GELU,
462
+ stop_grad: bool = False,
463
+ pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
464
+ every_layer_add_pos: bool = True,
465
+ ):
466
+ super().__init__()
467
+
468
+ self.depth = depth
469
+ self.stop_grad = stop_grad
470
+ self.pretrained_model_name = pretrained_model_name
471
+ self.every_layer_add_pos = every_layer_add_pos
472
+ if 'dino' in self.pretrained_model_name:
473
+ init_values = 1e-5
474
+ if 'giant' in self.pretrained_model_name:
475
+ mlp_ratio = 48 / 11
476
+ self.local_blocks = nn.Sequential(*[
477
+ Block(
478
+ dim=embed_dim,
479
+ num_heads=num_heads,
480
+ mlp_ratio=mlp_ratio,
481
+ qkv_bias=qkv_bias,
482
+ qk_norm=qk_norm,
483
+ init_values=init_values,
484
+ proj_drop=proj_drop,
485
+ attn_drop=attn_drop_rate,
486
+ drop_path=drop_path_rate[i],
487
+ norm_layer=norm_layer,
488
+ act_layer=act_layer
489
+ )
490
+ for i in range(depth)])
491
+
492
+ self.global_blocks = nn.Sequential(*[
493
+ CrossBlock(
494
+ dim=embed_dim,
495
+ num_heads=num_heads,
496
+ mlp_ratio=mlp_ratio,
497
+ qkv_bias=qkv_bias,
498
+ qk_norm=qk_norm,
499
+ init_values=init_values,
500
+ proj_drop=proj_drop,
501
+ attn_drop=attn_drop_rate,
502
+ drop_path=drop_path_rate[i],
503
+ norm_layer=norm_layer,
504
+ act_layer=act_layer,
505
+ stop_grad=stop_grad
506
+ )
507
+ for i in range(depth)])
508
+
509
+ def load_pretrained_timm_weights(self):
510
+ model = timm.create_model(self.pretrained_model_name, pretrained=True)
511
+ state_dict = model.blocks.state_dict()
512
+ self.local_blocks.load_state_dict(state_dict, strict=True)
513
+
514
+ cross_state_dict = {}
515
+ for k, v in state_dict.items():
516
+ if 'qkv' in k:
517
+ cross_state_dict[k.replace('qkv', 'q')] = v[:int(v.shape[0] / 3)]
518
+ cross_state_dict[k.replace('qkv', 'kv')] = v[int(v.shape[0] / 3):]
519
+ else:
520
+ cross_state_dict[k] = v
521
+ self.global_blocks.load_state_dict(cross_state_dict, strict=True)
522
+
523
+ def forward(self, x, pos, attn_mask=None, query=None):
524
+ if self.every_layer_add_pos:
525
+ for i in range(self.depth):
526
+ x = self.local_blocks[i](x + pos, attn_mask)
527
+ if query is not None:
528
+ query = self.global_blocks[i](query, x)
529
+ else:
530
+ x = x + pos
531
+ for i in range(self.depth):
532
+ x = self.local_blocks[i](x, attn_mask)
533
+ if query is not None:
534
+ query = self.global_blocks[i](query, x)
535
+ return x, query
536
+
537
+
538
+ class GPTExtractor(nn.Module):
539
+ def __init__(
540
+ self,
541
+ embed_dim: int = 768,
542
+ num_heads: int = 12,
543
+ depth: int = 12,
544
+ group_size: int = 32,
545
+ drop_path_rate: float = 0.0,
546
+ stop_grad: bool = False,
547
+ pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
548
+ ):
549
+ super(GPTExtractor, self).__init__()
550
+
551
+ self.embed_dim = embed_dim
552
+ self.group_size = group_size
553
+
554
+ # start of sequence token
555
+ self.sos = nn.Parameter(torch.zeros(1, 1, embed_dim))
556
+ self.sos_pos = nn.Parameter(torch.zeros(1, 1, embed_dim))
557
+ nn.init.normal_(self.sos)
558
+ nn.init.normal_(self.sos_pos)
559
+
560
+ drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
561
+ self.blocks = ReConBlocks(
562
+ embed_dim=embed_dim,
563
+ num_heads=num_heads,
564
+ depth=depth,
565
+ drop_path_rate=drop_path_rate,
566
+ stop_grad=stop_grad,
567
+ pretrained_model_name=pretrained_model_name,
568
+ )
569
+
570
+ self.ln_f1 = nn.LayerNorm(embed_dim)
571
+ self.ln_f2 = nn.LayerNorm(embed_dim)
572
+
573
+ def forward(self, x, pos, attn_mask, query):
574
+ """
575
+ Expect input as shape [sequence len, batch]
576
+ """
577
+
578
+ batch, length, _ = x.shape
579
+
580
+ # prepend sos token
581
+ sos = self.sos.expand(batch, -1, -1)
582
+ sos_pos = self.sos_pos.expand(batch, -1, -1)
583
+
584
+ x = torch.cat([sos, x[:, :-1]], dim=1)
585
+ pos = torch.cat([sos_pos, pos[:, :-1]], dim=1)
586
+
587
+ # transformer
588
+ x, query = self.blocks(x, pos, attn_mask, query)
589
+
590
+ encoded_points = self.ln_f1(x)
591
+ query = self.ln_f2(query)
592
+
593
+ return encoded_points, query
594
+
595
+
596
+ class MAEExtractor(nn.Module):
597
+ def __init__(
598
+ self,
599
+ embed_dim: int = 768,
600
+ num_heads: int = 12,
601
+ depth: int = 12,
602
+ group_size: int = 32,
603
+ drop_path_rate: float = 0.0,
604
+ stop_grad: bool = False,
605
+ pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
606
+ ):
607
+ super(MAEExtractor, self).__init__()
608
+
609
+ self.embed_dim = embed_dim
610
+ self.group_size = group_size
611
+
612
+ drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
613
+ self.blocks = ReConBlocks(
614
+ embed_dim=embed_dim,
615
+ num_heads=num_heads,
616
+ depth=depth,
617
+ drop_path_rate=drop_path_rate,
618
+ stop_grad=stop_grad,
619
+ pretrained_model_name=pretrained_model_name,
620
+ )
621
+
622
+ self.ln_f1 = nn.LayerNorm(embed_dim)
623
+ self.ln_f2 = nn.LayerNorm(embed_dim)
624
+
625
+ def forward(self, x, pos, mask=None, query=None):
626
+ """
627
+ Expect input as shape [sequence len, batch]
628
+ """
629
+
630
+ batch, length, C = x.shape
631
+ if mask is not None:
632
+ x_vis = x[~mask].reshape(batch, -1, C)
633
+ pos_vis = pos[~mask].reshape(batch, -1, C)
634
+ else:
635
+ x_vis = x
636
+ pos_vis = pos
637
+
638
+ # transformer
639
+ x_vis, query = self.blocks(x_vis, pos_vis, None, query)
640
+
641
+ encoded_points = self.ln_f1(x_vis)
642
+ query = self.ln_f2(query)
643
+
644
+ return encoded_points, query
645
+
646
+
647
+
mm_models/modal_module/point/reconv2.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from easydict import EasyDict
3
+ import timm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .recon.transformer import Group, ZGroup, PatchEmbedding, PositionEmbeddingCoordsSine, GPTExtractor, MAEExtractor
8
+
9
+
10
+ POINT_RECON2_MODAL_CFG = {
11
+ 'modal_tag': 'point',
12
+ 'modal_placeholder_token': "<|point_placeholder|>",
13
+ 'model_path': None,
14
+ 'group_size': 32,
15
+ 'num_group': 512,
16
+ 'mask_type': 'rand',
17
+ 'embed_dim': 1024,
18
+ 'depth': 24,
19
+ 'drop_path_rate': 0.1,
20
+ 'num_heads': 16,
21
+ 'with_color': True,
22
+ 'stop_grad': False,
23
+ 'large_embedding': False,
24
+ 'img_queries': 13,
25
+ 'text_queries': 3,
26
+ 'pretrained_model_name': 'eva_large_patch14_336.in22k_ft_in22k_in1k',
27
+ 'output_dim': 896
28
+ }
29
+
30
+
31
+ class MaskTransformer(nn.Module):
32
+ def __init__(self, config):
33
+ super(MaskTransformer, self).__init__()
34
+
35
+ self.embed_dim = config.embed_dim
36
+ self.num_group = config.num_group
37
+ self.group_size = config.group_size
38
+ self.with_color = config.with_color
39
+ self.input_channel = 6 if self.with_color else 3
40
+ self.img_queries = config.img_queries
41
+ self.text_queries = config.text_queries
42
+ self.global_query_num = self.img_queries + self.text_queries
43
+ self.mask_type = config.mask_type
44
+ self.stop_grad = config.stop_grad
45
+
46
+ self.embed = PatchEmbedding(embed_dim=self.embed_dim, input_channel=self.input_channel,
47
+ large=config.large_embedding)
48
+
49
+ print(f'[ReCon] divide point cloud into G{config.num_group} x S{config.group_size} points ...')
50
+
51
+ if self.mask_type == 'causal':
52
+ self.group_divider = ZGroup(num_group=config.num_group, group_size=config.group_size)
53
+ self.encoder = GPTExtractor(
54
+ embed_dim=config.embed_dim,
55
+ num_heads=config.num_heads,
56
+ depth=config.depth,
57
+ group_size=config.group_size,
58
+ drop_path_rate=config.drop_path_rate,
59
+ stop_grad=self.stop_grad,
60
+ pretrained_model_name=config.pretrained_model_name,
61
+ )
62
+ self.pos_embed = PositionEmbeddingCoordsSine(3, self.embed_dim, 1.0)
63
+ else:
64
+ self.group_divider = Group(num_group=config.num_group, group_size=config.group_size)
65
+ self.encoder = MAEExtractor(
66
+ embed_dim=config.embed_dim,
67
+ num_heads=config.num_heads,
68
+ depth=config.depth,
69
+ group_size=config.group_size,
70
+ drop_path_rate=config.drop_path_rate,
71
+ stop_grad=self.stop_grad,
72
+ pretrained_model_name=config.pretrained_model_name,
73
+ )
74
+ self.pos_embed = nn.Sequential(
75
+ nn.Linear(3, 128),
76
+ nn.GELU(),
77
+ nn.Linear(128, self.embed_dim)
78
+ )
79
+
80
+ self.norm = nn.LayerNorm(self.embed_dim)
81
+ self.global_query = nn.Parameter(torch.zeros(1, self.global_query_num, self.embed_dim))
82
+ self.apply(self._init_weights)
83
+
84
+ self.num_group = config.num_group
85
+
86
+ def _init_weights(self, m):
87
+ if isinstance(m, nn.Linear):
88
+ nn.init.normal_(m.weight, 0.02, 0.01)
89
+ if isinstance(m, nn.Linear) and m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+ elif isinstance(m, nn.BatchNorm1d):
92
+ nn.init.constant_(m.bias, 0)
93
+ nn.init.constant_(m.weight, 1.0)
94
+
95
+ def inference(self, pts):
96
+ with torch.no_grad():
97
+ neighborhood, center = self.group_divider(pts)
98
+ group_input_tokens = self.embed(neighborhood) # B G C
99
+ batch_size, seq_len, C = group_input_tokens.size()
100
+
101
+ global_query = self.global_query.expand(batch_size, -1, -1)
102
+ pos = self.pos_embed(center.to(group_input_tokens.dtype))
103
+
104
+ mask = torch.full(
105
+ (seq_len, seq_len), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
106
+ ).to(torch.bool)
107
+ if self.mask_type == 'causal':
108
+ mask = torch.triu(mask, diagonal=1)
109
+ else:
110
+ mask = None
111
+
112
+ local_features, global_features = self.encoder(
113
+ group_input_tokens, pos, mask, global_query)
114
+
115
+ return pos, local_features, global_features
116
+
117
+
118
+ class ReCon2(nn.Module):
119
+ def __init__(self, config):
120
+ super().__init__()
121
+ self.config = config
122
+ self.embed_dim = config.embed_dim
123
+ self.with_color = config.with_color
124
+ self.img_queries = config.img_queries
125
+ self.text_queries = config.text_queries
126
+ self.global_query_num = self.img_queries + self.text_queries
127
+ self.input_channel = 6 if self.with_color else 3
128
+
129
+ self.model = MaskTransformer(config)
130
+
131
+ self.img_proj = nn.Linear(self.embed_dim, 1280)
132
+ self.img_proj.apply(self._init_weights)
133
+ self.text_proj = nn.Linear(self.embed_dim, 1280)
134
+ self.text_proj.apply(self._init_weights)
135
+
136
+ def _init_weights(self, m):
137
+ if isinstance(m, nn.Linear):
138
+ nn.init.normal_(m.weight, 0.02, 0.01)
139
+ if isinstance(m, nn.Linear) and m.bias is not None:
140
+ nn.init.constant_(m.bias, 0)
141
+ elif isinstance(m, nn.BatchNorm1d):
142
+ nn.init.constant_(m.bias, 0)
143
+ nn.init.constant_(m.weight, 1.0)
144
+
145
+ @property
146
+ def device(self):
147
+ return next(self.parameters()).device
148
+
149
+ @property
150
+ def dtype(self):
151
+ return next(self.parameters()).dtype
152
+
153
+
154
+ class ReConv2PointEncoder(nn.Module):
155
+
156
+ def __init__(self, config):
157
+ super().__init__()
158
+ self.config = config
159
+ self.vision_tower = ReCon2(self.config)
160
+
161
+ @torch.no_grad()
162
+ def forward(self, pts):
163
+
164
+ pts = torch.stack(pts, dim=0)
165
+ pos_features, local_features, global_features = \
166
+ self.vision_tower.model.inference(pts.to(device=self.device, dtype=self.dtype))
167
+ local_features = local_features.to(pts.dtype)
168
+ global_features = global_features.to(pts.dtype)
169
+
170
+ return (pos_features, local_features, global_features)
171
+
172
+ @property
173
+ def dtype(self):
174
+ return self.vision_tower.dtype
175
+
176
+ @property
177
+ def device(self):
178
+ return self.vision_tower.device
179
+
180
+
181
+ class ReConProjector_MLP(nn.Module):
182
+ def __init__(self, in_channels, out_channels, mlp_depth, prompt_token_num,
183
+ with_ape=True, with_local=True, with_global=True):
184
+ super().__init__()
185
+
186
+ self.in_channels = in_channels
187
+ self.out_channels = out_channels
188
+ self.mlp_depth = mlp_depth
189
+ self.prompt_token_num = prompt_token_num
190
+ self.with_ape = with_ape
191
+ self.with_local = with_local
192
+ self.with_global = with_global
193
+
194
+ if prompt_token_num > 0:
195
+ self.prompt1 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
196
+ self.prompt2 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
197
+ self.prompt3 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
198
+
199
+ self.proj1 = self.set_proj()
200
+ self.proj2 = self.set_proj()
201
+ self.proj3 = self.set_proj()
202
+
203
+ def set_proj(self):
204
+ modules = [nn.Linear(self.in_channels, self.out_channels)]
205
+ for i in range(1, self.mlp_depth):
206
+ modules.append(nn.GELU())
207
+ modules.append(nn.Linear(self.out_channels, self.out_channels))
208
+ modules.append(nn.GELU())
209
+ modules.append(nn.Linear(self.out_channels, self.out_channels))
210
+ return nn.Sequential(*modules)
211
+
212
+ def forward(self, proj_inps):
213
+ pos_feat, local_feat, global_feat = proj_inps
214
+ B = pos_feat.shape[0]
215
+ pos_feat = self.proj1(pos_feat)
216
+ local_feat = self.proj2(local_feat)
217
+ global_feat = self.proj3(global_feat)
218
+
219
+ if self.prompt_token_num > 0:
220
+ pos_feat = torch.cat([self.prompt1.expand(B, -1, -1), pos_feat], dim=1)
221
+ local_feat = torch.cat([self.prompt2.expand(B, -1, -1), local_feat], dim=1)
222
+ global_feat = torch.cat([self.prompt3.expand(B, -1, -1), global_feat], dim=1)
223
+
224
+ pts_feat = [feat for feat, flag in [(pos_feat, self.with_ape), (local_feat, self.with_local), (global_feat, self.with_global)] if flag]
225
+ pts_feat = torch.cat(pts_feat, dim=1)
226
+
227
+ pts_feat = torch.split(pts_feat, 1)
228
+ pts_feat = [item.squeeze() for item in pts_feat]
229
+ return pts_feat
230
+
231
+
232
+ def build_point_encoder(modal_cfg):
233
+ assert modal_cfg['modal_tag'] == 'point', f"building point encoder with '{modal_cfg['modal_tag']}' tag is not supported"
234
+ if "encoder_cfg" not in modal_cfg: # init
235
+ cfg = EasyDict(modal_cfg)
236
+ model = ReConv2PointEncoder(cfg)
237
+ print(f"loading point encoder from {modal_cfg['model_path']}")
238
+ model_path = modal_cfg['model_path']
239
+ model.vision_tower.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
240
+ else:
241
+ cfg = EasyDict(modal_cfg["encoder_cfg"])
242
+ model = ReConv2PointEncoder(cfg)
243
+ return model
244
+
245
+ def build_point_projector(modal_cfg):
246
+ assert modal_cfg['modal_tag'] == 'point', f"building point projector with '{modal_cfg['modal_tag']}' tag is not supported"
247
+ if "encoder_cfg" in modal_cfg:
248
+ proj_cfg = EasyDict(modal_cfg["encoder_cfg"])
249
+ else:
250
+ proj_cfg = EasyDict(modal_cfg)
251
+ projector = ReConProjector_MLP(in_channels=proj_cfg.embed_dim,
252
+ out_channels=proj_cfg.output_dim,
253
+ mlp_depth=2, prompt_token_num=1)
254
+ return projector
255
+
256
+
257
+ # if __name__ == '__main__':
258
+
259
+ # encoder = build_point_encoder(POINT_RECON2_MODAL_CFG)
260
+ # print(encoder)
261
+ # data = torch.randn((1, 8192, 6))
262
+ # pos_features, local_features, global_features = encoder(data)
263
+ # print(pos_features.shape, local_features.shape, global_features.shape)
264
+ # proj = build_point_projector(POINT_RECON2_MODAL_CFG)
265
+ # point_token = proj((pos_features, local_features, global_features))
266
+ # print(point_token.shape)
mm_models/modal_module/vision/__pycache__/siglip.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
mm_models/modal_module/vision/siglip.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SiglipVisionModel, SiglipVisionConfig
2
+ import torch.nn as nn
3
+ import torch
4
+ import einops
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+
8
+
9
+ VISION_SIGLIP_MODAL_CFG = {
10
+ 'modal_tag': 'vision',
11
+ 'model_name_or_path': None,
12
+ 'modal_placeholder_token': "<|vision_placeholder|>",
13
+ 'proj_num_layers': 2,
14
+ 'proj_input_dim': 1152,
15
+ 'proj_output_dim': 896, # for qwen2.5 0.5B
16
+ 'multi_grid': False
17
+ }
18
+
19
+
20
+ class SiglipVisionModelForMM(SiglipVisionModel):
21
+
22
+ def forward(self, pixel_values, **kwargs):
23
+
24
+ assert type(pixel_values) == list
25
+
26
+ split_sizes = []
27
+ temp = []
28
+ for pixel_value in pixel_values:
29
+ if pixel_value.dim() == 3:
30
+ pixel_value = pixel_value.unsqueeze(0)
31
+ temp.append(pixel_value)
32
+ split_sizes.append(pixel_value.shape[0])
33
+ pixel_values = torch.cat(temp, dim=0) # (BG) 3 H W
34
+
35
+ outputs = super().forward(pixel_values, output_hidden_states=True, **kwargs)
36
+ hidden_states = outputs['hidden_states'][-2]
37
+
38
+ return (hidden_states, split_sizes)
39
+
40
+
41
+ class VisionProjector_MLP(nn.Module):
42
+
43
+ def __init__(self, proj_num_layers, proj_input_dim, proj_output_dim, multi_grid=False):
44
+ super().__init__()
45
+ _proj_input_dim = int(4*proj_input_dim)
46
+ module = [nn.Linear(_proj_input_dim, proj_output_dim)]
47
+ for _ in range(proj_num_layers - 1):
48
+ module.append(nn.GELU())
49
+ module.append(nn.Linear(proj_output_dim, proj_output_dim))
50
+ module.append(nn.GELU())
51
+ module.append(nn.Linear(proj_output_dim, proj_output_dim))
52
+ self.module = nn.Sequential(*module)
53
+
54
+ self.resample_pad_token = nn.Parameter(torch.randn((1, proj_input_dim)))
55
+ init.kaiming_normal_(self.resample_pad_token)
56
+
57
+ self.multi_grid = multi_grid
58
+ if self.multi_grid:
59
+ self.grid_sep = nn.Parameter(torch.randn((1, proj_output_dim)))
60
+ init.kaiming_normal_(self.grid_sep)
61
+
62
+ def forward(self, encoder_output):
63
+ visual_tokens, split_sizes = encoder_output
64
+
65
+ B, L, D = visual_tokens.shape
66
+
67
+ # Pooling
68
+ n_patch = int(L**0.5)
69
+
70
+ visual_tokens = einops.rearrange(visual_tokens, "B (h w) D -> B D h w", h=n_patch, w=n_patch)
71
+ if n_patch % 2 != 0:
72
+ visual_tokens = F.pad(visual_tokens, (0, 1, 0, 1), value=0)
73
+ visual_tokens = einops.rearrange(visual_tokens, "B D h w -> B h w D")
74
+ visual_tokens[:, -1, -1, :] = self.resample_pad_token.expand(B, -1)
75
+ n_patch += 1
76
+
77
+ visual_tokens = visual_tokens.view(B, n_patch // 2, 2, n_patch // 2, 2, D) # (B, n//2, 2, n//2, 2, D)
78
+ visual_tokens = visual_tokens.permute(0, 1, 3, 2, 4, 5) # (B, n//2, n//2, 2, 2, D)
79
+ visual_tokens = visual_tokens.contiguous().view(B, n_patch // 2, n_patch // 2, D * 4) # (B, n//2, n//2, D*4)
80
+
81
+ visual_tokens = einops.rearrange(visual_tokens, "B h w D -> B (h w) D")
82
+
83
+ visual_tokens = self.module(visual_tokens)
84
+
85
+ # Grid
86
+ if self.multi_grid:
87
+ visual_tokens = torch.split(visual_tokens, split_sizes) # B [G n D]
88
+ visual_tokens_list = []
89
+ for grid_visual_tokens in visual_tokens:
90
+ grid_visual_tokens = torch.cat([grid_visual_tokens,
91
+ self.grid_sep.repeat(grid_visual_tokens.shape[0], 1, 1)], dim=1)
92
+ grid_visual_tokens = einops.rearrange(grid_visual_tokens, "G n D -> (G n) D")[:-1, :]
93
+ visual_tokens_list.append(grid_visual_tokens)
94
+ else:
95
+ visual_tokens_list = torch.split(visual_tokens, 1)
96
+ visual_tokens_list = [item.squeeze() for item in visual_tokens_list]
97
+
98
+ return visual_tokens_list
99
+
100
+
101
+ def build_vision_encoder(modal_cfg):
102
+ assert modal_cfg['modal_tag'] == 'vision', f"building vision encoder with '{modal_cfg['modal_tag']}' tag is not supported"
103
+ if "encoder_cfg" not in modal_cfg: # from pretrained
104
+ model = SiglipVisionModelForMM.from_pretrained(modal_cfg['model_name_or_path'])
105
+ else:
106
+ cfg = SiglipVisionConfig(**modal_cfg['encoder_cfg'])
107
+ model = SiglipVisionModelForMM._from_config(cfg)
108
+ return model
109
+
110
+
111
+ def build_vision_projector(modal_cfg):
112
+ assert modal_cfg['modal_tag'] == 'vision', f"building vision projector with '{modal_cfg['modal_tag']}' tag is not supported"
113
+ return VisionProjector_MLP(modal_cfg['proj_num_layers'], modal_cfg['proj_input_dim'], modal_cfg['proj_output_dim'],
114
+ modal_cfg['multi_grid'])
115
+
116
+
117
+ if __name__ == '__main__':
118
+ projector = VisionProjector_MLP(2, 1152, 2048)
119
+ inputs = torch.randn((5, 729, 1152))
120
+ outputs = projector(inputs)
121
+ print(outputs.shape)
122
+
mm_models/modeling_mm.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PretrainedConfig, PreTrainedModel, Qwen2Config
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from transformers.generation.utils import GenerateOutput
6
+ from .configuration_mm import AllSparkConfig
7
+ from .modal_module import MODAL_ENCODERS_MAPPING, MODAL_PROJECTORS_MAPPING
8
+ from .llms.qwen_model_moe import Qwen2ForCausalLMMoE
9
+ from typing import Optional, List, Union, Tuple
10
+ import torch.nn.init as init
11
+ from utils import rank0_print
12
+
13
+
14
+ class AllSparkPreTrainedModel(PreTrainedModel):
15
+
16
+ config_class = AllSparkConfig
17
+ base_model_prefix = "allspark"
18
+
19
+ def _init_weights(self, module):
20
+ """Initialize the weights"""
21
+ std = (
22
+ self.config.initializer_range
23
+ if hasattr(self.config, "initializer_range")
24
+ else 0.02
25
+ )
26
+
27
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
28
+ module.weight.data.normal_(mean=0.0, std=std)
29
+ if module.bias is not None:
30
+ module.bias.data.zero_()
31
+ elif isinstance(module, nn.Embedding):
32
+ module.weight.data.normal_(mean=0.0, std=std)
33
+ if module.padding_idx is not None:
34
+ module.weight.data[module.padding_idx].zero_()
35
+
36
+ def prepare_multimodal_inputs(self, input_ids, modal_inputs, labels, attention_mask):
37
+ if modal_inputs is None:
38
+ return input_ids, None, labels, attention_mask, None
39
+
40
+ modal_tensors = dict()
41
+ for single_sample_modal_inputs in modal_inputs:
42
+ for tag, modal_tensor in single_sample_modal_inputs:
43
+ if tag in modal_tensors:
44
+ modal_tensors[tag].append(modal_tensor)
45
+ else:
46
+ modal_tensors[tag] = [modal_tensor]
47
+ for tag in modal_tensors:
48
+ modal_tensors[tag] = self.modal_projectors[tag](self.modal_encoders[tag](modal_tensors[tag])) # B [N D]
49
+ for sample_id, single_sample_modal_inputs in enumerate(modal_inputs):
50
+ for modal_id, (tag, _) in enumerate(single_sample_modal_inputs):
51
+ modal_inputs[sample_id][modal_id] = (tag, modal_tensors[tag].pop(0).squeeze(0))
52
+ for tag in modal_tensors:
53
+ assert len(modal_tensors[tag]) == 0
54
+
55
+ if attention_mask is None:
56
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
57
+ else:
58
+ attention_mask = attention_mask.bool()
59
+ if labels is None:
60
+ labels = torch.full_like(input_ids, self.config.ignore_index)
61
+
62
+ # input_ids: (batch_size, seq_len)
63
+ # modal_inputs: (B, M, Tuple[str, Tensor])
64
+ # labels: (batch_size, seq_len)
65
+ assert input_ids.shape[0] == len(modal_inputs) == labels.shape[0], \
66
+ f"Batch size mismatch: {input_ids.shape[0]} vs {len(modal_inputs)} vs {labels.shape[0]}" \
67
+ "If some sample has no modal inputs, please append a empty list to modal_inputs."
68
+
69
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
70
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
71
+
72
+ modal_tag_pos_list = []
73
+ new_input_embeds = []
74
+ new_labels = []
75
+ for single_sample_input_ids, single_sample_modal_inputs, single_sample_labels in zip(input_ids, modal_inputs, labels):
76
+ tag_num = dict()
77
+ cur_id = 0
78
+ single_sample_embeds = []
79
+ _labels = []
80
+ single_sample_modal_tag_pos_list = []
81
+ for modal_input in single_sample_modal_inputs:
82
+ tag, modal_tensor = modal_input
83
+ if tag not in self.modal_tags:
84
+ raise ValueError(f"Unknown modal tag: {tag}. Vaid modal tags: {self.modal_tags}")
85
+ if tag not in tag_num:
86
+ tag_num[tag] = 0
87
+
88
+ for modal_config in self.config.modal_configs:
89
+ if modal_config["modal_tag"] == tag:
90
+ modal_placeholder_token_id = modal_config["modal_placeholder_token_id"]
91
+ break
92
+
93
+ cur_modal_idx = torch.where(single_sample_input_ids == modal_placeholder_token_id)[0].tolist()[tag_num[tag]]
94
+
95
+ single_sample_embeds.append(self.llm.get_input_embeddings()(single_sample_input_ids[cur_id:cur_modal_idx]))
96
+ single_sample_embeds.append(self.modal_embeds[tag][0:1, :]) # start embed
97
+ single_sample_embeds.append(modal_tensor)
98
+ single_sample_embeds.append(self.modal_embeds[tag][1:2, :]) # end embed
99
+
100
+ _labels.append(single_sample_labels[cur_id:cur_modal_idx])
101
+ _labels.append(torch.full((modal_tensor.shape[0]+2,), self.config.ignore_index, device=single_sample_labels.device, dtype=single_sample_labels.dtype))
102
+
103
+ single_sample_modal_tag_pos_list.append((tag, cur_modal_idx, cur_modal_idx+modal_tensor.shape[0]+1))
104
+
105
+ cur_id += cur_modal_idx+1
106
+ tag_num[tag] += 1
107
+
108
+ single_sample_embeds.append(self.llm.get_input_embeddings()(single_sample_input_ids[cur_id:]))
109
+ _labels.append(single_sample_labels[cur_id:])
110
+
111
+ new_input_embeds.append(torch.cat(single_sample_embeds, dim=0))
112
+ new_labels.append(torch.cat(_labels, dim=0))
113
+ modal_tag_pos_list.append(single_sample_modal_tag_pos_list)
114
+
115
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
116
+ if tokenizer_model_max_length is not None:
117
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
118
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
119
+
120
+ max_len = max(x.shape[0] for x in new_input_embeds)
121
+ batch_size = len(new_input_embeds)
122
+
123
+ new_input_embeds_padded = []
124
+ new_labels_padded = torch.full((batch_size, max_len), self.config.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device)
125
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
126
+
127
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
128
+ cur_len = cur_new_embed.shape[0]
129
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
130
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
131
+ if cur_len > 0:
132
+ new_labels_padded[i, -cur_len:] = cur_new_labels
133
+ attention_mask[i, -cur_len:] = True
134
+ else:
135
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
136
+ if cur_len > 0:
137
+ new_labels_padded[i, :cur_len] = cur_new_labels
138
+ attention_mask[i, :cur_len] = True
139
+
140
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
141
+
142
+ return None, new_input_embeds, new_labels_padded, attention_mask, modal_tag_pos_list
143
+
144
+
145
+ class AllSparkForCausalLM(AllSparkPreTrainedModel):
146
+
147
+ def __init__(self, config):
148
+ super().__init__(config)
149
+
150
+ if self.config.modal_configs is not None:
151
+ self.modal_tags = []
152
+ self.modal_encoders, self.modal_projectors = nn.ModuleDict(), nn.ModuleDict()
153
+ for modal_config in self.config.modal_configs:
154
+ modal_tag = modal_config['modal_tag']
155
+ assert modal_tag not in self.modal_tags, f"Duplicate modal tag: {modal_tag}"
156
+ self.modal_tags.append(modal_tag)
157
+ self.modal_encoders[modal_tag] = MODAL_ENCODERS_MAPPING[modal_tag](modal_config)
158
+ encoder_cfg = self.modal_encoders[modal_tag].config
159
+ if isinstance(encoder_cfg, PretrainedConfig):
160
+ encoder_cfg = encoder_cfg.to_dict()
161
+ modal_config['encoder_cfg'] = encoder_cfg
162
+ self.modal_projectors[modal_tag] = MODAL_PROJECTORS_MAPPING[modal_tag](modal_config)
163
+ else:
164
+ self.modal_tags = None
165
+
166
+ if hasattr(config, 'llm_config'):
167
+ if "Qwen2" in config.llm_name_or_path:
168
+ llm_config = Qwen2Config(**config.llm_config)
169
+ self.llm = Qwen2ForCausalLMMoE._from_config(llm_config, modal_tags=self.modal_tags,
170
+ add_moe=self.config.add_moe)
171
+ else:
172
+ raise ValueError(config.llm_name_or_path)
173
+ else:
174
+ if "Qwen2" in config.llm_name_or_path:
175
+ self.llm = Qwen2ForCausalLMMoE.from_pretrained(config.llm_name_or_path, modal_tags=self.modal_tags,
176
+ add_moe=self.config.add_moe)
177
+ else:
178
+ raise ValueError(config.llm_name_or_path)
179
+ self.config.llm_config = self.llm.config
180
+ self.config.hidden_size = self.llm.config.hidden_size
181
+
182
+ if self.config.modal_configs is not None:
183
+ self.modal_embeds = nn.ParameterDict()
184
+ for modal_config in self.config.modal_configs:
185
+ modal_tag = modal_config['modal_tag']
186
+ self.modal_embeds[modal_tag] = torch.randn((2, self.config.hidden_size)) # start and end embeds
187
+ init.kaiming_normal_(self.modal_embeds[modal_tag])
188
+
189
+ self.post_init()
190
+
191
+ def initialize_tokenizer_for_multimodal(self, tokenizer, new_tag):
192
+ config = self.config
193
+ if config.modal_configs is None:
194
+ rank0_print("No modal configs provided, skipping multimodal tokenizer initialization.")
195
+ return None
196
+
197
+ for i, modal_config in enumerate(config.modal_configs):
198
+ # only add new tokens for the new modal
199
+ if modal_config['modal_tag'] != new_tag:
200
+ continue
201
+ modal_placeholder_token = modal_config['modal_placeholder_token']
202
+ tokenizer.add_tokens([modal_placeholder_token], special_tokens=True)
203
+ self.config.modal_configs[i]['modal_placeholder_token_id'] = tokenizer.convert_tokens_to_ids(modal_placeholder_token)
204
+
205
+ self.llm.resize_token_embeddings(len(tokenizer), mean_resizing=False)
206
+
207
+ def forward(
208
+ self,
209
+ input_ids: torch.LongTensor = None,
210
+ modal_inputs: Optional[List[List[Tuple[str, torch.FloatTensor]]]] = None, # B M (modal_tag, modal_input)
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ labels: Optional[torch.LongTensor] = None,
213
+ **kwargs
214
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
215
+
216
+ if modal_inputs is not None:
217
+ input_ids, inputs_embeds, labels, attention_mask, modal_tag_pos_list = \
218
+ self.prepare_multimodal_inputs(input_ids, modal_inputs, labels, attention_mask)
219
+
220
+ return self.llm(input_ids=input_ids,
221
+ attention_mask=attention_mask,
222
+ inputs_embeds=inputs_embeds,
223
+ labels=labels,
224
+ modal_tag_pos_list=modal_tag_pos_list,
225
+ **kwargs)
226
+ else:
227
+ return self.llm(input_ids=input_ids,
228
+ attention_mask=attention_mask,
229
+ labels=labels,
230
+ **kwargs)
231
+
232
+ @torch.no_grad()
233
+ def generate(
234
+ self,
235
+ input_ids: torch.LongTensor = None,
236
+ modal_inputs: Optional[List[List[Tuple[str, torch.FloatTensor]]]] = None, # B M (modal_tag, modal_input)
237
+ attention_mask: Optional[torch.Tensor] = None,
238
+ **kwargs
239
+ ) -> Union[GenerateOutput, torch.LongTensor]:
240
+
241
+ if modal_inputs is not None:
242
+ input_ids, inputs_embeds, labels, attention_mask, modal_tag_pos_list = \
243
+ self.prepare_multimodal_inputs(input_ids, modal_inputs, None, attention_mask)
244
+
245
+ return self.llm.generate(input_ids=input_ids, attention_mask=attention_mask,
246
+ inputs_embeds=inputs_embeds, modal_tag_pos_list=modal_tag_pos_list, **kwargs)
247
+
248
+ else:
249
+ return self.llm.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.distributed as dist
2
+ import ast
3
+ import re
4
+ import torch
5
+ from PIL import Image
6
+ import math
7
+
8
+
9
+ def model_params_summary(module, out_fn, verbose=True):
10
+ out_fn("-"*30)
11
+ out_fn(f"module name: {module.__class__}")
12
+ if verbose:
13
+ out_fn("-"*30)
14
+ for n, p in module.named_parameters():
15
+ out_fn(f"{n}: {'trainable' if p.requires_grad else 'freeze'}")
16
+ out_fn("-"*30)
17
+ out_fn(f"Total params: {sum(p.numel() for p in module.parameters())/1e6:.4f}M")
18
+ out_fn(f"Trainable params: {sum(p.numel() for p in module.parameters() if p.requires_grad)/1e6:.4f}M")
19
+ out_fn("-"*30)
20
+
21
+
22
+ def rank0_print(*args):
23
+ if dist.is_initialized():
24
+ if dist.get_rank() == 0:
25
+ print(f"Rank {dist.get_rank()}: ", *args)
26
+ else:
27
+ print(*args)
28
+
29
+
30
+ LLM_DIM_MAPPING = {
31
+ 'Qwen2.5-0.5B': 896,
32
+ 'Qwen2.5-1.5B': 1536,
33
+ 'Qwen2.5-3B': 2048,
34
+ 'Qwen2.5-7B': 3584
35
+ }
36
+
37
+
38
+ SYSTEM_PROMPT = "You are a multimodal AI assistant named AllSparkv2 capable of understanding and generating content " +\
39
+ "in various forms, including text and images. Your primary function is to provide useful and harmless " +\
40
+ "information based on user input, assisting with problem-solving, information retrieval, and task completion. "
41
+
42
+
43
+ def select_best_resolution(original_size, possible_resolutions):
44
+ """
45
+ Selects the best resolution from a list of possible resolutions based on the original size.
46
+
47
+ Args:
48
+ original_size (tuple): The original size of the image in the format (width, height).
49
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
50
+
51
+ Returns:
52
+ tuple: The best fit resolution in the format (width, height).
53
+ """
54
+ original_width, original_height = original_size
55
+ best_fit = None
56
+ max_effective_resolution = 0
57
+ min_wasted_resolution = float("inf")
58
+
59
+ for width, height in possible_resolutions:
60
+ # Calculate the downscaled size to keep the aspect ratio
61
+ scale = min(width / original_width, height / original_height)
62
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
63
+
64
+ # Calculate effective and wasted resolutions
65
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
66
+ wasted_resolution = (width * height) - effective_resolution
67
+
68
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
69
+ max_effective_resolution = effective_resolution
70
+ min_wasted_resolution = wasted_resolution
71
+ best_fit = (width, height)
72
+
73
+ return best_fit
74
+
75
+
76
+ def resize_and_pad_image(image, target_resolution):
77
+ """
78
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
79
+
80
+ Args:
81
+ image (PIL.Image.Image): The input image.
82
+ target_resolution (tuple): The target resolution (width, height) of the image.
83
+
84
+ Returns:
85
+ PIL.Image.Image: The resized and padded image.
86
+ """
87
+ original_width, original_height = image.size
88
+ target_width, target_height = target_resolution
89
+
90
+ # Determine which dimension (width or height) to fill
91
+ scale_w = target_width / original_width
92
+ scale_h = target_height / original_height
93
+
94
+ if scale_w < scale_h:
95
+ # Width will be filled completely
96
+ new_width = target_width
97
+ new_height = min(math.ceil(original_height * scale_w), target_height)
98
+ else:
99
+ # Height will be filled completely
100
+ new_height = target_height
101
+ new_width = min(math.ceil(original_width * scale_h), target_width)
102
+
103
+ # Resize the image
104
+ resized_image = image.resize((new_width, new_height))
105
+
106
+ # Create a new image with the target size and paste the resized image onto it
107
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
108
+ paste_x = (target_width - new_width) // 2
109
+ paste_y = (target_height - new_height) // 2
110
+ new_image.paste(resized_image, (paste_x, paste_y))
111
+
112
+ return new_image
113
+
114
+
115
+ def divide_to_patches(image, patch_size):
116
+ """
117
+ Divides an image into patches of a specified size.
118
+
119
+ Args:
120
+ image (PIL.Image.Image): The input image.
121
+ patch_size (int): The size of each patch.
122
+
123
+ Returns:
124
+ list: A list of PIL.Image.Image objects representing the patches.
125
+ """
126
+ patches = []
127
+ width, height = image.size
128
+ for i in range(0, height, patch_size):
129
+ for j in range(0, width, patch_size):
130
+ box = (j, i, j + patch_size, i + patch_size)
131
+ patch = image.crop(box)
132
+ patches.append(patch)
133
+
134
+ return patches
135
+
136
+
137
+ def process_anyres_image(image, processor, grid_pinpoints):
138
+ """
139
+ Process an image with variable resolutions.
140
+
141
+ Args:
142
+ image (PIL.Image.Image): The input image to be processed.
143
+ processor: The image processor object.
144
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
145
+
146
+ Returns:
147
+ torch.Tensor: A tensor containing the processed image patches.
148
+ """
149
+ # Convert grid_pinpoints from string to list
150
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
151
+ patch_size = min(processor.size.values())
152
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
153
+ # Use regex to extract the range from the input string
154
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
155
+ range_start = tuple(map(int, matches[0]))
156
+ range_end = tuple(map(int, matches[-1]))
157
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
158
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
159
+ # Multiply all elements by patch_size
160
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
161
+
162
+ if type(grid_pinpoints) is list:
163
+ possible_resolutions = grid_pinpoints
164
+ else:
165
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
166
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
167
+ image_padded = resize_and_pad_image(image, best_resolution)
168
+
169
+ patches = divide_to_patches(image_padded, processor.size["height"])
170
+
171
+ # FIXME: this seems to be a bug that it resizes instead of pad.
172
+ # but to keep it consistent with previous, i will keep it as it is
173
+ # TODO: uncomment below to ablate with the padding
174
+ shortest_edge = min(processor.size.values())
175
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
176
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
177
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
178
+
179
+ image_patches = [image_original_resize] + patches
180
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
181
+ return torch.stack(image_patches, dim=0)