Spaces:
Runtime error
Runtime error
Upload 36 files
Browse files- .gitattributes +3 -0
- inference/chat_vision_point.ipynb +0 -0
- inference/demo_assets/e393be9a47a24a7cae6142e13f5686d1_8192.npy +3 -0
- inference/demo_assets/image1.png +3 -0
- inference/demo_assets/image1_384.png +3 -0
- inference/demo_assets/image2.png +3 -0
- inference/forward_speed.ipynb +182 -0
- mm_models/__init__.py +2 -0
- mm_models/configuration_mm.py +23 -0
- mm_models/llms/__pycache__/llama_modal_moe.cpython-310.pyc +0 -0
- mm_models/llms/__pycache__/qwen_model_moe.cpython-310.pyc +0 -0
- mm_models/llms/qwen_model_moe.py +338 -0
- mm_models/modal_module/__init__.py +22 -0
- mm_models/modal_module/__pycache__/__init__.cpython-310.pyc +0 -0
- mm_models/modal_module/point/__pycache__/reconv2.cpython-310.pyc +0 -0
- mm_models/modal_module/point/recon/__pycache__/transformer.cpython-310.pyc +0 -0
- mm_models/modal_module/point/recon/reconv2_utils/AverageMeter.py +42 -0
- mm_models/modal_module/point/recon/reconv2_utils/__pycache__/knn.cpython-310.pyc +0 -0
- mm_models/modal_module/point/recon/reconv2_utils/__pycache__/logger.cpython-310.pyc +0 -0
- mm_models/modal_module/point/recon/reconv2_utils/__pycache__/misc.cpython-310.pyc +0 -0
- mm_models/modal_module/point/recon/reconv2_utils/checkpoint.py +129 -0
- mm_models/modal_module/point/recon/reconv2_utils/config.py +69 -0
- mm_models/modal_module/point/recon/reconv2_utils/data.py +109 -0
- mm_models/modal_module/point/recon/reconv2_utils/dist_utils.py +49 -0
- mm_models/modal_module/point/recon/reconv2_utils/knn.py +37 -0
- mm_models/modal_module/point/recon/reconv2_utils/logger.py +127 -0
- mm_models/modal_module/point/recon/reconv2_utils/misc.py +294 -0
- mm_models/modal_module/point/recon/reconv2_utils/parser.py +117 -0
- mm_models/modal_module/point/recon/reconv2_utils/randaugment.py +216 -0
- mm_models/modal_module/point/recon/reconv2_utils/registry.py +289 -0
- mm_models/modal_module/point/recon/reconv2_utils/transforms.py +78 -0
- mm_models/modal_module/point/recon/transformer.py +647 -0
- mm_models/modal_module/point/reconv2.py +266 -0
- mm_models/modal_module/vision/__pycache__/siglip.cpython-310.pyc +0 -0
- mm_models/modal_module/vision/siglip.py +122 -0
- mm_models/modeling_mm.py +259 -0
- utils.py +181 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
inference/demo_assets/image1_384.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
inference/demo_assets/image1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
inference/demo_assets/image2.png filter=lfs diff=lfs merge=lfs -text
|
inference/chat_vision_point.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inference/demo_assets/e393be9a47a24a7cae6142e13f5686d1_8192.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e9bee56f432fce81bc1d356f54656064ac4926d02ff95b1b4f2e09d18ba79c7
|
3 |
+
size 196736
|
inference/demo_assets/image1.png
ADDED
![]() |
Git LFS Details
|
inference/demo_assets/image1_384.png
ADDED
![]() |
Git LFS Details
|
inference/demo_assets/image2.png
ADDED
![]() |
Git LFS Details
|
inference/forward_speed.ipynb
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/opt/conda/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"import sys\n",
|
19 |
+
"import os\n",
|
20 |
+
"import random\n",
|
21 |
+
"import torch\n",
|
22 |
+
"sys.path.append(\"../\")\n",
|
23 |
+
"from mm_models import AllSparkForCausalLM\n",
|
24 |
+
"from transformers import AutoImageProcessor, AutoTokenizer\n",
|
25 |
+
"from PIL import Image\n",
|
26 |
+
"import numpy as np\n",
|
27 |
+
"from fvcore.nn import FlopCountAnalysis\n",
|
28 |
+
"from plyfile import PlyData\n",
|
29 |
+
"import plotly.graph_objects as go\n",
|
30 |
+
"from mm_datasets.data_utils import point_preprocess, load_pts, process_pts\n",
|
31 |
+
"from utils import SYSTEM_PROMPT\n",
|
32 |
+
"\n",
|
33 |
+
"system_prompt = SYSTEM_PROMPT\n",
|
34 |
+
"\n",
|
35 |
+
"\n",
|
36 |
+
"def show_pointcloud(data, background=None):\n",
|
37 |
+
" points = data[:, :3]\n",
|
38 |
+
" colors = data[:, 3:6]\n",
|
39 |
+
"\n",
|
40 |
+
" if colors is not None:\n",
|
41 |
+
" # * if colors in range(0-1)\n",
|
42 |
+
" if np.max(colors) <= 1:\n",
|
43 |
+
" color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)\n",
|
44 |
+
" # * if colors in range(0-255)\n",
|
45 |
+
" elif np.max(colors) <= 255:\n",
|
46 |
+
" color_data = colors.astype(int)\n",
|
47 |
+
" else:\n",
|
48 |
+
" color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available\n",
|
49 |
+
" colors = color_data.astype(np.float32) / 255 # model input is (0-1)\n",
|
50 |
+
"\n",
|
51 |
+
" color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]\n",
|
52 |
+
"\n",
|
53 |
+
" fig = go.Figure(\n",
|
54 |
+
" data=[\n",
|
55 |
+
" go.Scatter3d(\n",
|
56 |
+
" x=points[:, 0], y=points[:, 1], z=points[:, 2],\n",
|
57 |
+
" mode='markers',\n",
|
58 |
+
" marker=dict(\n",
|
59 |
+
" size=1.2,\n",
|
60 |
+
" color=color_strings, # Use the list of RGB strings for the marker colors\n",
|
61 |
+
" )\n",
|
62 |
+
" )\n",
|
63 |
+
" ],\n",
|
64 |
+
" layout=dict(\n",
|
65 |
+
" scene=dict(\n",
|
66 |
+
" xaxis=dict(visible=False),\n",
|
67 |
+
" yaxis=dict(visible=False),\n",
|
68 |
+
" zaxis=dict(visible=False)\n",
|
69 |
+
" ),\n",
|
70 |
+
" paper_bgcolor='rgb(50,50,50)' if background is None else background # Set the background color to dark gray 50, 50, 50\n",
|
71 |
+
" ),\n",
|
72 |
+
" )\n",
|
73 |
+
" fig.show()"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": null,
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"name": "stderr",
|
83 |
+
"output_type": "stream",
|
84 |
+
"text": [
|
85 |
+
"Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00, 1.73s/it]\n"
|
86 |
+
]
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"model_path = \"[path/to/model]\"\n",
|
91 |
+
"\n",
|
92 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
|
93 |
+
"model = AllSparkForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda()\n",
|
94 |
+
"img_processor = AutoImageProcessor.from_pretrained(model_path)\n",
|
95 |
+
"modal_place_token = dict()\n",
|
96 |
+
"for modal_cfg in model.config.modal_configs:\n",
|
97 |
+
" modal_place_token[modal_cfg['modal_tag']] = modal_cfg['modal_placeholder_token']"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": null,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [
|
105 |
+
{
|
106 |
+
"name": "stderr",
|
107 |
+
"output_type": "stream",
|
108 |
+
"text": [
|
109 |
+
"Unsupported operator aten::_convolution_mode encountered 1 time(s)\n",
|
110 |
+
"Unsupported operator aten::embedding encountered 3 time(s)\n",
|
111 |
+
"Unsupported operator aten::add encountered 224 time(s)\n",
|
112 |
+
"Unsupported operator aten::mul encountered 342 time(s)\n",
|
113 |
+
"Unsupported operator aten::softmax encountered 26 time(s)\n",
|
114 |
+
"Unsupported operator aten::gelu encountered 28 time(s)\n",
|
115 |
+
"Unsupported operator aten::pad encountered 1 time(s)\n",
|
116 |
+
"Unsupported operator aten::mul_ encountered 1 time(s)\n",
|
117 |
+
"Unsupported operator aten::ones_like encountered 1 time(s)\n",
|
118 |
+
"Unsupported operator aten::sub encountered 1 time(s)\n",
|
119 |
+
"Unsupported operator aten::cos encountered 1 time(s)\n",
|
120 |
+
"Unsupported operator aten::sin encountered 1 time(s)\n",
|
121 |
+
"Unsupported operator aten::pow encountered 57 time(s)\n",
|
122 |
+
"Unsupported operator aten::mean encountered 57 time(s)\n",
|
123 |
+
"Unsupported operator aten::rsqrt encountered 57 time(s)\n",
|
124 |
+
"Unsupported operator aten::neg encountered 56 time(s)\n",
|
125 |
+
"Unsupported operator prim::PythonOp.FlashAttnFunc encountered 28 time(s)\n",
|
126 |
+
"Unsupported operator aten::silu encountered 28 time(s)\n",
|
127 |
+
"Unsupported operator aten::cross_entropy_loss encountered 1 time(s)\n",
|
128 |
+
"The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
|
129 |
+
"llm.model.layers.0.self_attn.rotary_emb, llm.model.layers.1.self_attn.rotary_emb, llm.model.layers.10.self_attn.rotary_emb, llm.model.layers.11.self_attn.rotary_emb, llm.model.layers.12.self_attn.rotary_emb, llm.model.layers.13.self_attn.rotary_emb, llm.model.layers.14.self_attn.rotary_emb, llm.model.layers.15.self_attn.rotary_emb, llm.model.layers.16.self_attn.rotary_emb, llm.model.layers.17.self_attn.rotary_emb, llm.model.layers.18.self_attn.rotary_emb, llm.model.layers.19.self_attn.rotary_emb, llm.model.layers.2.self_attn.rotary_emb, llm.model.layers.20.self_attn.rotary_emb, llm.model.layers.21.self_attn.rotary_emb, llm.model.layers.22.self_attn.rotary_emb, llm.model.layers.23.self_attn.rotary_emb, llm.model.layers.24.self_attn.rotary_emb, llm.model.layers.25.self_attn.rotary_emb, llm.model.layers.26.self_attn.rotary_emb, llm.model.layers.27.self_attn.rotary_emb, llm.model.layers.3.self_attn.rotary_emb, llm.model.layers.4.self_attn.rotary_emb, llm.model.layers.5.self_attn.rotary_emb, llm.model.layers.6.self_attn.rotary_emb, llm.model.layers.7.self_attn.rotary_emb, llm.model.layers.8.self_attn.rotary_emb, llm.model.layers.9.self_attn.rotary_emb, modal_encoders.vision.vision_model.encoder.layers.26, modal_encoders.vision.vision_model.encoder.layers.26.layer_norm1, modal_encoders.vision.vision_model.encoder.layers.26.layer_norm2, modal_encoders.vision.vision_model.encoder.layers.26.mlp, modal_encoders.vision.vision_model.encoder.layers.26.mlp.activation_fn, modal_encoders.vision.vision_model.encoder.layers.26.mlp.fc1, modal_encoders.vision.vision_model.encoder.layers.26.mlp.fc2, modal_encoders.vision.vision_model.encoder.layers.26.self_attn, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.k_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.out_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.q_proj, modal_encoders.vision.vision_model.encoder.layers.26.self_attn.v_proj, modal_encoders.vision.vision_model.head, modal_encoders.vision.vision_model.head.attention, modal_encoders.vision.vision_model.head.attention.out_proj, modal_encoders.vision.vision_model.head.layernorm, modal_encoders.vision.vision_model.head.mlp, modal_encoders.vision.vision_model.head.mlp.activation_fn, modal_encoders.vision.vision_model.head.mlp.fc1, modal_encoders.vision.vision_model.head.mlp.fc2, modal_encoders.vision.vision_model.post_layernorm\n"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"name": "stdout",
|
134 |
+
"output_type": "stream",
|
135 |
+
"text": [
|
136 |
+
"2781.7917 GFLOPs\n"
|
137 |
+
]
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"source": [
|
141 |
+
"image_path = \"../demo_images/image1_384.png\"\n",
|
142 |
+
"img = Image.open(image_path).convert(\"RGB\")\n",
|
143 |
+
"\n",
|
144 |
+
"img = img_processor(images=img, return_tensors=\"pt\").pixel_values.to(\"cuda\").squeeze().to(model.dtype)\n",
|
145 |
+
"\n",
|
146 |
+
"question_2 = modal_place_token['vision'] + \"\\nDescribe this image.\"\n",
|
147 |
+
"\n",
|
148 |
+
"modal_inputs = [('vision', img)]\n",
|
149 |
+
"\n",
|
150 |
+
"messages = [\n",
|
151 |
+
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
152 |
+
" {\"role\": \"user\", \"content\": \"The 3D object is a football, specifically a soccer ball, which is used in the sport of soccer or football. The ball is designed with a series of black and white panels that form a spherical shape, making it easy to kick and control during gameplay. The design also allows for the ball to spin and bounce when in motion, adding a strategic element to the sport.\\n\" + question_2},\n",
|
153 |
+
"]\n",
|
154 |
+
"inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors='pt').to(model.device)\n",
|
155 |
+
"\n",
|
156 |
+
"flops = FlopCountAnalysis(model, (inputs, [modal_inputs]))\n",
|
157 |
+
"print(f\"{flops.total()/1e9:.4f} GFLOPs\")"
|
158 |
+
]
|
159 |
+
}
|
160 |
+
],
|
161 |
+
"metadata": {
|
162 |
+
"kernelspec": {
|
163 |
+
"display_name": "base",
|
164 |
+
"language": "python",
|
165 |
+
"name": "python3"
|
166 |
+
},
|
167 |
+
"language_info": {
|
168 |
+
"codemirror_mode": {
|
169 |
+
"name": "ipython",
|
170 |
+
"version": 3
|
171 |
+
},
|
172 |
+
"file_extension": ".py",
|
173 |
+
"mimetype": "text/x-python",
|
174 |
+
"name": "python",
|
175 |
+
"nbconvert_exporter": "python",
|
176 |
+
"pygments_lexer": "ipython3",
|
177 |
+
"version": "3.10.14"
|
178 |
+
}
|
179 |
+
},
|
180 |
+
"nbformat": 4,
|
181 |
+
"nbformat_minor": 2
|
182 |
+
}
|
mm_models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .configuration_mm import AllSparkConfig
|
2 |
+
from .modeling_mm import AllSparkForCausalLM
|
mm_models/configuration_mm.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import Optional, List, Dict
|
3 |
+
|
4 |
+
|
5 |
+
class AllSparkConfig(PretrainedConfig):
|
6 |
+
|
7 |
+
model_type = "allspark"
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
llm_name_or_path: str = None,
|
11 |
+
modal_configs: Optional[List[Dict]] = None,
|
12 |
+
initializer_range: float = 0.02,
|
13 |
+
ignore_index: int = -100,
|
14 |
+
tokenizer_padding_side: str = "right",
|
15 |
+
add_moe: bool = True,
|
16 |
+
**kwargs):
|
17 |
+
self.llm_name_or_path = llm_name_or_path
|
18 |
+
self.modal_configs = modal_configs
|
19 |
+
self.initializer_range = initializer_range
|
20 |
+
self.ignore_index = ignore_index
|
21 |
+
self.tokenizer_padding_side = tokenizer_padding_side
|
22 |
+
self.add_moe = add_moe
|
23 |
+
super().__init__(**kwargs)
|
mm_models/llms/__pycache__/llama_modal_moe.cpython-310.pyc
ADDED
Binary file (9.32 kB). View file
|
|
mm_models/llms/__pycache__/qwen_model_moe.cpython-310.pyc
ADDED
Binary file (8.53 kB). View file
|
|
mm_models/llms/qwen_model_moe.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import Qwen2ForCausalLM, Qwen2Model
|
4 |
+
from transformers.cache_utils import Cache, DynamicCache
|
5 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP
|
6 |
+
from typing import Optional, Tuple, List, Union
|
7 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
|
8 |
+
from utils import rank0_print
|
9 |
+
import copy
|
10 |
+
|
11 |
+
|
12 |
+
class Qwen2DecoderLayerMoE(Qwen2DecoderLayer):
|
13 |
+
|
14 |
+
def __init__(self, config, layer_idx, modal_tags, add_moe):
|
15 |
+
config._attn_implementation = "flash_attention_2"
|
16 |
+
super().__init__(config, layer_idx)
|
17 |
+
self.modal_tags = modal_tags
|
18 |
+
|
19 |
+
if modal_tags is not None and add_moe:
|
20 |
+
self.modal_moes = nn.ModuleDict()
|
21 |
+
for tag in modal_tags:
|
22 |
+
self.modal_moes[tag] = Qwen2MLP(config)
|
23 |
+
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
hidden_states: torch.Tensor,
|
27 |
+
modal_token_idx_matrix,
|
28 |
+
modal_idx_mapping,
|
29 |
+
attention_mask: Optional[torch.Tensor] = None,
|
30 |
+
position_ids: Optional[torch.LongTensor] = None,
|
31 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
32 |
+
output_attentions: Optional[bool] = False,
|
33 |
+
use_cache: Optional[bool] = False,
|
34 |
+
cache_position: Optional[torch.LongTensor] = None,
|
35 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
36 |
+
**kwargs,
|
37 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
38 |
+
|
39 |
+
residual = hidden_states
|
40 |
+
|
41 |
+
hidden_states = self.input_layernorm(hidden_states)
|
42 |
+
|
43 |
+
# Self Attention
|
44 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
45 |
+
hidden_states=hidden_states,
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
position_ids=position_ids,
|
48 |
+
past_key_value=past_key_value,
|
49 |
+
output_attentions=output_attentions,
|
50 |
+
use_cache=use_cache,
|
51 |
+
cache_position=cache_position,
|
52 |
+
position_embeddings=position_embeddings,
|
53 |
+
)
|
54 |
+
hidden_states = residual + hidden_states
|
55 |
+
|
56 |
+
residual = hidden_states
|
57 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
58 |
+
if modal_token_idx_matrix is not None and hidden_states.shape[1] > 1 and hasattr(self, 'modal_moes'):
|
59 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
60 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
61 |
+
final_hidden_states = torch.zeros(
|
62 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
63 |
+
)
|
64 |
+
for modal_tag, modal_idx in modal_idx_mapping.items():
|
65 |
+
mask = modal_token_idx_matrix == modal_idx
|
66 |
+
if not torch.any(mask):
|
67 |
+
continue
|
68 |
+
mask = mask.view(-1)
|
69 |
+
assert mask.shape[0] == hidden_states.shape[0]
|
70 |
+
if modal_tag == 'text':
|
71 |
+
final_hidden_states[mask] = self.mlp(hidden_states[mask])
|
72 |
+
else:
|
73 |
+
final_hidden_states[mask] = self.modal_moes[modal_tag](hidden_states[mask])
|
74 |
+
hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
75 |
+
else:
|
76 |
+
hidden_states = self.mlp(hidden_states)
|
77 |
+
hidden_states = residual + hidden_states
|
78 |
+
|
79 |
+
outputs = (hidden_states,)
|
80 |
+
|
81 |
+
if output_attentions:
|
82 |
+
outputs += (self_attn_weights,)
|
83 |
+
|
84 |
+
if use_cache:
|
85 |
+
outputs += (present_key_value,)
|
86 |
+
|
87 |
+
return outputs
|
88 |
+
|
89 |
+
|
90 |
+
class Qwen2ModelMoE(Qwen2Model):
|
91 |
+
|
92 |
+
def __init__(self, config, modal_tags=None, add_moe=True):
|
93 |
+
super().__init__(config)
|
94 |
+
|
95 |
+
self.modal_tags = modal_tags
|
96 |
+
|
97 |
+
self.layers = nn.ModuleList(
|
98 |
+
[Qwen2DecoderLayerMoE(config, layer_idx, modal_tags, add_moe=add_moe)
|
99 |
+
for layer_idx in range(config.num_hidden_layers)]
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(
|
103 |
+
self,
|
104 |
+
input_ids: torch.LongTensor = None,
|
105 |
+
attention_mask: Optional[torch.Tensor] = None,
|
106 |
+
position_ids: Optional[torch.LongTensor] = None,
|
107 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
108 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
109 |
+
use_cache: Optional[bool] = None,
|
110 |
+
output_attentions: Optional[bool] = None,
|
111 |
+
output_hidden_states: Optional[bool] = None,
|
112 |
+
return_dict: Optional[bool] = None,
|
113 |
+
cache_position: Optional[torch.LongTensor] = None,
|
114 |
+
modal_tag_pos_list: Optional[List[List[Tuple[str, int, int]]]] = None, # batch, modal_num, (tag, start, end)
|
115 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
116 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
117 |
+
output_hidden_states = (
|
118 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
119 |
+
)
|
120 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
121 |
+
|
122 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
123 |
+
|
124 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
125 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
126 |
+
|
127 |
+
if self.gradient_checkpointing and self.training:
|
128 |
+
if use_cache:
|
129 |
+
rank0_print(
|
130 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
131 |
+
)
|
132 |
+
use_cache = False
|
133 |
+
|
134 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
135 |
+
return_legacy_cache = False
|
136 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
137 |
+
return_legacy_cache = True
|
138 |
+
if past_key_values is None:
|
139 |
+
past_key_values = DynamicCache()
|
140 |
+
else:
|
141 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
142 |
+
rank0_print(
|
143 |
+
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
144 |
+
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
145 |
+
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
146 |
+
)
|
147 |
+
|
148 |
+
if inputs_embeds is None:
|
149 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
150 |
+
|
151 |
+
if cache_position is None:
|
152 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
153 |
+
cache_position = torch.arange(
|
154 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
155 |
+
)
|
156 |
+
if position_ids is None:
|
157 |
+
position_ids = cache_position.unsqueeze(0)
|
158 |
+
|
159 |
+
causal_mask = self._update_causal_mask(
|
160 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
161 |
+
)
|
162 |
+
|
163 |
+
hidden_states = inputs_embeds
|
164 |
+
|
165 |
+
# create position embeddings to be shared across the decoder layers
|
166 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
167 |
+
|
168 |
+
# decoder layers
|
169 |
+
all_hidden_states = () if output_hidden_states else None
|
170 |
+
all_self_attns = () if output_attentions else None
|
171 |
+
next_decoder_cache = None
|
172 |
+
|
173 |
+
modal_token_idx_matrix = torch.zeros(hidden_states.shape[:2], dtype=torch.int8)
|
174 |
+
modal_idx_mapping = {"text": 0}
|
175 |
+
if self.modal_tags and modal_tag_pos_list:
|
176 |
+
for modal_idx, tag in enumerate(self.modal_tags):
|
177 |
+
modal_idx_mapping[tag] = modal_idx + 1
|
178 |
+
for sample_id, single_sample_mtp in enumerate(modal_tag_pos_list):
|
179 |
+
for tag, spos, epos in single_sample_mtp:
|
180 |
+
modal_token_idx_matrix[sample_id, spos:epos+1] = modal_idx_mapping[tag]
|
181 |
+
|
182 |
+
for decoder_layer in self.layers:
|
183 |
+
if output_hidden_states:
|
184 |
+
all_hidden_states += (hidden_states,)
|
185 |
+
|
186 |
+
if self.gradient_checkpointing and self.training:
|
187 |
+
layer_outputs = self._gradient_checkpointing_func(
|
188 |
+
decoder_layer.__call__,
|
189 |
+
hidden_states,
|
190 |
+
modal_token_idx_matrix,
|
191 |
+
modal_idx_mapping,
|
192 |
+
causal_mask,
|
193 |
+
position_ids,
|
194 |
+
past_key_values,
|
195 |
+
output_attentions,
|
196 |
+
use_cache,
|
197 |
+
cache_position,
|
198 |
+
position_embeddings,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
layer_outputs = decoder_layer(
|
202 |
+
hidden_states,
|
203 |
+
modal_token_idx_matrix=modal_token_idx_matrix,
|
204 |
+
modal_idx_mapping=modal_idx_mapping,
|
205 |
+
attention_mask=causal_mask,
|
206 |
+
position_ids=position_ids,
|
207 |
+
past_key_value=past_key_values,
|
208 |
+
output_attentions=output_attentions,
|
209 |
+
use_cache=use_cache,
|
210 |
+
cache_position=cache_position,
|
211 |
+
position_embeddings=position_embeddings,
|
212 |
+
)
|
213 |
+
|
214 |
+
hidden_states = layer_outputs[0]
|
215 |
+
|
216 |
+
if use_cache:
|
217 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
218 |
+
|
219 |
+
if output_attentions:
|
220 |
+
all_self_attns += (layer_outputs[1],)
|
221 |
+
|
222 |
+
hidden_states = self.norm(hidden_states)
|
223 |
+
|
224 |
+
# add hidden states from the last decoder layer
|
225 |
+
if output_hidden_states:
|
226 |
+
all_hidden_states += (hidden_states,)
|
227 |
+
|
228 |
+
next_cache = next_decoder_cache if use_cache else None
|
229 |
+
if return_legacy_cache:
|
230 |
+
next_cache = next_cache.to_legacy_cache()
|
231 |
+
|
232 |
+
if not return_dict:
|
233 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
234 |
+
return BaseModelOutputWithPast(
|
235 |
+
last_hidden_state=hidden_states,
|
236 |
+
past_key_values=next_cache,
|
237 |
+
hidden_states=all_hidden_states,
|
238 |
+
attentions=all_self_attns,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
class Qwen2ForCausalLMMoE(Qwen2ForCausalLM):
|
243 |
+
|
244 |
+
def __init__(self, config, modal_tags=None, add_moe=True):
|
245 |
+
super().__init__(config)
|
246 |
+
self.model = Qwen2ModelMoE(config, modal_tags, add_moe=add_moe)
|
247 |
+
|
248 |
+
def forward(
|
249 |
+
self,
|
250 |
+
input_ids: torch.LongTensor = None,
|
251 |
+
attention_mask: Optional[torch.Tensor] = None,
|
252 |
+
position_ids: Optional[torch.LongTensor] = None,
|
253 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
254 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
255 |
+
labels: Optional[torch.LongTensor] = None,
|
256 |
+
use_cache: Optional[bool] = None,
|
257 |
+
output_attentions: Optional[bool] = None,
|
258 |
+
output_hidden_states: Optional[bool] = None,
|
259 |
+
return_dict: Optional[bool] = None,
|
260 |
+
cache_position: Optional[torch.LongTensor] = None,
|
261 |
+
num_logits_to_keep: int = 0,
|
262 |
+
modal_tag_pos_list: Optional[List[List[Tuple[str, int, int]]]] = None, # batch, modal_num, (tag, start, end)
|
263 |
+
**loss_kwargs,
|
264 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
265 |
+
|
266 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
267 |
+
output_hidden_states = (
|
268 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
269 |
+
)
|
270 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
271 |
+
|
272 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
273 |
+
outputs = self.model(
|
274 |
+
input_ids=input_ids,
|
275 |
+
attention_mask=attention_mask,
|
276 |
+
position_ids=position_ids,
|
277 |
+
past_key_values=past_key_values,
|
278 |
+
inputs_embeds=inputs_embeds,
|
279 |
+
use_cache=use_cache,
|
280 |
+
output_attentions=output_attentions,
|
281 |
+
output_hidden_states=output_hidden_states,
|
282 |
+
return_dict=return_dict,
|
283 |
+
cache_position=cache_position,
|
284 |
+
modal_tag_pos_list=modal_tag_pos_list,
|
285 |
+
)
|
286 |
+
|
287 |
+
hidden_states = outputs[0]
|
288 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
289 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
290 |
+
|
291 |
+
loss = None
|
292 |
+
if labels is not None:
|
293 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
294 |
+
|
295 |
+
if not return_dict:
|
296 |
+
output = (logits,) + outputs[1:]
|
297 |
+
return (loss,) + output if loss is not None else output
|
298 |
+
|
299 |
+
return CausalLMOutputWithPast(
|
300 |
+
loss=loss,
|
301 |
+
logits=logits,
|
302 |
+
past_key_values=outputs.past_key_values,
|
303 |
+
hidden_states=outputs.hidden_states,
|
304 |
+
attentions=outputs.attentions,
|
305 |
+
)
|
306 |
+
|
307 |
+
def prepare_inputs_for_generation(self,
|
308 |
+
input_ids: torch.LongTensor,
|
309 |
+
past_key_values: Optional[Cache] = None,
|
310 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
311 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
312 |
+
cache_position: Optional[torch.LongTensor] = None,
|
313 |
+
**kwargs,):
|
314 |
+
model_inputs = super().prepare_inputs_for_generation(input_ids,
|
315 |
+
past_key_values,
|
316 |
+
attention_mask,
|
317 |
+
inputs_embeds,
|
318 |
+
cache_position,
|
319 |
+
**kwargs)
|
320 |
+
model_inputs.update(
|
321 |
+
{
|
322 |
+
"modal_tag_pos_list": kwargs.get("modal_tag_pos_list", None),
|
323 |
+
}
|
324 |
+
)
|
325 |
+
return model_inputs
|
326 |
+
|
327 |
+
def init_modal_moe_params(self, target_tag, src_tag):
|
328 |
+
for i, decoder_layer in enumerate(self.model.layers):
|
329 |
+
if hasattr(decoder_layer, "modal_moes"):
|
330 |
+
if src_tag == "text":
|
331 |
+
rank0_print(f"Initializing layer{i} {target_tag} moe params for text")
|
332 |
+
mlp_module = decoder_layer.mlp
|
333 |
+
decoder_layer.modal_moes[target_tag].load_state_dict(copy.deepcopy(mlp_module.state_dict()))
|
334 |
+
else:
|
335 |
+
rank0_print(f"Initializing layer{i} {target_tag} moe params for {src_tag}")
|
336 |
+
assert src_tag in decoder_layer.modal_moes, f"src_tag {src_tag} not found in decoder_layer.modal_moes"
|
337 |
+
src_module = decoder_layer.modal_moes[src_tag]
|
338 |
+
decoder_layer.modal_moes[target_tag].load_state_dict(copy.deepcopy(src_module.state_dict()))
|
mm_models/modal_module/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vision.siglip import build_vision_encoder, build_vision_projector, VISION_SIGLIP_MODAL_CFG
|
2 |
+
from .point.reconv2 import build_point_encoder, build_point_projector, POINT_RECON2_MODAL_CFG
|
3 |
+
# NOTE: import custom modal encoder and projector here
|
4 |
+
|
5 |
+
|
6 |
+
MODAL_CFG_MAPPING = {
|
7 |
+
'vision': VISION_SIGLIP_MODAL_CFG,
|
8 |
+
'point': POINT_RECON2_MODAL_CFG
|
9 |
+
# NOTE: add other modalities here
|
10 |
+
}
|
11 |
+
|
12 |
+
MODAL_ENCODERS_MAPPING = {
|
13 |
+
'vision': build_vision_encoder,
|
14 |
+
'point': build_point_encoder
|
15 |
+
# NOTE: add other modalities here
|
16 |
+
}
|
17 |
+
|
18 |
+
MODAL_PROJECTORS_MAPPING = {
|
19 |
+
'vision': build_vision_projector,
|
20 |
+
'point': build_point_projector
|
21 |
+
# NOTE: add other modalities here
|
22 |
+
}
|
mm_models/modal_module/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (546 Bytes). View file
|
|
mm_models/modal_module/point/__pycache__/reconv2.cpython-310.pyc
ADDED
Binary file (7.89 kB). View file
|
|
mm_models/modal_module/point/recon/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (17.8 kB). View file
|
|
mm_models/modal_module/point/recon/reconv2_utils/AverageMeter.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class AverageMeter(object):
|
3 |
+
def __init__(self, items=None):
|
4 |
+
self.items = items
|
5 |
+
self.n_items = 1 if items is None else len(items)
|
6 |
+
self.reset()
|
7 |
+
|
8 |
+
def reset(self):
|
9 |
+
self._val = [0] * self.n_items
|
10 |
+
self._sum = [0] * self.n_items
|
11 |
+
self._count = [0] * self.n_items
|
12 |
+
|
13 |
+
def update(self, values):
|
14 |
+
if type(values).__name__ == 'list':
|
15 |
+
for idx, v in enumerate(values):
|
16 |
+
self._val[idx] = v
|
17 |
+
self._sum[idx] += v
|
18 |
+
self._count[idx] += 1
|
19 |
+
else:
|
20 |
+
self._val[0] = values
|
21 |
+
self._sum[0] += values
|
22 |
+
self._count[0] += 1
|
23 |
+
|
24 |
+
def val(self, idx=None):
|
25 |
+
if idx is None:
|
26 |
+
return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]
|
27 |
+
else:
|
28 |
+
return self._val[idx]
|
29 |
+
|
30 |
+
def count(self, idx=None):
|
31 |
+
if idx is None:
|
32 |
+
return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]
|
33 |
+
else:
|
34 |
+
return self._count[idx]
|
35 |
+
|
36 |
+
def avg(self, idx=None):
|
37 |
+
if idx is None:
|
38 |
+
return self._sum[0] / self._count[0] if self.items is None else [
|
39 |
+
self._sum[i] / self._count[i] for i in range(self.n_items)
|
40 |
+
]
|
41 |
+
else:
|
42 |
+
return self._sum[idx] / self._count[idx]
|
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/knn.cpython-310.pyc
ADDED
Binary file (1.44 kB). View file
|
|
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (3.98 kB). View file
|
|
mm_models/modal_module/point/recon/reconv2_utils/__pycache__/misc.cpython-310.pyc
ADDED
Binary file (9.2 kB). View file
|
|
mm_models/modal_module/point/recon/reconv2_utils/checkpoint.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
|
4 |
+
from collections import defaultdict
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from typing import Any
|
8 |
+
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
|
9 |
+
|
10 |
+
from termcolor import colored
|
11 |
+
|
12 |
+
def get_missing_parameters_message(keys: List[str]) -> str:
|
13 |
+
"""
|
14 |
+
Get a logging-friendly message to report parameter names (keys) that are in
|
15 |
+
the model but not found in a checkpoint.
|
16 |
+
Args:
|
17 |
+
keys (list[str]): List of keys that were not found in the checkpoint.
|
18 |
+
Returns:
|
19 |
+
str: message.
|
20 |
+
"""
|
21 |
+
groups = _group_checkpoint_keys(keys)
|
22 |
+
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
|
23 |
+
msg += "\n".join(
|
24 |
+
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
|
25 |
+
)
|
26 |
+
return msg
|
27 |
+
|
28 |
+
|
29 |
+
def get_unexpected_parameters_message(keys: List[str]) -> str:
|
30 |
+
"""
|
31 |
+
Get a logging-friendly message to report parameter names (keys) that are in
|
32 |
+
the checkpoint but not found in the model.
|
33 |
+
Args:
|
34 |
+
keys (list[str]): List of keys that were not found in the model.
|
35 |
+
Returns:
|
36 |
+
str: message.
|
37 |
+
"""
|
38 |
+
groups = _group_checkpoint_keys(keys)
|
39 |
+
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
|
40 |
+
msg += "\n".join(
|
41 |
+
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
|
42 |
+
)
|
43 |
+
return msg
|
44 |
+
|
45 |
+
|
46 |
+
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
|
47 |
+
"""
|
48 |
+
Strip the prefix in metadata, if any.
|
49 |
+
Args:
|
50 |
+
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
51 |
+
prefix (str): prefix.
|
52 |
+
"""
|
53 |
+
keys = sorted(state_dict.keys())
|
54 |
+
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
|
55 |
+
return
|
56 |
+
|
57 |
+
for key in keys:
|
58 |
+
newkey = key[len(prefix):]
|
59 |
+
state_dict[newkey] = state_dict.pop(key)
|
60 |
+
|
61 |
+
# also strip the prefix in metadata, if any..
|
62 |
+
try:
|
63 |
+
metadata = state_dict._metadata # pyre-ignore
|
64 |
+
except AttributeError:
|
65 |
+
pass
|
66 |
+
else:
|
67 |
+
for key in list(metadata.keys()):
|
68 |
+
# for the metadata dict, the key can be:
|
69 |
+
# '': for the DDP module, which we want to remove.
|
70 |
+
# 'module': for the actual model.
|
71 |
+
# 'module.xx.xx': for the rest.
|
72 |
+
|
73 |
+
if len(key) == 0:
|
74 |
+
continue
|
75 |
+
newkey = key[len(prefix):]
|
76 |
+
metadata[newkey] = metadata.pop(key)
|
77 |
+
|
78 |
+
|
79 |
+
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
|
80 |
+
"""
|
81 |
+
Group keys based on common prefixes. A prefix is the string up to the final
|
82 |
+
"." in each key.
|
83 |
+
Args:
|
84 |
+
keys (list[str]): list of parameter names, i.e. keys in the model
|
85 |
+
checkpoint dict.
|
86 |
+
Returns:
|
87 |
+
dict[list]: keys with common prefixes are grouped into lists.
|
88 |
+
"""
|
89 |
+
groups = defaultdict(list)
|
90 |
+
for key in keys:
|
91 |
+
pos = key.rfind(".")
|
92 |
+
if pos >= 0:
|
93 |
+
head, tail = key[:pos], [key[pos + 1:]]
|
94 |
+
else:
|
95 |
+
head, tail = key, []
|
96 |
+
groups[head].extend(tail)
|
97 |
+
return groups
|
98 |
+
|
99 |
+
|
100 |
+
def _group_to_str(group: List[str]) -> str:
|
101 |
+
"""
|
102 |
+
Format a group of parameter name suffixes into a loggable string.
|
103 |
+
Args:
|
104 |
+
group (list[str]): list of parameter name suffixes.
|
105 |
+
Returns:
|
106 |
+
str: formated string.
|
107 |
+
"""
|
108 |
+
if len(group) == 0:
|
109 |
+
return ""
|
110 |
+
|
111 |
+
if len(group) == 1:
|
112 |
+
return "." + group[0]
|
113 |
+
|
114 |
+
return ".{" + ", ".join(group) + "}"
|
115 |
+
|
116 |
+
|
117 |
+
def _named_modules_with_dup(
|
118 |
+
model: nn.Module, prefix: str = ""
|
119 |
+
) -> Iterable[Tuple[str, nn.Module]]:
|
120 |
+
"""
|
121 |
+
The same as `model.named_modules()`, except that it includes
|
122 |
+
duplicated modules that have more than one name.
|
123 |
+
"""
|
124 |
+
yield prefix, model
|
125 |
+
for name, module in model._modules.items(): # pyre-ignore
|
126 |
+
if module is None:
|
127 |
+
continue
|
128 |
+
submodule_prefix = prefix + ("." if prefix else "") + name
|
129 |
+
yield from _named_modules_with_dup(module, submodule_prefix)
|
mm_models/modal_module/point/recon/reconv2_utils/config.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from easydict import EasyDict
|
3 |
+
import os
|
4 |
+
from .logger import print_log
|
5 |
+
|
6 |
+
|
7 |
+
def log_args_to_file(args, pre='args', logger=None):
|
8 |
+
for key, val in args.__dict__.items():
|
9 |
+
print_log(f'{pre}.{key} : {val}', logger=logger)
|
10 |
+
|
11 |
+
|
12 |
+
def log_config_to_file(cfg, pre='cfg', logger=None):
|
13 |
+
for key, val in cfg.items():
|
14 |
+
if isinstance(cfg[key], EasyDict):
|
15 |
+
print_log(f'{pre}.{key} = edict()', logger=logger)
|
16 |
+
log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
|
17 |
+
continue
|
18 |
+
print_log(f'{pre}.{key} : {val}', logger=logger)
|
19 |
+
|
20 |
+
|
21 |
+
def merge_new_config(config, new_config):
|
22 |
+
for key, val in new_config.items():
|
23 |
+
if not isinstance(val, dict):
|
24 |
+
if key == '_base_':
|
25 |
+
with open(new_config['_base_'], 'r') as f:
|
26 |
+
try:
|
27 |
+
val = yaml.load(f, Loader=yaml.FullLoader)
|
28 |
+
except:
|
29 |
+
val = yaml.load(f)
|
30 |
+
config[key] = EasyDict()
|
31 |
+
merge_new_config(config[key], val)
|
32 |
+
else:
|
33 |
+
config[key] = val
|
34 |
+
continue
|
35 |
+
if key not in config:
|
36 |
+
config[key] = EasyDict()
|
37 |
+
merge_new_config(config[key], val)
|
38 |
+
return config
|
39 |
+
|
40 |
+
|
41 |
+
def cfg_from_yaml_file(cfg_file):
|
42 |
+
config = EasyDict()
|
43 |
+
with open(cfg_file, 'r') as f:
|
44 |
+
try:
|
45 |
+
new_config = yaml.load(f, Loader=yaml.FullLoader)
|
46 |
+
except:
|
47 |
+
new_config = yaml.load(f)
|
48 |
+
merge_new_config(config=config, new_config=new_config)
|
49 |
+
return config
|
50 |
+
|
51 |
+
|
52 |
+
def get_config(args, logger=None):
|
53 |
+
if args.resume:
|
54 |
+
cfg_path = os.path.join(args.experiment_path, 'config.yaml')
|
55 |
+
if not os.path.exists(cfg_path):
|
56 |
+
print_log("Failed to resume", logger=logger)
|
57 |
+
raise FileNotFoundError()
|
58 |
+
print_log(f'Resume yaml from {cfg_path}', logger=logger)
|
59 |
+
args.config = cfg_path
|
60 |
+
config = cfg_from_yaml_file(args.config)
|
61 |
+
if not args.resume and args.local_rank == 0:
|
62 |
+
save_experiment_config(args, config, logger)
|
63 |
+
return config
|
64 |
+
|
65 |
+
|
66 |
+
def save_experiment_config(args, config, logger=None):
|
67 |
+
config_path = os.path.join(args.experiment_path, 'config.yaml')
|
68 |
+
os.system('cp %s %s' % (args.config, config_path))
|
69 |
+
print_log(f'Copy the Config file from {args.config} to {config_path}', logger=logger)
|
mm_models/modal_module/point/recon/reconv2_utils/data.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def random_rotate_z(pc):
|
5 |
+
# random roate around z axis
|
6 |
+
theta = np.random.uniform(0, 2 * np.pi)
|
7 |
+
R = np.array([[np.cos(theta), -np.sin(theta), 0],
|
8 |
+
[np.sin(theta), np.cos(theta), 0],
|
9 |
+
[0, 0, 1]])
|
10 |
+
return np.matmul(pc, R)
|
11 |
+
|
12 |
+
|
13 |
+
def normalize_pc(pc):
|
14 |
+
""" pc: NxC, return NxC """
|
15 |
+
centroid = np.mean(pc, axis=0)
|
16 |
+
pc = pc - centroid
|
17 |
+
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
|
18 |
+
if m < 1e-6:
|
19 |
+
pc = np.zeros_like(pc)
|
20 |
+
else:
|
21 |
+
pc = pc / m
|
22 |
+
return pc
|
23 |
+
|
24 |
+
|
25 |
+
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
|
26 |
+
""" batch_pc: BxNx3 """
|
27 |
+
for b in range(batch_pc.shape[0]):
|
28 |
+
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
|
29 |
+
drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
|
30 |
+
if len(drop_idx) > 0:
|
31 |
+
batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point
|
32 |
+
return batch_pc
|
33 |
+
|
34 |
+
|
35 |
+
def random_scale_point_cloud(data, scale_low=0.8, scale_high=1.25):
|
36 |
+
|
37 |
+
scales = np.random.uniform(scale_low, scale_high)
|
38 |
+
data *= scales
|
39 |
+
return data
|
40 |
+
|
41 |
+
|
42 |
+
def shift_point_cloud(batch_data, shift_range=0.1):
|
43 |
+
""" Randomly shift point cloud. Shift is per point cloud.
|
44 |
+
Input:
|
45 |
+
BxNx3 array, original batch of point clouds
|
46 |
+
Return:
|
47 |
+
BxNx3 array, shifted batch of point clouds
|
48 |
+
"""
|
49 |
+
B, N, C = batch_data.shape
|
50 |
+
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
|
51 |
+
for batch_index in range(B):
|
52 |
+
batch_data[batch_index, :, :] += shifts[batch_index, :]
|
53 |
+
return batch_data
|
54 |
+
|
55 |
+
|
56 |
+
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
|
57 |
+
""" Randomly perturb the point clouds by small rotations
|
58 |
+
Input:
|
59 |
+
BxNx3 array, original batch of point clouds
|
60 |
+
Return:
|
61 |
+
BxNx3 array, rotated batch of point clouds
|
62 |
+
"""
|
63 |
+
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
64 |
+
for k in range(batch_data.shape[0]):
|
65 |
+
angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
|
66 |
+
Rx = np.array([[1, 0, 0],
|
67 |
+
[0, np.cos(angles[0]), -np.sin(angles[0])],
|
68 |
+
[0, np.sin(angles[0]), np.cos(angles[0])]])
|
69 |
+
Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
|
70 |
+
[0, 1, 0],
|
71 |
+
[-np.sin(angles[1]), 0, np.cos(angles[1])]])
|
72 |
+
Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
|
73 |
+
[np.sin(angles[2]), np.cos(angles[2]), 0],
|
74 |
+
[0, 0, 1]])
|
75 |
+
R = np.dot(Rz, np.dot(Ry, Rx))
|
76 |
+
shape_pc = batch_data[k, ...]
|
77 |
+
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
|
78 |
+
return rotated_data
|
79 |
+
|
80 |
+
|
81 |
+
def rotate_point_cloud(batch_data):
|
82 |
+
""" Randomly rotate the point clouds to augument the dataset
|
83 |
+
rotation is per shape based along up direction
|
84 |
+
Input:
|
85 |
+
BxNx3 array, original batch of point clouds
|
86 |
+
Return:
|
87 |
+
BxNx3 array, rotated batch of point clouds
|
88 |
+
"""
|
89 |
+
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
|
90 |
+
for k in range(batch_data.shape[0]):
|
91 |
+
rotation_angle = np.random.uniform() * 2 * np.pi
|
92 |
+
cosval = np.cos(rotation_angle)
|
93 |
+
sinval = np.sin(rotation_angle)
|
94 |
+
rotation_matrix = np.array([[cosval, 0, sinval],
|
95 |
+
[0, 1, 0],
|
96 |
+
[-sinval, 0, cosval]])
|
97 |
+
shape_pc = batch_data[k, ...]
|
98 |
+
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
|
99 |
+
return rotated_data
|
100 |
+
|
101 |
+
|
102 |
+
def augment_pc(data):
|
103 |
+
# data = random_point_dropout(data[None, ...])
|
104 |
+
data = random_scale_point_cloud(data[None, ...])
|
105 |
+
data = shift_point_cloud(data)
|
106 |
+
data = rotate_perturbation_point_cloud(data)
|
107 |
+
data = rotate_point_cloud(data)
|
108 |
+
data = data.squeeze()
|
109 |
+
return data
|
mm_models/modal_module/point/recon/reconv2_utils/dist_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import distributed as dist
|
3 |
+
|
4 |
+
|
5 |
+
def init_dist(local_rank, backend='nccl', **kwargs):
|
6 |
+
torch.cuda.set_device(local_rank)
|
7 |
+
dist.init_process_group(backend=backend, **kwargs)
|
8 |
+
print(f'init distributed in rank {local_rank}')
|
9 |
+
|
10 |
+
|
11 |
+
def reduce_tensor(tensor, args):
|
12 |
+
'''
|
13 |
+
for acc kind, get the mean in each gpu
|
14 |
+
'''
|
15 |
+
rt = tensor.clone()
|
16 |
+
torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
|
17 |
+
rt /= args.world_size
|
18 |
+
return rt
|
19 |
+
|
20 |
+
|
21 |
+
def gather_tensor(tensor, args):
|
22 |
+
output_tensors = [tensor.clone() for _ in range(args.world_size)]
|
23 |
+
torch.distributed.all_gather(output_tensors, tensor)
|
24 |
+
concat = torch.cat(output_tensors, dim=0)
|
25 |
+
return concat
|
26 |
+
|
27 |
+
|
28 |
+
def set_batch_size(args, config):
|
29 |
+
if args.distributed:
|
30 |
+
assert config.total_bs % args.world_size == 0
|
31 |
+
if config.dataset.get('train'):
|
32 |
+
config.dataset.train.others.bs = config.total_bs // args.world_size
|
33 |
+
if config.dataset.get('extra_train'):
|
34 |
+
config.dataset.extra_train.others.bs = config.total_bs // args.world_size
|
35 |
+
if config.dataset.get('val'):
|
36 |
+
config.dataset.val.others.bs = config.total_bs // args.world_size
|
37 |
+
if config.dataset.get('test'):
|
38 |
+
config.dataset.test.others.bs = config.total_bs // args.world_size
|
39 |
+
else:
|
40 |
+
if config.dataset.get('train'):
|
41 |
+
config.dataset.train.others.bs = config.total_bs
|
42 |
+
if config.dataset.get('extra_train'):
|
43 |
+
config.dataset.extra_train.others.bs = config.total_bs
|
44 |
+
if config.dataset.get('extra_val'):
|
45 |
+
config.dataset.extra_val.others.bs = config.total_bs
|
46 |
+
if config.dataset.get('val'):
|
47 |
+
config.dataset.val.others.bs = config.total_bs
|
48 |
+
if config.dataset.get('test'):
|
49 |
+
config.dataset.test.others.bs = config.total_bs
|
mm_models/modal_module/point/recon/reconv2_utils/knn.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def square_distance(src, dst):
|
5 |
+
"""
|
6 |
+
Calculate Euclid distance between each two points.
|
7 |
+
src^T * dst = xn * xm + yn * ym + zn * zm;
|
8 |
+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
|
9 |
+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
|
10 |
+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
|
11 |
+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
|
12 |
+
Input:
|
13 |
+
src: source points, [B, N, C]
|
14 |
+
dst: target points, [B, M, C]
|
15 |
+
Output:
|
16 |
+
dist: per-point square distance, [B, N, M]
|
17 |
+
"""
|
18 |
+
B, N, _ = src.shape
|
19 |
+
_, M, _ = dst.shape
|
20 |
+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
|
21 |
+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
|
22 |
+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
|
23 |
+
return dist
|
24 |
+
|
25 |
+
|
26 |
+
def knn_point(nsample, xyz, new_xyz):
|
27 |
+
"""
|
28 |
+
Input:
|
29 |
+
nsample: max sample number in local region
|
30 |
+
xyz: all points, [B, N, C]
|
31 |
+
new_xyz: query points, [B, S, C]
|
32 |
+
Return:
|
33 |
+
group_idx: grouped points index, [B, S, nsample]
|
34 |
+
"""
|
35 |
+
sqrdists = square_distance(new_xyz, xyz)
|
36 |
+
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
|
37 |
+
return group_idx
|
mm_models/modal_module/point/recon/reconv2_utils/logger.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch.distributed as dist
|
3 |
+
|
4 |
+
logger_initialized = {}
|
5 |
+
|
6 |
+
def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
|
7 |
+
"""Get root logger and add a keyword filter to it.
|
8 |
+
The logger will be initialized if it has not been initialized. By default a
|
9 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
10 |
+
also be added. The name of the root logger is the top-level package name,
|
11 |
+
e.g., "mmdet3d".
|
12 |
+
Args:
|
13 |
+
log_file (str, optional): File path of log. Defaults to None.
|
14 |
+
log_level (int, optional): The level of logger.
|
15 |
+
Defaults to logging.INFO.
|
16 |
+
name (str, optional): The name of the root logger, also used as a
|
17 |
+
filter keyword. Defaults to 'mmdet3d'.
|
18 |
+
Returns:
|
19 |
+
:obj:`logging.Logger`: The obtained logger
|
20 |
+
"""
|
21 |
+
logger = get_logger(name=name, log_file=log_file, log_level=log_level)
|
22 |
+
# add a logging filter
|
23 |
+
logging_filter = logging.Filter(name)
|
24 |
+
logging_filter.filter = lambda record: record.find(name) != -1
|
25 |
+
|
26 |
+
return logger
|
27 |
+
|
28 |
+
|
29 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
30 |
+
"""Initialize and get a logger by name.
|
31 |
+
If the logger has not been initialized, this method will initialize the
|
32 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
33 |
+
be directly returned. During initialization, a StreamHandler will always be
|
34 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
35 |
+
will also be added.
|
36 |
+
Args:
|
37 |
+
name (str): Logger name.
|
38 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
39 |
+
will be added to the logger.
|
40 |
+
log_level (int): The logger level. Note that only the process of
|
41 |
+
rank 0 is affected, and other processes will set the level to
|
42 |
+
"Error" thus be silent most of the time.
|
43 |
+
file_mode (str): The file mode used in opening log file.
|
44 |
+
Defaults to 'w'.
|
45 |
+
Returns:
|
46 |
+
logging.Logger: The expected logger.
|
47 |
+
"""
|
48 |
+
logger = logging.getLogger(name)
|
49 |
+
if name in logger_initialized:
|
50 |
+
return logger
|
51 |
+
# handle hierarchical names
|
52 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
53 |
+
# initialization since it is a child of "a".
|
54 |
+
for logger_name in logger_initialized:
|
55 |
+
if name.startswith(logger_name):
|
56 |
+
return logger
|
57 |
+
|
58 |
+
# handle duplicate logs to the console
|
59 |
+
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
60 |
+
# to the root logger. As logger.propagate is True by default, this root
|
61 |
+
# level handler causes logging messages from rank>0 processes to
|
62 |
+
# unexpectedly show up on the console, creating much unwanted clutter.
|
63 |
+
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
64 |
+
# at the ERROR level.
|
65 |
+
for handler in logger.root.handlers:
|
66 |
+
if type(handler) is logging.StreamHandler:
|
67 |
+
handler.setLevel(logging.ERROR)
|
68 |
+
|
69 |
+
stream_handler = logging.StreamHandler()
|
70 |
+
handlers = [stream_handler]
|
71 |
+
|
72 |
+
if dist.is_available() and dist.is_initialized():
|
73 |
+
rank = dist.get_rank()
|
74 |
+
else:
|
75 |
+
rank = 0
|
76 |
+
|
77 |
+
# only rank 0 will add a FileHandler
|
78 |
+
if rank == 0 and log_file is not None:
|
79 |
+
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
80 |
+
# provide an interface to change the file mode to the default
|
81 |
+
# behaviour.
|
82 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
83 |
+
handlers.append(file_handler)
|
84 |
+
|
85 |
+
formatter = logging.Formatter(
|
86 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
87 |
+
for handler in handlers:
|
88 |
+
handler.setFormatter(formatter)
|
89 |
+
handler.setLevel(log_level)
|
90 |
+
logger.addHandler(handler)
|
91 |
+
|
92 |
+
if rank == 0:
|
93 |
+
logger.setLevel(log_level)
|
94 |
+
else:
|
95 |
+
logger.setLevel(logging.ERROR)
|
96 |
+
|
97 |
+
logger_initialized[name] = True
|
98 |
+
|
99 |
+
|
100 |
+
return logger
|
101 |
+
|
102 |
+
|
103 |
+
def print_log(msg, logger=None, level=logging.INFO):
|
104 |
+
"""Print a log message.
|
105 |
+
Args:
|
106 |
+
msg (str): The message to be logged.
|
107 |
+
logger (logging.Logger | str | None): The logger to be used.
|
108 |
+
Some special loggers are:
|
109 |
+
- "silent": no message will be printed.
|
110 |
+
- other str: the logger obtained with `get_root_logger(logger)`.
|
111 |
+
- None: The `print()` method will be used to print log messages.
|
112 |
+
level (int): Logging level. Only available when `logger` is a Logger
|
113 |
+
object or "root".
|
114 |
+
"""
|
115 |
+
if logger is None:
|
116 |
+
print(msg)
|
117 |
+
elif isinstance(logger, logging.Logger):
|
118 |
+
logger.log(level, msg)
|
119 |
+
elif logger == 'silent':
|
120 |
+
pass
|
121 |
+
elif isinstance(logger, str):
|
122 |
+
_logger = get_logger(logger)
|
123 |
+
_logger.log(level, msg)
|
124 |
+
else:
|
125 |
+
raise TypeError(
|
126 |
+
'logger should be either a logging.Logger object, str, '
|
127 |
+
f'"silent" or None, but got {type(logger)}')
|
mm_models/modal_module/point/recon/reconv2_utils/misc.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from mpl_toolkits.mplot3d import Axes3D
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import os
|
9 |
+
from collections import abc
|
10 |
+
# from pointnet2_ops import pointnet2_utils
|
11 |
+
|
12 |
+
|
13 |
+
# def fps(data, number):
|
14 |
+
# '''
|
15 |
+
# data B N 3
|
16 |
+
# number int
|
17 |
+
# '''
|
18 |
+
# fps_idx = pointnet2_utils.furthest_point_sample(data, number)
|
19 |
+
# fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
|
20 |
+
# return fps_data
|
21 |
+
|
22 |
+
def index_points(points, idx):
|
23 |
+
"""
|
24 |
+
Input:
|
25 |
+
points: input points data, [B, N, C]
|
26 |
+
idx: sample index data, [B, S]
|
27 |
+
Return:
|
28 |
+
new_points:, indexed points data, [B, S, C]
|
29 |
+
"""
|
30 |
+
device = points.device
|
31 |
+
B = points.shape[0]
|
32 |
+
view_shape = list(idx.shape)
|
33 |
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
34 |
+
repeat_shape = list(idx.shape)
|
35 |
+
repeat_shape[0] = 1
|
36 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
37 |
+
new_points = points[batch_indices, idx, :]
|
38 |
+
return new_points
|
39 |
+
|
40 |
+
def fps(point_data, npoint):
|
41 |
+
"""
|
42 |
+
Input:
|
43 |
+
xyz: pointcloud data, [B, N, 3]
|
44 |
+
npoint: number of samples
|
45 |
+
Return:
|
46 |
+
centroids: sampled pointcloud index, [B, npoint]
|
47 |
+
"""
|
48 |
+
xyz = point_data[:, :, :3]
|
49 |
+
device = xyz.device
|
50 |
+
B, N, C = xyz.shape
|
51 |
+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
52 |
+
distance = torch.ones(B, N).to(device) * 1e10
|
53 |
+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
54 |
+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
|
55 |
+
for i in range(npoint):
|
56 |
+
centroids[:, i] = farthest
|
57 |
+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
|
58 |
+
dist = torch.sum((xyz - centroid) ** 2, -1)
|
59 |
+
distance = torch.min(distance, dist)
|
60 |
+
farthest = torch.max(distance, -1)[1]
|
61 |
+
return index_points(point_data, centroids)
|
62 |
+
|
63 |
+
|
64 |
+
def worker_init_fn(worker_id):
|
65 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
66 |
+
|
67 |
+
|
68 |
+
def build_lambda_sche(opti, config):
|
69 |
+
if config.get('decay_step') is not None:
|
70 |
+
lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
|
71 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
|
72 |
+
else:
|
73 |
+
raise NotImplementedError()
|
74 |
+
return scheduler
|
75 |
+
|
76 |
+
|
77 |
+
def build_lambda_bnsche(model, config):
|
78 |
+
if config.get('decay_step') is not None:
|
79 |
+
bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
|
80 |
+
bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
|
81 |
+
else:
|
82 |
+
raise NotImplementedError()
|
83 |
+
return bnm_scheduler
|
84 |
+
|
85 |
+
|
86 |
+
def set_random_seed(seed, deterministic=False):
|
87 |
+
"""Set random seed.
|
88 |
+
Args:
|
89 |
+
seed (int): Seed to be used.
|
90 |
+
deterministic (bool): Whether to set the deterministic option for
|
91 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
92 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
93 |
+
Default: False.
|
94 |
+
|
95 |
+
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
96 |
+
if cuda_deterministic: # slower, more reproducible
|
97 |
+
cudnn.deterministic = True
|
98 |
+
cudnn.benchmark = False
|
99 |
+
else: # faster, less reproducible
|
100 |
+
cudnn.deterministic = False
|
101 |
+
cudnn.benchmark = True
|
102 |
+
|
103 |
+
"""
|
104 |
+
random.seed(seed)
|
105 |
+
np.random.seed(seed)
|
106 |
+
torch.manual_seed(seed)
|
107 |
+
torch.cuda.manual_seed_all(seed)
|
108 |
+
if deterministic:
|
109 |
+
torch.backends.cudnn.deterministic = True
|
110 |
+
torch.backends.cudnn.benchmark = False
|
111 |
+
|
112 |
+
|
113 |
+
def is_seq_of(seq, expected_type, seq_type=None):
|
114 |
+
"""Check whether it is a sequence of some type.
|
115 |
+
Args:
|
116 |
+
seq (Sequence): The sequence to be checked.
|
117 |
+
expected_type (type): Expected type of sequence items.
|
118 |
+
seq_type (type, optional): Expected sequence type.
|
119 |
+
Returns:
|
120 |
+
bool: Whether the sequence is valid.
|
121 |
+
"""
|
122 |
+
if seq_type is None:
|
123 |
+
exp_seq_type = abc.Sequence
|
124 |
+
else:
|
125 |
+
assert isinstance(seq_type, type)
|
126 |
+
exp_seq_type = seq_type
|
127 |
+
if not isinstance(seq, exp_seq_type):
|
128 |
+
return False
|
129 |
+
for item in seq:
|
130 |
+
if not isinstance(item, expected_type):
|
131 |
+
return False
|
132 |
+
return True
|
133 |
+
|
134 |
+
|
135 |
+
def set_bn_momentum_default(bn_momentum):
|
136 |
+
def fn(m):
|
137 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
|
138 |
+
m.momentum = bn_momentum
|
139 |
+
|
140 |
+
return fn
|
141 |
+
|
142 |
+
|
143 |
+
class BNMomentumScheduler(object):
|
144 |
+
|
145 |
+
def __init__(
|
146 |
+
self, model, bn_lambda, last_epoch=-1,
|
147 |
+
setter=set_bn_momentum_default
|
148 |
+
):
|
149 |
+
if not isinstance(model, nn.Module):
|
150 |
+
raise RuntimeError(
|
151 |
+
"Class '{}' is not a PyTorch nn Module".format(
|
152 |
+
type(model).__name__
|
153 |
+
)
|
154 |
+
)
|
155 |
+
|
156 |
+
self.model = model
|
157 |
+
self.setter = setter
|
158 |
+
self.lmbd = bn_lambda
|
159 |
+
|
160 |
+
self.step(last_epoch + 1)
|
161 |
+
self.last_epoch = last_epoch
|
162 |
+
|
163 |
+
def step(self, epoch=None):
|
164 |
+
if epoch is None:
|
165 |
+
epoch = self.last_epoch + 1
|
166 |
+
|
167 |
+
self.last_epoch = epoch
|
168 |
+
self.model.apply(self.setter(self.lmbd(epoch)))
|
169 |
+
|
170 |
+
def get_momentum(self, epoch=None):
|
171 |
+
if epoch is None:
|
172 |
+
epoch = self.last_epoch + 1
|
173 |
+
return self.lmbd(epoch)
|
174 |
+
|
175 |
+
|
176 |
+
def seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False):
|
177 |
+
'''
|
178 |
+
seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
|
179 |
+
'''
|
180 |
+
_, n, c = xyz.shape
|
181 |
+
|
182 |
+
assert n == num_points
|
183 |
+
assert c == 3
|
184 |
+
if crop == num_points:
|
185 |
+
return xyz, None
|
186 |
+
|
187 |
+
INPUT = []
|
188 |
+
CROP = []
|
189 |
+
for points in xyz:
|
190 |
+
if isinstance(crop, list):
|
191 |
+
num_crop = random.randint(crop[0], crop[1])
|
192 |
+
else:
|
193 |
+
num_crop = crop
|
194 |
+
|
195 |
+
points = points.unsqueeze(0)
|
196 |
+
|
197 |
+
if fixed_points is None:
|
198 |
+
center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda()
|
199 |
+
else:
|
200 |
+
if isinstance(fixed_points, list):
|
201 |
+
fixed_point = random.sample(fixed_points, 1)[0]
|
202 |
+
else:
|
203 |
+
fixed_point = fixed_points
|
204 |
+
center = fixed_point.reshape(1, 1, 3).cuda()
|
205 |
+
|
206 |
+
distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1) # 1 1 2048
|
207 |
+
|
208 |
+
idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0] # 2048
|
209 |
+
|
210 |
+
if padding_zeros:
|
211 |
+
input_data = points.clone()
|
212 |
+
input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0
|
213 |
+
|
214 |
+
else:
|
215 |
+
input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
|
216 |
+
|
217 |
+
crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
|
218 |
+
|
219 |
+
if isinstance(crop, list):
|
220 |
+
INPUT.append(fps(input_data, 2048))
|
221 |
+
CROP.append(fps(crop_data, 2048))
|
222 |
+
else:
|
223 |
+
INPUT.append(input_data)
|
224 |
+
CROP.append(crop_data)
|
225 |
+
|
226 |
+
input_data = torch.cat(INPUT, dim=0) # B N 3
|
227 |
+
crop_data = torch.cat(CROP, dim=0) # B M 3
|
228 |
+
|
229 |
+
return input_data.contiguous(), crop_data.contiguous()
|
230 |
+
|
231 |
+
|
232 |
+
def get_ptcloud_img(ptcloud, roll, pitch):
|
233 |
+
fig = plt.figure(figsize=(8, 8))
|
234 |
+
|
235 |
+
x, z, y = ptcloud.transpose(1, 0)
|
236 |
+
ax = fig.gca(projection=Axes3D.name, adjustable='box')
|
237 |
+
ax.axis('off')
|
238 |
+
# ax.axis('scaled')
|
239 |
+
ax.view_init(roll, pitch)
|
240 |
+
max, min = np.max(ptcloud), np.min(ptcloud)
|
241 |
+
ax.set_xbound(min, max)
|
242 |
+
ax.set_ybound(min, max)
|
243 |
+
ax.set_zbound(min, max)
|
244 |
+
ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')
|
245 |
+
|
246 |
+
fig.canvas.draw()
|
247 |
+
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
248 |
+
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
249 |
+
return img
|
250 |
+
|
251 |
+
|
252 |
+
def visualize_KITTI(path, data_list, titles=['input', 'pred'], cmap=['bwr', 'autumn'], zdir='y',
|
253 |
+
xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)):
|
254 |
+
fig = plt.figure(figsize=(6 * len(data_list), 6))
|
255 |
+
cmax = data_list[-1][:, 0].max()
|
256 |
+
|
257 |
+
for i in range(len(data_list)):
|
258 |
+
data = data_list[i][:-2048] if i == 1 else data_list[i]
|
259 |
+
color = data[:, 0] / cmax
|
260 |
+
ax = fig.add_subplot(1, len(data_list), i + 1, projection='3d')
|
261 |
+
ax.view_init(30, -120)
|
262 |
+
b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color, vmin=-1, vmax=1, cmap=cmap[0], s=4,
|
263 |
+
linewidth=0.05, edgecolors='black')
|
264 |
+
ax.set_title(titles[i])
|
265 |
+
|
266 |
+
ax.set_axis_off()
|
267 |
+
ax.set_xlim(xlim)
|
268 |
+
ax.set_ylim(ylim)
|
269 |
+
ax.set_zlim(zlim)
|
270 |
+
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
|
271 |
+
if not os.path.exists(path):
|
272 |
+
os.makedirs(path)
|
273 |
+
|
274 |
+
pic_path = path + '.png'
|
275 |
+
fig.savefig(pic_path)
|
276 |
+
|
277 |
+
np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
|
278 |
+
np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
|
279 |
+
plt.close(fig)
|
280 |
+
|
281 |
+
|
282 |
+
def random_dropping(pc, e):
|
283 |
+
up_num = max(64, 768 // (e // 50 + 1))
|
284 |
+
pc = pc
|
285 |
+
random_num = torch.randint(1, up_num, (1, 1))[0, 0]
|
286 |
+
pc = fps(pc, random_num)
|
287 |
+
padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
|
288 |
+
pc = torch.cat([pc, padding], dim=1)
|
289 |
+
return pc
|
290 |
+
|
291 |
+
|
292 |
+
def random_scale(partial, scale_range=[0.8, 1.2]):
|
293 |
+
scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
|
294 |
+
return partial * scale
|
mm_models/modal_module/point/recon/reconv2_utils/parser.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
|
6 |
+
def get_args():
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument(
|
9 |
+
'--config',
|
10 |
+
type=str,
|
11 |
+
help='yaml config file')
|
12 |
+
parser.add_argument('--distributed', action='store_true', default=False)
|
13 |
+
parser.add_argument('--local-rank', type=int, default=0)
|
14 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
15 |
+
# seed
|
16 |
+
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
17 |
+
parser.add_argument(
|
18 |
+
'--deterministic',
|
19 |
+
action='store_true',
|
20 |
+
help='whether to set deterministic options for CUDNN backend.')
|
21 |
+
# bn
|
22 |
+
parser.add_argument(
|
23 |
+
'--sync_bn',
|
24 |
+
action='store_true',
|
25 |
+
default=False,
|
26 |
+
help='whether to use sync bn')
|
27 |
+
# some args
|
28 |
+
parser.add_argument('--exp_name', type=str, default='default', help='experiment name')
|
29 |
+
parser.add_argument('--start_ckpts', type=str, default=None, help='reload used ckpt path')
|
30 |
+
parser.add_argument('--ckpts', type=str, default=None, help='test used ckpt path')
|
31 |
+
parser.add_argument('--val_freq', type=int, default=1, help='test freq')
|
32 |
+
parser.add_argument(
|
33 |
+
'--vote',
|
34 |
+
action='store_true',
|
35 |
+
default=False,
|
36 |
+
help='vote acc')
|
37 |
+
parser.add_argument(
|
38 |
+
'--resume',
|
39 |
+
action='store_true',
|
40 |
+
default=False,
|
41 |
+
help='autoresume training (interrupted by accident)')
|
42 |
+
parser.add_argument(
|
43 |
+
'--svm',
|
44 |
+
action='store_true',
|
45 |
+
default=False,
|
46 |
+
help='svm')
|
47 |
+
parser.add_argument(
|
48 |
+
'--zeroshot',
|
49 |
+
action='store_true',
|
50 |
+
default=False,
|
51 |
+
help='zero-shot')
|
52 |
+
parser.add_argument(
|
53 |
+
'--test',
|
54 |
+
action='store_true',
|
55 |
+
default=False,
|
56 |
+
help='test mode for certain ckpt')
|
57 |
+
parser.add_argument(
|
58 |
+
'--reconstruct',
|
59 |
+
action='store_true',
|
60 |
+
default=False,
|
61 |
+
help='reconstruct pretraining stage')
|
62 |
+
parser.add_argument(
|
63 |
+
'--contrast',
|
64 |
+
action='store_true',
|
65 |
+
default=False,
|
66 |
+
help='contrast pretraining stage')
|
67 |
+
parser.add_argument(
|
68 |
+
'--finetune_model',
|
69 |
+
action='store_true',
|
70 |
+
default=False,
|
71 |
+
help='finetune modelnet with pretrained weight')
|
72 |
+
parser.add_argument(
|
73 |
+
'--way', type=int, default=-1)
|
74 |
+
parser.add_argument(
|
75 |
+
'--shot', type=int, default=-1)
|
76 |
+
parser.add_argument(
|
77 |
+
'--fold', type=int, default=-1)
|
78 |
+
|
79 |
+
args = parser.parse_args()
|
80 |
+
|
81 |
+
if args.test and args.resume:
|
82 |
+
raise ValueError(
|
83 |
+
'--test and --resume cannot be both activate')
|
84 |
+
|
85 |
+
if args.resume and args.start_ckpts is not None:
|
86 |
+
raise ValueError(
|
87 |
+
'--resume and --start_ckpts cannot be both activate')
|
88 |
+
|
89 |
+
if args.test and args.ckpts is None:
|
90 |
+
raise ValueError(
|
91 |
+
'ckpts shouldnt be None while test mode')
|
92 |
+
|
93 |
+
if args.finetune_model and args.ckpts is None:
|
94 |
+
print(
|
95 |
+
'training from scratch')
|
96 |
+
|
97 |
+
if 'LOCAL_RANK' not in os.environ:
|
98 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
99 |
+
|
100 |
+
if args.test:
|
101 |
+
args.exp_name = 'test_' + args.exp_name
|
102 |
+
args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,
|
103 |
+
args.exp_name)
|
104 |
+
args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, 'TFBoard',
|
105 |
+
args.exp_name)
|
106 |
+
args.log_name = Path(args.config).stem
|
107 |
+
create_experiment_dir(args)
|
108 |
+
return args
|
109 |
+
|
110 |
+
|
111 |
+
def create_experiment_dir(args):
|
112 |
+
if not os.path.exists(args.experiment_path):
|
113 |
+
os.makedirs(args.experiment_path, exist_ok=True)
|
114 |
+
print('Create experiment path successfully at %s' % args.experiment_path)
|
115 |
+
if not os.path.exists(args.tfboard_path):
|
116 |
+
os.makedirs(args.tfboard_path, exist_ok=True)
|
117 |
+
print('Create TFBoard path successfully at %s' % args.tfboard_path)
|
mm_models/modal_module/point/recon/reconv2_utils/randaugment.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import PIL.ImageOps
|
7 |
+
import PIL.ImageEnhance
|
8 |
+
import PIL.ImageDraw
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
PARAMETER_MAX = 10
|
14 |
+
|
15 |
+
|
16 |
+
def AutoContrast(img, **kwarg):
|
17 |
+
return PIL.ImageOps.autocontrast(img)
|
18 |
+
|
19 |
+
|
20 |
+
def Brightness(img, v, max_v, bias=0):
|
21 |
+
v = _float_parameter(v, max_v) + bias
|
22 |
+
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
23 |
+
|
24 |
+
|
25 |
+
def Color(img, v, max_v, bias=0):
|
26 |
+
v = _float_parameter(v, max_v) + bias
|
27 |
+
return PIL.ImageEnhance.Color(img).enhance(v)
|
28 |
+
|
29 |
+
|
30 |
+
def Contrast(img, v, max_v, bias=0):
|
31 |
+
v = _float_parameter(v, max_v) + bias
|
32 |
+
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
33 |
+
|
34 |
+
|
35 |
+
def Cutout(img, v, max_v, bias=0):
|
36 |
+
if v == 0:
|
37 |
+
return img
|
38 |
+
v = _float_parameter(v, max_v) + bias
|
39 |
+
v = int(v * min(img.size))
|
40 |
+
return CutoutAbs(img, v)
|
41 |
+
|
42 |
+
|
43 |
+
def CutoutAbs(img, v, **kwarg):
|
44 |
+
w, h = img.size
|
45 |
+
x0 = np.random.uniform(0, w)
|
46 |
+
y0 = np.random.uniform(0, h)
|
47 |
+
x0 = int(max(0, x0 - v / 2.))
|
48 |
+
y0 = int(max(0, y0 - v / 2.))
|
49 |
+
x1 = int(min(w, x0 + v))
|
50 |
+
y1 = int(min(h, y0 + v))
|
51 |
+
xy = (x0, y0, x1, y1)
|
52 |
+
# gray
|
53 |
+
color = (127, 127, 127)
|
54 |
+
img = img.copy()
|
55 |
+
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
56 |
+
return img
|
57 |
+
|
58 |
+
|
59 |
+
def Equalize(img, **kwarg):
|
60 |
+
return PIL.ImageOps.equalize(img)
|
61 |
+
|
62 |
+
|
63 |
+
def Identity(img, **kwarg):
|
64 |
+
return img
|
65 |
+
|
66 |
+
|
67 |
+
def Invert(img, **kwarg):
|
68 |
+
return PIL.ImageOps.invert(img)
|
69 |
+
|
70 |
+
|
71 |
+
def Posterize(img, v, max_v, bias=0):
|
72 |
+
v = _int_parameter(v, max_v) + bias
|
73 |
+
return PIL.ImageOps.posterize(img, v)
|
74 |
+
|
75 |
+
|
76 |
+
def Rotate(img, v, max_v, bias=0):
|
77 |
+
v = _int_parameter(v, max_v) + bias
|
78 |
+
if random.random() < 0.5:
|
79 |
+
v = -v
|
80 |
+
return img.rotate(v)
|
81 |
+
|
82 |
+
|
83 |
+
def Sharpness(img, v, max_v, bias=0):
|
84 |
+
v = _float_parameter(v, max_v) + bias
|
85 |
+
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
86 |
+
|
87 |
+
|
88 |
+
def ShearX(img, v, max_v, bias=0):
|
89 |
+
v = _float_parameter(v, max_v) + bias
|
90 |
+
if random.random() < 0.5:
|
91 |
+
v = -v
|
92 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
|
93 |
+
|
94 |
+
|
95 |
+
def ShearY(img, v, max_v, bias=0):
|
96 |
+
v = _float_parameter(v, max_v) + bias
|
97 |
+
if random.random() < 0.5:
|
98 |
+
v = -v
|
99 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
|
100 |
+
|
101 |
+
|
102 |
+
def Solarize(img, v, max_v, bias=0):
|
103 |
+
v = _int_parameter(v, max_v) + bias
|
104 |
+
return PIL.ImageOps.solarize(img, 256 - v)
|
105 |
+
|
106 |
+
|
107 |
+
def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
|
108 |
+
v = _int_parameter(v, max_v) + bias
|
109 |
+
if random.random() < 0.5:
|
110 |
+
v = -v
|
111 |
+
img_np = np.array(img).astype(np.int)
|
112 |
+
img_np = img_np + v
|
113 |
+
img_np = np.clip(img_np, 0, 255)
|
114 |
+
img_np = img_np.astype(np.uint8)
|
115 |
+
img = Image.fromarray(img_np)
|
116 |
+
return PIL.ImageOps.solarize(img, threshold)
|
117 |
+
|
118 |
+
|
119 |
+
def TranslateX(img, v, max_v, bias=0):
|
120 |
+
v = _float_parameter(v, max_v) + bias
|
121 |
+
if random.random() < 0.5:
|
122 |
+
v = -v
|
123 |
+
v = int(v * img.size[0])
|
124 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
125 |
+
|
126 |
+
|
127 |
+
def TranslateY(img, v, max_v, bias=0):
|
128 |
+
v = _float_parameter(v, max_v) + bias
|
129 |
+
if random.random() < 0.5:
|
130 |
+
v = -v
|
131 |
+
v = int(v * img.size[1])
|
132 |
+
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
133 |
+
|
134 |
+
|
135 |
+
def _float_parameter(v, max_v):
|
136 |
+
return float(v) * max_v / PARAMETER_MAX
|
137 |
+
|
138 |
+
|
139 |
+
def _int_parameter(v, max_v):
|
140 |
+
return int(v * max_v / PARAMETER_MAX)
|
141 |
+
|
142 |
+
|
143 |
+
def fixmatch_augment_pool():
|
144 |
+
# FixMatch paper
|
145 |
+
augs = [(AutoContrast, None, None),
|
146 |
+
(Brightness, 0.9, 0.05),
|
147 |
+
(Color, 0.9, 0.05),
|
148 |
+
(Contrast, 0.9, 0.05),
|
149 |
+
(Equalize, None, None),
|
150 |
+
(Identity, None, None),
|
151 |
+
(Posterize, 4, 4),
|
152 |
+
(Rotate, 30, 0),
|
153 |
+
(Sharpness, 0.9, 0.05),
|
154 |
+
(ShearX, 0.3, 0),
|
155 |
+
(ShearY, 0.3, 0),
|
156 |
+
(Solarize, 256, 0),
|
157 |
+
(TranslateX, 0.3, 0),
|
158 |
+
(TranslateY, 0.3, 0)]
|
159 |
+
return augs
|
160 |
+
|
161 |
+
|
162 |
+
def my_augment_pool():
|
163 |
+
# Test
|
164 |
+
augs = [(AutoContrast, None, None),
|
165 |
+
(Brightness, 1.8, 0.1),
|
166 |
+
(Color, 1.8, 0.1),
|
167 |
+
(Contrast, 1.8, 0.1),
|
168 |
+
(Cutout, 0.2, 0),
|
169 |
+
(Equalize, None, None),
|
170 |
+
(Invert, None, None),
|
171 |
+
(Posterize, 4, 4),
|
172 |
+
(Rotate, 30, 0),
|
173 |
+
(Sharpness, 1.8, 0.1),
|
174 |
+
(ShearX, 0.3, 0),
|
175 |
+
(ShearY, 0.3, 0),
|
176 |
+
(Solarize, 256, 0),
|
177 |
+
(SolarizeAdd, 110, 0),
|
178 |
+
(TranslateX, 0.45, 0),
|
179 |
+
(TranslateY, 0.45, 0)]
|
180 |
+
return augs
|
181 |
+
|
182 |
+
|
183 |
+
class RandAugmentPC(object):
|
184 |
+
def __init__(self, n, m):
|
185 |
+
assert n >= 1
|
186 |
+
assert 1 <= m <= 10
|
187 |
+
self.n = n
|
188 |
+
self.m = m
|
189 |
+
self.augment_pool = my_augment_pool()
|
190 |
+
|
191 |
+
def __call__(self, img):
|
192 |
+
ops = random.choices(self.augment_pool, k=self.n)
|
193 |
+
for op, max_v, bias in ops:
|
194 |
+
prob = np.random.uniform(0.2, 0.8)
|
195 |
+
if random.random() + prob >= 1:
|
196 |
+
img = op(img, v=self.m, max_v=max_v, bias=bias)
|
197 |
+
img = CutoutAbs(img, int(32*0.5))
|
198 |
+
return img
|
199 |
+
|
200 |
+
|
201 |
+
class RandAugmentMC(object):
|
202 |
+
def __init__(self, n, m):
|
203 |
+
assert n >= 1
|
204 |
+
assert 1 <= m <= 10
|
205 |
+
self.n = n
|
206 |
+
self.m = m
|
207 |
+
self.augment_pool = fixmatch_augment_pool()
|
208 |
+
|
209 |
+
def __call__(self, img):
|
210 |
+
ops = random.choices(self.augment_pool, k=self.n)
|
211 |
+
for op, max_v, bias in ops:
|
212 |
+
v = np.random.randint(1, self.m)
|
213 |
+
if random.random() < 0.5:
|
214 |
+
img = op(img, v=v, max_v=max_v, bias=bias)
|
215 |
+
img = CutoutAbs(img, int(32*0.5))
|
216 |
+
return img
|
mm_models/modal_module/point/recon/reconv2_utils/registry.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
from ReConV2.utils import config
|
5 |
+
|
6 |
+
|
7 |
+
class Registry:
|
8 |
+
"""A registry to map strings to classes.
|
9 |
+
Registered object could be built from registry.
|
10 |
+
Example:
|
11 |
+
>>> MODELS = Registry('models')
|
12 |
+
>>> @MODELS.register_module()
|
13 |
+
>>> class ResNet:
|
14 |
+
>>> pass
|
15 |
+
>>> resnet = MODELS.build(dict(NAME='ResNet'))
|
16 |
+
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
|
17 |
+
advanced useage.
|
18 |
+
Args:
|
19 |
+
name (str): Registry name.
|
20 |
+
build_func(func, optional): Build function to construct instance from
|
21 |
+
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
22 |
+
``build_func`` is specified. If ``parent`` is specified and
|
23 |
+
``build_func`` is not given, ``build_func`` will be inherited
|
24 |
+
from ``parent``. Default: None.
|
25 |
+
parent (Registry, optional): Parent registry. The class registered in
|
26 |
+
children registry could be built from parent. Default: None.
|
27 |
+
scope (str, optional): The scope of registry. It is the key to search
|
28 |
+
for children registry. If not specified, scope will be the name of
|
29 |
+
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
30 |
+
Default: None.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, name, build_func=None, parent=None, scope=None):
|
34 |
+
self._name = name
|
35 |
+
self._module_dict = dict()
|
36 |
+
self._children = dict()
|
37 |
+
self._scope = self.infer_scope() if scope is None else scope
|
38 |
+
|
39 |
+
# self.build_func will be set with the following priority:
|
40 |
+
# 1. build_func
|
41 |
+
# 2. parent.build_func
|
42 |
+
# 3. build_from_cfg
|
43 |
+
if build_func is None:
|
44 |
+
if parent is not None:
|
45 |
+
self.build_func = parent.build_func
|
46 |
+
else:
|
47 |
+
self.build_func = build_from_cfg
|
48 |
+
else:
|
49 |
+
self.build_func = build_func
|
50 |
+
if parent is not None:
|
51 |
+
assert isinstance(parent, Registry)
|
52 |
+
parent._add_children(self)
|
53 |
+
self.parent = parent
|
54 |
+
else:
|
55 |
+
self.parent = None
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self._module_dict)
|
59 |
+
|
60 |
+
def __contains__(self, key):
|
61 |
+
return self.get(key) is not None
|
62 |
+
|
63 |
+
def __repr__(self):
|
64 |
+
format_str = self.__class__.__name__ + \
|
65 |
+
f'(name={self._name}, ' \
|
66 |
+
f'items={self._module_dict})'
|
67 |
+
return format_str
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def infer_scope():
|
71 |
+
"""Infer the scope of registry.
|
72 |
+
The name of the package where registry is defined will be returned.
|
73 |
+
Example:
|
74 |
+
# in mmdet/models/backbone/resnet.py
|
75 |
+
>>> MODELS = Registry('models')
|
76 |
+
>>> @MODELS.register_module()
|
77 |
+
>>> class ResNet:
|
78 |
+
>>> pass
|
79 |
+
The scope of ``ResNet`` will be ``mmdet``.
|
80 |
+
Returns:
|
81 |
+
scope (str): The inferred scope name.
|
82 |
+
"""
|
83 |
+
# inspect.stack() trace where this function is called, the index-2
|
84 |
+
# indicates the frame where `infer_scope()` is called
|
85 |
+
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
86 |
+
split_filename = filename.split('.')
|
87 |
+
return split_filename[0]
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def split_scope_key(key):
|
91 |
+
"""Split scope and key.
|
92 |
+
The first scope will be split from key.
|
93 |
+
Examples:
|
94 |
+
>>> Registry.split_scope_key('mmdet.ResNet')
|
95 |
+
'mmdet', 'ResNet'
|
96 |
+
>>> Registry.split_scope_key('ResNet')
|
97 |
+
None, 'ResNet'
|
98 |
+
Return:
|
99 |
+
scope (str, None): The first scope.
|
100 |
+
key (str): The remaining key.
|
101 |
+
"""
|
102 |
+
split_index = key.find('.')
|
103 |
+
if split_index != -1:
|
104 |
+
return key[:split_index], key[split_index + 1:]
|
105 |
+
else:
|
106 |
+
return None, key
|
107 |
+
|
108 |
+
@property
|
109 |
+
def name(self):
|
110 |
+
return self._name
|
111 |
+
|
112 |
+
@property
|
113 |
+
def scope(self):
|
114 |
+
return self._scope
|
115 |
+
|
116 |
+
@property
|
117 |
+
def module_dict(self):
|
118 |
+
return self._module_dict
|
119 |
+
|
120 |
+
@property
|
121 |
+
def children(self):
|
122 |
+
return self._children
|
123 |
+
|
124 |
+
def get(self, key):
|
125 |
+
"""Get the registry record.
|
126 |
+
Args:
|
127 |
+
key (str): The class name in string format.
|
128 |
+
Returns:
|
129 |
+
class: The corresponding class.
|
130 |
+
"""
|
131 |
+
scope, real_key = self.split_scope_key(key)
|
132 |
+
if scope is None or scope == self._scope:
|
133 |
+
# get from self
|
134 |
+
if real_key in self._module_dict:
|
135 |
+
return self._module_dict[real_key]
|
136 |
+
else:
|
137 |
+
# get from self._children
|
138 |
+
if scope in self._children:
|
139 |
+
return self._children[scope].get(real_key)
|
140 |
+
else:
|
141 |
+
# goto root
|
142 |
+
parent = self.parent
|
143 |
+
while parent.parent is not None:
|
144 |
+
parent = parent.parent
|
145 |
+
return parent.get(key)
|
146 |
+
|
147 |
+
def build(self, *args, **kwargs):
|
148 |
+
return self.build_func(*args, **kwargs, registry=self)
|
149 |
+
|
150 |
+
def _add_children(self, registry):
|
151 |
+
"""Add children for a registry.
|
152 |
+
The ``registry`` will be added as children based on its scope.
|
153 |
+
The parent registry could build objects from children registry.
|
154 |
+
Example:
|
155 |
+
>>> models = Registry('models')
|
156 |
+
>>> mmdet_models = Registry('models', parent=models)
|
157 |
+
>>> @mmdet_models.register_module()
|
158 |
+
>>> class ResNet:
|
159 |
+
>>> pass
|
160 |
+
>>> resnet = models.build(dict(NAME='mmdet.ResNet'))
|
161 |
+
"""
|
162 |
+
|
163 |
+
assert isinstance(registry, Registry)
|
164 |
+
assert registry.scope is not None
|
165 |
+
assert registry.scope not in self.children, \
|
166 |
+
f'scope {registry.scope} exists in {self.name} registry'
|
167 |
+
self.children[registry.scope] = registry
|
168 |
+
|
169 |
+
def _register_module(self, module_class, module_name=None, force=False):
|
170 |
+
if not inspect.isclass(module_class):
|
171 |
+
raise TypeError('module must be a class, '
|
172 |
+
f'but got {type(module_class)}')
|
173 |
+
|
174 |
+
if module_name is None:
|
175 |
+
module_name = module_class.__name__
|
176 |
+
if isinstance(module_name, str):
|
177 |
+
module_name = [module_name]
|
178 |
+
for name in module_name:
|
179 |
+
if not force and name in self._module_dict:
|
180 |
+
raise KeyError(f'{name} is already registered '
|
181 |
+
f'in {self.name}')
|
182 |
+
self._module_dict[name] = module_class
|
183 |
+
|
184 |
+
def deprecated_register_module(self, cls=None, force=False):
|
185 |
+
warnings.warn(
|
186 |
+
'The old API of register_module(module, force=False) '
|
187 |
+
'is deprecated and will be removed, please use the new API '
|
188 |
+
'register_module(name=None, force=False, module=None) instead.')
|
189 |
+
if cls is None:
|
190 |
+
return partial(self.deprecated_register_module, force=force)
|
191 |
+
self._register_module(cls, force=force)
|
192 |
+
return cls
|
193 |
+
|
194 |
+
def register_module(self, name=None, force=False, module=None):
|
195 |
+
"""Register a module.
|
196 |
+
A record will be added to `self._module_dict`, whose key is the class
|
197 |
+
name or the specified name, and value is the class itself.
|
198 |
+
It can be used as a decorator or a normal function.
|
199 |
+
Example:
|
200 |
+
>>> backbones = Registry('backbone')
|
201 |
+
>>> @backbones.register_module()
|
202 |
+
>>> class ResNet:
|
203 |
+
>>> pass
|
204 |
+
>>> backbones = Registry('backbone')
|
205 |
+
>>> @backbones.register_module(name='mnet')
|
206 |
+
>>> class MobileNet:
|
207 |
+
>>> pass
|
208 |
+
>>> backbones = Registry('backbone')
|
209 |
+
>>> class ResNet:
|
210 |
+
>>> pass
|
211 |
+
>>> backbones.register_module(ResNet)
|
212 |
+
Args:
|
213 |
+
name (str | None): The module name to be registered. If not
|
214 |
+
specified, the class name will be used.
|
215 |
+
force (bool, optional): Whether to override an existing class with
|
216 |
+
the same name. Default: False.
|
217 |
+
module (type): Module class to be registered.
|
218 |
+
"""
|
219 |
+
if not isinstance(force, bool):
|
220 |
+
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
221 |
+
# NOTE: This is a walkaround to be compatible with the old api,
|
222 |
+
# while it may introduce unexpected bugs.
|
223 |
+
if isinstance(name, type):
|
224 |
+
return self.deprecated_register_module(name, force=force)
|
225 |
+
|
226 |
+
# raise the error ahead of time
|
227 |
+
if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
|
228 |
+
raise TypeError(
|
229 |
+
'name must be either of None, an instance of str or a sequence'
|
230 |
+
f' of str, but got {type(name)}')
|
231 |
+
|
232 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
233 |
+
if module is not None:
|
234 |
+
self._register_module(
|
235 |
+
module_class=module, module_name=name, force=force)
|
236 |
+
return module
|
237 |
+
|
238 |
+
# use it as a decorator: @x.register_module()
|
239 |
+
def _register(cls):
|
240 |
+
self._register_module(
|
241 |
+
module_class=cls, module_name=name, force=force)
|
242 |
+
return cls
|
243 |
+
|
244 |
+
return _register
|
245 |
+
|
246 |
+
|
247 |
+
def build_from_cfg(cfg, registry, default_args=None):
|
248 |
+
"""Build a module from config dict.
|
249 |
+
Args:
|
250 |
+
cfg (edict): Config dict. It should at least contain the key "NAME".
|
251 |
+
registry (:obj:`Registry`): The registry to search the type from.
|
252 |
+
Returns:
|
253 |
+
object: The constructed object.
|
254 |
+
"""
|
255 |
+
if not isinstance(cfg, dict):
|
256 |
+
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
257 |
+
if 'NAME' not in cfg:
|
258 |
+
if default_args is None or 'NAME' not in default_args:
|
259 |
+
raise KeyError(
|
260 |
+
'`cfg` or `default_args` must contain the key "NAME", '
|
261 |
+
f'but got {cfg}\n{default_args}')
|
262 |
+
if not isinstance(registry, Registry):
|
263 |
+
raise TypeError('registry must be an mmcv.Registry object, '
|
264 |
+
f'but got {type(registry)}')
|
265 |
+
|
266 |
+
if not (isinstance(default_args, dict) or default_args is None):
|
267 |
+
raise TypeError('default_args must be a dict or None, '
|
268 |
+
f'but got {type(default_args)}')
|
269 |
+
|
270 |
+
if default_args is not None:
|
271 |
+
cfg = config.merge_new_config(cfg, default_args)
|
272 |
+
|
273 |
+
obj_type = cfg.get('NAME')
|
274 |
+
|
275 |
+
if isinstance(obj_type, str):
|
276 |
+
obj_cls = registry.get(obj_type)
|
277 |
+
if obj_cls is None:
|
278 |
+
raise KeyError(
|
279 |
+
f'{obj_type} is not in the {registry.name} registry')
|
280 |
+
elif inspect.isclass(obj_type):
|
281 |
+
obj_cls = obj_type
|
282 |
+
else:
|
283 |
+
raise TypeError(
|
284 |
+
f'type must be a str or valid type, but got {type(obj_type)}')
|
285 |
+
try:
|
286 |
+
return obj_cls(cfg)
|
287 |
+
except Exception as e:
|
288 |
+
# Normal TypeError does not print class name.
|
289 |
+
raise type(e)(f'{obj_cls.__name__}: {e}')
|
mm_models/modal_module/point/recon/reconv2_utils/transforms.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
from ReConV2.utils.randaugment import RandAugmentMC
|
4 |
+
__all__ = ['get_transforms']
|
5 |
+
|
6 |
+
|
7 |
+
class ResizeImage():
|
8 |
+
def __init__(self, size):
|
9 |
+
if isinstance(size, int):
|
10 |
+
self.size = (int(size), int(size))
|
11 |
+
else:
|
12 |
+
self.size = size
|
13 |
+
|
14 |
+
def __call__(self, img):
|
15 |
+
th, tw = self.size
|
16 |
+
return img.resize((th, tw))
|
17 |
+
|
18 |
+
|
19 |
+
class PlaceCrop(object):
|
20 |
+
"""Crops the given PIL.Image at the particular index.
|
21 |
+
Args:
|
22 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
23 |
+
int instead of sequence like (w, h), a square crop (size, size) is
|
24 |
+
made.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, size, start_x, start_y):
|
28 |
+
if isinstance(size, int):
|
29 |
+
self.size = (int(size), int(size))
|
30 |
+
else:
|
31 |
+
self.size = size
|
32 |
+
self.start_x = start_x
|
33 |
+
self.start_y = start_y
|
34 |
+
|
35 |
+
def __call__(self, img):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
img (PIL.Image): Image to be cropped.
|
39 |
+
Returns:
|
40 |
+
PIL.Image: Cropped image.
|
41 |
+
"""
|
42 |
+
th, tw = self.size
|
43 |
+
return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th))
|
44 |
+
|
45 |
+
|
46 |
+
class ForceFlip(object):
|
47 |
+
"""Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
|
48 |
+
|
49 |
+
def __call__(self, img):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
img (PIL.Image): Image to be flipped.
|
53 |
+
Returns:
|
54 |
+
PIL.Image: Randomly flipped image.
|
55 |
+
"""
|
56 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
57 |
+
|
58 |
+
|
59 |
+
def transform_train(resize_size=256, crop_size=224):
|
60 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
61 |
+
std=[0.229, 0.224, 0.225])
|
62 |
+
return transforms.Compose([
|
63 |
+
# ResizeImage(resize_size),
|
64 |
+
# transforms.RandomHorizontalFlip(),
|
65 |
+
# transforms.RandomResizedCrop(crop_size, scale=(0.64, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
|
66 |
+
# RandAugmentMC(n=2, m=10),
|
67 |
+
ResizeImage(crop_size),
|
68 |
+
transforms.ToTensor(),
|
69 |
+
normalize
|
70 |
+
])
|
71 |
+
|
72 |
+
|
73 |
+
def get_transforms(resize_size=256, crop_size=224):
|
74 |
+
transforms = {
|
75 |
+
'train': transform_train(resize_size, crop_size)
|
76 |
+
}
|
77 |
+
|
78 |
+
return transforms
|
mm_models/modal_module/point/recon/transformer.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import timm
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .reconv2_utils import misc
|
9 |
+
from .reconv2_utils.logger import *
|
10 |
+
from .reconv2_utils.knn import knn_point
|
11 |
+
from timm.layers import Mlp, DropPath
|
12 |
+
from typing import Optional, List
|
13 |
+
|
14 |
+
|
15 |
+
class PatchEmbedding(nn.Module): # Embedding module
|
16 |
+
def __init__(self, embed_dim, input_channel=3, large=False):
|
17 |
+
super().__init__()
|
18 |
+
self.embed_dim = embed_dim
|
19 |
+
self.input_channel = input_channel
|
20 |
+
|
21 |
+
# embed_dim_list = [c * (embed_dim // 512 + 1) for c in [128, 256, 512]]
|
22 |
+
#
|
23 |
+
# self.first_conv = nn.Sequential(
|
24 |
+
# nn.Conv1d(self.input_channel, embed_dim_list[0], 1),
|
25 |
+
# nn.BatchNorm1d(embed_dim_list[0]),
|
26 |
+
# nn.ReLU(inplace=True),
|
27 |
+
# nn.Conv1d(embed_dim_list[0], embed_dim_list[1], 1)
|
28 |
+
# )
|
29 |
+
# self.second_conv = nn.Sequential(
|
30 |
+
# nn.Conv1d(embed_dim_list[2], embed_dim_list[2], 1),
|
31 |
+
# nn.BatchNorm1d(embed_dim_list[2]),
|
32 |
+
# nn.ReLU(inplace=True),
|
33 |
+
# nn.Conv1d(embed_dim_list[2], self.embed_dim, 1)
|
34 |
+
# )
|
35 |
+
|
36 |
+
if large:
|
37 |
+
self.first_conv = nn.Sequential(
|
38 |
+
nn.Conv1d(self.input_channel, 256, 1),
|
39 |
+
nn.BatchNorm1d(256),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Conv1d(256, 512, 1),
|
42 |
+
nn.BatchNorm1d(512),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
nn.Conv1d(512, 1024, 1)
|
45 |
+
)
|
46 |
+
self.second_conv = nn.Sequential(
|
47 |
+
nn.Conv1d(2048, 2048, 1),
|
48 |
+
nn.BatchNorm1d(2048),
|
49 |
+
nn.ReLU(inplace=True),
|
50 |
+
nn.Conv1d(2048, embed_dim, 1)
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
self.first_conv = nn.Sequential(
|
54 |
+
nn.Conv1d(self.input_channel, 128, 1),
|
55 |
+
nn.BatchNorm1d(128),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Conv1d(128, 256, 1)
|
58 |
+
)
|
59 |
+
self.second_conv = nn.Sequential(
|
60 |
+
nn.Conv1d(512, 512, 1),
|
61 |
+
nn.BatchNorm1d(512),
|
62 |
+
nn.ReLU(inplace=True),
|
63 |
+
nn.Conv1d(512, embed_dim, 1)
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, point_groups):
|
67 |
+
'''
|
68 |
+
point_groups : B G N 3/6
|
69 |
+
-----------------
|
70 |
+
feature_global : B G C
|
71 |
+
'''
|
72 |
+
bs, g, n, _ = point_groups.shape
|
73 |
+
point_groups = point_groups.reshape(bs * g, n, self.input_channel)
|
74 |
+
# encoder
|
75 |
+
feature = self.first_conv(point_groups.transpose(2, 1))
|
76 |
+
feature_global = torch.max(feature, dim=2, keepdim=True)[0]
|
77 |
+
feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
|
78 |
+
feature = self.second_conv(feature)
|
79 |
+
feature_global = torch.max(feature, dim=2, keepdim=False)[0]
|
80 |
+
return feature_global.reshape(bs, g, self.embed_dim)
|
81 |
+
|
82 |
+
|
83 |
+
class PositionEmbeddingCoordsSine(nn.Module):
|
84 |
+
"""Similar to transformer's position encoding, but generalizes it to
|
85 |
+
arbitrary dimensions and continuous coordinates.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
n_dim: Number of input dimensions, e.g. 2 for image coordinates.
|
89 |
+
d_model: Number of dimensions to encode into
|
90 |
+
temperature:
|
91 |
+
scale:
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=1.0, scale=None):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.n_dim = n_dim
|
98 |
+
self.num_pos_feats = d_model // n_dim // 2 * 2
|
99 |
+
self.temperature = temperature
|
100 |
+
self.padding = d_model - self.num_pos_feats * self.n_dim
|
101 |
+
|
102 |
+
if scale is None:
|
103 |
+
scale = 1.0
|
104 |
+
self.scale = scale * 2 * math.pi
|
105 |
+
|
106 |
+
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
xyz: Point positions (*, d_in)
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
pos_emb (*, d_out)
|
113 |
+
"""
|
114 |
+
assert xyz.shape[-1] == self.n_dim
|
115 |
+
|
116 |
+
dim_t = torch.arange(self.num_pos_feats,
|
117 |
+
dtype=torch.float32, device=xyz.device)
|
118 |
+
dim_t = self.temperature ** (2 * torch.div(dim_t,
|
119 |
+
2, rounding_mode='trunc') / self.num_pos_feats)
|
120 |
+
|
121 |
+
xyz = xyz * self.scale
|
122 |
+
pos_divided = xyz.unsqueeze(-1) / dim_t
|
123 |
+
pos_sin = pos_divided[..., 0::2].sin()
|
124 |
+
pos_cos = pos_divided[..., 1::2].cos()
|
125 |
+
pos_emb = torch.stack([pos_sin, pos_cos], dim=-1).reshape(*xyz.shape[:-1], -1)
|
126 |
+
|
127 |
+
# Pad unused dimensions with zeros
|
128 |
+
pos_emb = F.pad(pos_emb, (0, self.padding))
|
129 |
+
return pos_emb
|
130 |
+
|
131 |
+
|
132 |
+
class Group(nn.Module): # FPS + KNN
|
133 |
+
def __init__(self, num_group, group_size):
|
134 |
+
super().__init__()
|
135 |
+
self.num_group = num_group
|
136 |
+
self.group_size = group_size
|
137 |
+
|
138 |
+
def forward(self, pts):
|
139 |
+
'''
|
140 |
+
input: B N 3/6
|
141 |
+
---------------------------
|
142 |
+
output: B G M 3/6
|
143 |
+
center : B G 3
|
144 |
+
'''
|
145 |
+
xyz = pts[:, :, :3]
|
146 |
+
c = pts.shape[2]
|
147 |
+
batch_size, num_points, _ = xyz.shape
|
148 |
+
# fps the centers out
|
149 |
+
xyz = xyz.float()
|
150 |
+
center = misc.fps(xyz.contiguous(), self.num_group) # B G 3
|
151 |
+
# knn to get the neighborhood
|
152 |
+
idx = knn_point(self.group_size, xyz, center)
|
153 |
+
assert idx.size(1) == self.num_group
|
154 |
+
assert idx.size(2) == self.group_size
|
155 |
+
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
|
156 |
+
idx = idx + idx_base
|
157 |
+
idx = idx.view(-1)
|
158 |
+
neighborhood = pts.view(batch_size * num_points, -1)[idx, :]
|
159 |
+
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, c).contiguous()
|
160 |
+
# normalize
|
161 |
+
neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2)
|
162 |
+
return neighborhood, center
|
163 |
+
|
164 |
+
|
165 |
+
class ZGroup(nn.Module):
|
166 |
+
def __init__(self, num_group, group_size):
|
167 |
+
super().__init__()
|
168 |
+
self.num_group = num_group
|
169 |
+
self.group_size = group_size
|
170 |
+
|
171 |
+
def simplied_morton_sorting(self, xyz, center):
|
172 |
+
"""
|
173 |
+
Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient.
|
174 |
+
"""
|
175 |
+
batch_size, num_points, _ = xyz.shape
|
176 |
+
distances_batch = torch.cdist(center, center)
|
177 |
+
distances_batch[:, torch.eye(self.num_group).bool()] = float("inf")
|
178 |
+
idx_base = torch.arange(
|
179 |
+
0, batch_size, device=xyz.device) * self.num_group
|
180 |
+
sorted_indices_list = [idx_base]
|
181 |
+
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
|
182 |
+
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
|
183 |
+
distances_batch[idx_base] = float("inf")
|
184 |
+
distances_batch = distances_batch.view(
|
185 |
+
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
|
186 |
+
for i in range(self.num_group - 1):
|
187 |
+
distances_batch = distances_batch.view(
|
188 |
+
batch_size * self.num_group, self.num_group)
|
189 |
+
distances_to_last_batch = distances_batch[sorted_indices_list[-1]]
|
190 |
+
closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)
|
191 |
+
closest_point_idx = closest_point_idx + idx_base
|
192 |
+
sorted_indices_list.append(closest_point_idx)
|
193 |
+
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
|
194 |
+
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
|
195 |
+
distances_batch[closest_point_idx] = float("inf")
|
196 |
+
distances_batch = distances_batch.view(
|
197 |
+
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
|
198 |
+
sorted_indices = torch.stack(sorted_indices_list, dim=-1)
|
199 |
+
sorted_indices = sorted_indices.view(-1)
|
200 |
+
return sorted_indices
|
201 |
+
|
202 |
+
def forward(self, pts):
|
203 |
+
"""
|
204 |
+
input: B N 3/6
|
205 |
+
---------------------------
|
206 |
+
output: B G M 3/6
|
207 |
+
center : B G 3
|
208 |
+
"""
|
209 |
+
xyz = pts[:, :, :3]
|
210 |
+
c = pts.shape[2]
|
211 |
+
batch_size, num_points, _ = xyz.shape
|
212 |
+
# fps the centers out
|
213 |
+
xyz = xyz.float()
|
214 |
+
center = misc.fps(xyz.contiguous(), self.num_group) # B G 3
|
215 |
+
# knn to get the neighborhood
|
216 |
+
idx = knn_point(self.group_size, xyz, center)
|
217 |
+
assert idx.size(1) == self.num_group
|
218 |
+
assert idx.size(2) == self.group_size
|
219 |
+
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
|
220 |
+
idx = idx + idx_base
|
221 |
+
idx = idx.view(-1)
|
222 |
+
neighborhood = pts.view(batch_size * num_points, -1)[idx, :]
|
223 |
+
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, c).contiguous()
|
224 |
+
# normalize
|
225 |
+
neighborhood[:, :, :, :3] = neighborhood[:, :, :, :3] - center.unsqueeze(2)
|
226 |
+
|
227 |
+
# can utilize morton_sorting by choosing morton_sorting function
|
228 |
+
sorted_indices = self.simplied_morton_sorting(xyz, center)
|
229 |
+
|
230 |
+
neighborhood = neighborhood.view(
|
231 |
+
batch_size * self.num_group, self.group_size, c)[sorted_indices, :, :]
|
232 |
+
neighborhood = neighborhood.view(
|
233 |
+
batch_size, self.num_group, self.group_size, c).contiguous()
|
234 |
+
center = center.view(
|
235 |
+
batch_size * self.num_group, 3)[sorted_indices, :]
|
236 |
+
center = center.view(
|
237 |
+
batch_size, self.num_group, 3).contiguous()
|
238 |
+
|
239 |
+
return neighborhood, center
|
240 |
+
|
241 |
+
|
242 |
+
# Transformers
|
243 |
+
class Attention(nn.Module):
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
dim: int,
|
247 |
+
num_heads: int = 8,
|
248 |
+
qkv_bias: bool = True,
|
249 |
+
qk_norm: bool = False,
|
250 |
+
attn_drop: float = 0.,
|
251 |
+
proj_drop: float = 0.,
|
252 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
253 |
+
) -> None:
|
254 |
+
super().__init__()
|
255 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
256 |
+
self.num_heads = num_heads
|
257 |
+
self.head_dim = dim // num_heads
|
258 |
+
self.scale = self.head_dim ** -0.5
|
259 |
+
|
260 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
261 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
262 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
263 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
264 |
+
self.proj = nn.Linear(dim, dim)
|
265 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
266 |
+
|
267 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
268 |
+
B, N, C = x.shape
|
269 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
270 |
+
q, k, v = qkv.unbind(0)
|
271 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
272 |
+
|
273 |
+
q = q * self.scale
|
274 |
+
attn = q @ k.transpose(-2, -1)
|
275 |
+
if mask is not None:
|
276 |
+
attn = attn.masked_fill(mask, float('-inf'))
|
277 |
+
attn = attn.softmax(dim=-1)
|
278 |
+
attn = self.attn_drop(attn)
|
279 |
+
x = attn @ v
|
280 |
+
|
281 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
282 |
+
x = self.proj(x)
|
283 |
+
x = self.proj_drop(x)
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
class CrossAttention(nn.Module):
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
dim: int,
|
291 |
+
num_heads: int = 8,
|
292 |
+
qkv_bias: bool = True,
|
293 |
+
qk_norm: bool = False,
|
294 |
+
attn_drop: float = 0.,
|
295 |
+
proj_drop: float = 0.,
|
296 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
297 |
+
) -> None:
|
298 |
+
super().__init__()
|
299 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
300 |
+
self.num_heads = num_heads
|
301 |
+
self.head_dim = dim // num_heads
|
302 |
+
self.scale = self.head_dim ** -0.5
|
303 |
+
|
304 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
305 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
306 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
307 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
308 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
309 |
+
self.proj = nn.Linear(dim, dim)
|
310 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
311 |
+
|
312 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
313 |
+
B, N, C = y.shape
|
314 |
+
kv = self.kv(y).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
315 |
+
k, v = kv.unbind(0)
|
316 |
+
|
317 |
+
B, N, C = x.shape
|
318 |
+
q = self.q(x).reshape(B, N, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)[0]
|
319 |
+
|
320 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
321 |
+
q = q * self.scale
|
322 |
+
attn = q @ k.transpose(-2, -1)
|
323 |
+
if mask is not None:
|
324 |
+
attn = attn.masked_fill(mask, float('-inf'))
|
325 |
+
attn = attn.softmax(dim=-1)
|
326 |
+
attn = self.attn_drop(attn)
|
327 |
+
x = attn @ v
|
328 |
+
|
329 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
330 |
+
x = self.proj(x)
|
331 |
+
x = self.proj_drop(x)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
class LayerScale(nn.Module):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
dim: int,
|
339 |
+
init_values: float = 1e-5,
|
340 |
+
inplace: bool = False,
|
341 |
+
) -> None:
|
342 |
+
super().__init__()
|
343 |
+
self.inplace = inplace
|
344 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
347 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
348 |
+
|
349 |
+
|
350 |
+
class Block(nn.Module):
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
dim: int,
|
354 |
+
num_heads: int,
|
355 |
+
mlp_ratio: float = 4.,
|
356 |
+
qkv_bias: bool = True,
|
357 |
+
qk_norm: bool = False,
|
358 |
+
proj_drop: float = 0.,
|
359 |
+
attn_drop: float = 0.,
|
360 |
+
init_values: Optional[float] = None,
|
361 |
+
drop_path: float = 0.,
|
362 |
+
act_layer: nn.Module = nn.GELU,
|
363 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
364 |
+
) -> None:
|
365 |
+
super().__init__()
|
366 |
+
self.norm1 = norm_layer(dim)
|
367 |
+
self.attn = Attention(
|
368 |
+
dim,
|
369 |
+
num_heads=num_heads,
|
370 |
+
qkv_bias=qkv_bias,
|
371 |
+
qk_norm=qk_norm,
|
372 |
+
attn_drop=attn_drop,
|
373 |
+
proj_drop=proj_drop,
|
374 |
+
norm_layer=norm_layer,
|
375 |
+
)
|
376 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
377 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
378 |
+
|
379 |
+
self.norm2 = norm_layer(dim)
|
380 |
+
self.mlp = Mlp(
|
381 |
+
in_features=dim,
|
382 |
+
hidden_features=int(dim * mlp_ratio),
|
383 |
+
act_layer=act_layer,
|
384 |
+
drop=proj_drop,
|
385 |
+
)
|
386 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
387 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
388 |
+
|
389 |
+
def forward(self, x, attn_mask=None):
|
390 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask)))
|
391 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
392 |
+
return x
|
393 |
+
|
394 |
+
|
395 |
+
class CrossBlock(nn.Module):
|
396 |
+
def __init__(
|
397 |
+
self,
|
398 |
+
dim: int,
|
399 |
+
num_heads: int,
|
400 |
+
mlp_ratio: float = 4.,
|
401 |
+
qkv_bias: bool = True,
|
402 |
+
qk_norm: bool = False,
|
403 |
+
proj_drop: float = 0.,
|
404 |
+
attn_drop: float = 0.,
|
405 |
+
init_values: Optional[float] = None,
|
406 |
+
drop_path: float = 0.,
|
407 |
+
act_layer: nn.Module = nn.GELU,
|
408 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
409 |
+
stop_grad: bool = False
|
410 |
+
) -> None:
|
411 |
+
super().__init__()
|
412 |
+
self.norm1 = norm_layer(dim)
|
413 |
+
self.attn = CrossAttention(
|
414 |
+
dim,
|
415 |
+
num_heads=num_heads,
|
416 |
+
qkv_bias=qkv_bias,
|
417 |
+
qk_norm=qk_norm,
|
418 |
+
attn_drop=attn_drop,
|
419 |
+
proj_drop=proj_drop,
|
420 |
+
norm_layer=norm_layer,
|
421 |
+
)
|
422 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
423 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
424 |
+
|
425 |
+
self.norm2 = norm_layer(dim)
|
426 |
+
self.mlp = Mlp(
|
427 |
+
in_features=dim,
|
428 |
+
hidden_features=int(dim * mlp_ratio),
|
429 |
+
act_layer=act_layer,
|
430 |
+
drop=proj_drop,
|
431 |
+
)
|
432 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
433 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
434 |
+
|
435 |
+
self.stop_grad = stop_grad
|
436 |
+
|
437 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
438 |
+
if self.stop_grad:
|
439 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), self.norm1(y.detach()))))
|
440 |
+
else:
|
441 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), self.norm1(y))))
|
442 |
+
|
443 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
444 |
+
return x
|
445 |
+
|
446 |
+
|
447 |
+
class ReConBlocks(nn.Module):
|
448 |
+
def __init__(
|
449 |
+
self,
|
450 |
+
embed_dim: int = 768,
|
451 |
+
depth: int = 12,
|
452 |
+
num_heads: int = 12,
|
453 |
+
mlp_ratio: float = 4.,
|
454 |
+
qkv_bias: bool = True,
|
455 |
+
qk_norm: bool = False,
|
456 |
+
init_values: Optional[float] = None,
|
457 |
+
proj_drop: float = 0.,
|
458 |
+
attn_drop_rate: float = 0.,
|
459 |
+
drop_path_rate: List = [],
|
460 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
461 |
+
act_layer: nn.Module = nn.GELU,
|
462 |
+
stop_grad: bool = False,
|
463 |
+
pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
|
464 |
+
every_layer_add_pos: bool = True,
|
465 |
+
):
|
466 |
+
super().__init__()
|
467 |
+
|
468 |
+
self.depth = depth
|
469 |
+
self.stop_grad = stop_grad
|
470 |
+
self.pretrained_model_name = pretrained_model_name
|
471 |
+
self.every_layer_add_pos = every_layer_add_pos
|
472 |
+
if 'dino' in self.pretrained_model_name:
|
473 |
+
init_values = 1e-5
|
474 |
+
if 'giant' in self.pretrained_model_name:
|
475 |
+
mlp_ratio = 48 / 11
|
476 |
+
self.local_blocks = nn.Sequential(*[
|
477 |
+
Block(
|
478 |
+
dim=embed_dim,
|
479 |
+
num_heads=num_heads,
|
480 |
+
mlp_ratio=mlp_ratio,
|
481 |
+
qkv_bias=qkv_bias,
|
482 |
+
qk_norm=qk_norm,
|
483 |
+
init_values=init_values,
|
484 |
+
proj_drop=proj_drop,
|
485 |
+
attn_drop=attn_drop_rate,
|
486 |
+
drop_path=drop_path_rate[i],
|
487 |
+
norm_layer=norm_layer,
|
488 |
+
act_layer=act_layer
|
489 |
+
)
|
490 |
+
for i in range(depth)])
|
491 |
+
|
492 |
+
self.global_blocks = nn.Sequential(*[
|
493 |
+
CrossBlock(
|
494 |
+
dim=embed_dim,
|
495 |
+
num_heads=num_heads,
|
496 |
+
mlp_ratio=mlp_ratio,
|
497 |
+
qkv_bias=qkv_bias,
|
498 |
+
qk_norm=qk_norm,
|
499 |
+
init_values=init_values,
|
500 |
+
proj_drop=proj_drop,
|
501 |
+
attn_drop=attn_drop_rate,
|
502 |
+
drop_path=drop_path_rate[i],
|
503 |
+
norm_layer=norm_layer,
|
504 |
+
act_layer=act_layer,
|
505 |
+
stop_grad=stop_grad
|
506 |
+
)
|
507 |
+
for i in range(depth)])
|
508 |
+
|
509 |
+
def load_pretrained_timm_weights(self):
|
510 |
+
model = timm.create_model(self.pretrained_model_name, pretrained=True)
|
511 |
+
state_dict = model.blocks.state_dict()
|
512 |
+
self.local_blocks.load_state_dict(state_dict, strict=True)
|
513 |
+
|
514 |
+
cross_state_dict = {}
|
515 |
+
for k, v in state_dict.items():
|
516 |
+
if 'qkv' in k:
|
517 |
+
cross_state_dict[k.replace('qkv', 'q')] = v[:int(v.shape[0] / 3)]
|
518 |
+
cross_state_dict[k.replace('qkv', 'kv')] = v[int(v.shape[0] / 3):]
|
519 |
+
else:
|
520 |
+
cross_state_dict[k] = v
|
521 |
+
self.global_blocks.load_state_dict(cross_state_dict, strict=True)
|
522 |
+
|
523 |
+
def forward(self, x, pos, attn_mask=None, query=None):
|
524 |
+
if self.every_layer_add_pos:
|
525 |
+
for i in range(self.depth):
|
526 |
+
x = self.local_blocks[i](x + pos, attn_mask)
|
527 |
+
if query is not None:
|
528 |
+
query = self.global_blocks[i](query, x)
|
529 |
+
else:
|
530 |
+
x = x + pos
|
531 |
+
for i in range(self.depth):
|
532 |
+
x = self.local_blocks[i](x, attn_mask)
|
533 |
+
if query is not None:
|
534 |
+
query = self.global_blocks[i](query, x)
|
535 |
+
return x, query
|
536 |
+
|
537 |
+
|
538 |
+
class GPTExtractor(nn.Module):
|
539 |
+
def __init__(
|
540 |
+
self,
|
541 |
+
embed_dim: int = 768,
|
542 |
+
num_heads: int = 12,
|
543 |
+
depth: int = 12,
|
544 |
+
group_size: int = 32,
|
545 |
+
drop_path_rate: float = 0.0,
|
546 |
+
stop_grad: bool = False,
|
547 |
+
pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
|
548 |
+
):
|
549 |
+
super(GPTExtractor, self).__init__()
|
550 |
+
|
551 |
+
self.embed_dim = embed_dim
|
552 |
+
self.group_size = group_size
|
553 |
+
|
554 |
+
# start of sequence token
|
555 |
+
self.sos = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
556 |
+
self.sos_pos = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
557 |
+
nn.init.normal_(self.sos)
|
558 |
+
nn.init.normal_(self.sos_pos)
|
559 |
+
|
560 |
+
drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
561 |
+
self.blocks = ReConBlocks(
|
562 |
+
embed_dim=embed_dim,
|
563 |
+
num_heads=num_heads,
|
564 |
+
depth=depth,
|
565 |
+
drop_path_rate=drop_path_rate,
|
566 |
+
stop_grad=stop_grad,
|
567 |
+
pretrained_model_name=pretrained_model_name,
|
568 |
+
)
|
569 |
+
|
570 |
+
self.ln_f1 = nn.LayerNorm(embed_dim)
|
571 |
+
self.ln_f2 = nn.LayerNorm(embed_dim)
|
572 |
+
|
573 |
+
def forward(self, x, pos, attn_mask, query):
|
574 |
+
"""
|
575 |
+
Expect input as shape [sequence len, batch]
|
576 |
+
"""
|
577 |
+
|
578 |
+
batch, length, _ = x.shape
|
579 |
+
|
580 |
+
# prepend sos token
|
581 |
+
sos = self.sos.expand(batch, -1, -1)
|
582 |
+
sos_pos = self.sos_pos.expand(batch, -1, -1)
|
583 |
+
|
584 |
+
x = torch.cat([sos, x[:, :-1]], dim=1)
|
585 |
+
pos = torch.cat([sos_pos, pos[:, :-1]], dim=1)
|
586 |
+
|
587 |
+
# transformer
|
588 |
+
x, query = self.blocks(x, pos, attn_mask, query)
|
589 |
+
|
590 |
+
encoded_points = self.ln_f1(x)
|
591 |
+
query = self.ln_f2(query)
|
592 |
+
|
593 |
+
return encoded_points, query
|
594 |
+
|
595 |
+
|
596 |
+
class MAEExtractor(nn.Module):
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
embed_dim: int = 768,
|
600 |
+
num_heads: int = 12,
|
601 |
+
depth: int = 12,
|
602 |
+
group_size: int = 32,
|
603 |
+
drop_path_rate: float = 0.0,
|
604 |
+
stop_grad: bool = False,
|
605 |
+
pretrained_model_name: str = 'vit_base_patch32_clip_224.openai',
|
606 |
+
):
|
607 |
+
super(MAEExtractor, self).__init__()
|
608 |
+
|
609 |
+
self.embed_dim = embed_dim
|
610 |
+
self.group_size = group_size
|
611 |
+
|
612 |
+
drop_path_rate = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
613 |
+
self.blocks = ReConBlocks(
|
614 |
+
embed_dim=embed_dim,
|
615 |
+
num_heads=num_heads,
|
616 |
+
depth=depth,
|
617 |
+
drop_path_rate=drop_path_rate,
|
618 |
+
stop_grad=stop_grad,
|
619 |
+
pretrained_model_name=pretrained_model_name,
|
620 |
+
)
|
621 |
+
|
622 |
+
self.ln_f1 = nn.LayerNorm(embed_dim)
|
623 |
+
self.ln_f2 = nn.LayerNorm(embed_dim)
|
624 |
+
|
625 |
+
def forward(self, x, pos, mask=None, query=None):
|
626 |
+
"""
|
627 |
+
Expect input as shape [sequence len, batch]
|
628 |
+
"""
|
629 |
+
|
630 |
+
batch, length, C = x.shape
|
631 |
+
if mask is not None:
|
632 |
+
x_vis = x[~mask].reshape(batch, -1, C)
|
633 |
+
pos_vis = pos[~mask].reshape(batch, -1, C)
|
634 |
+
else:
|
635 |
+
x_vis = x
|
636 |
+
pos_vis = pos
|
637 |
+
|
638 |
+
# transformer
|
639 |
+
x_vis, query = self.blocks(x_vis, pos_vis, None, query)
|
640 |
+
|
641 |
+
encoded_points = self.ln_f1(x_vis)
|
642 |
+
query = self.ln_f2(query)
|
643 |
+
|
644 |
+
return encoded_points, query
|
645 |
+
|
646 |
+
|
647 |
+
|
mm_models/modal_module/point/reconv2.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from easydict import EasyDict
|
3 |
+
import timm
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .recon.transformer import Group, ZGroup, PatchEmbedding, PositionEmbeddingCoordsSine, GPTExtractor, MAEExtractor
|
8 |
+
|
9 |
+
|
10 |
+
POINT_RECON2_MODAL_CFG = {
|
11 |
+
'modal_tag': 'point',
|
12 |
+
'modal_placeholder_token': "<|point_placeholder|>",
|
13 |
+
'model_path': None,
|
14 |
+
'group_size': 32,
|
15 |
+
'num_group': 512,
|
16 |
+
'mask_type': 'rand',
|
17 |
+
'embed_dim': 1024,
|
18 |
+
'depth': 24,
|
19 |
+
'drop_path_rate': 0.1,
|
20 |
+
'num_heads': 16,
|
21 |
+
'with_color': True,
|
22 |
+
'stop_grad': False,
|
23 |
+
'large_embedding': False,
|
24 |
+
'img_queries': 13,
|
25 |
+
'text_queries': 3,
|
26 |
+
'pretrained_model_name': 'eva_large_patch14_336.in22k_ft_in22k_in1k',
|
27 |
+
'output_dim': 896
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class MaskTransformer(nn.Module):
|
32 |
+
def __init__(self, config):
|
33 |
+
super(MaskTransformer, self).__init__()
|
34 |
+
|
35 |
+
self.embed_dim = config.embed_dim
|
36 |
+
self.num_group = config.num_group
|
37 |
+
self.group_size = config.group_size
|
38 |
+
self.with_color = config.with_color
|
39 |
+
self.input_channel = 6 if self.with_color else 3
|
40 |
+
self.img_queries = config.img_queries
|
41 |
+
self.text_queries = config.text_queries
|
42 |
+
self.global_query_num = self.img_queries + self.text_queries
|
43 |
+
self.mask_type = config.mask_type
|
44 |
+
self.stop_grad = config.stop_grad
|
45 |
+
|
46 |
+
self.embed = PatchEmbedding(embed_dim=self.embed_dim, input_channel=self.input_channel,
|
47 |
+
large=config.large_embedding)
|
48 |
+
|
49 |
+
print(f'[ReCon] divide point cloud into G{config.num_group} x S{config.group_size} points ...')
|
50 |
+
|
51 |
+
if self.mask_type == 'causal':
|
52 |
+
self.group_divider = ZGroup(num_group=config.num_group, group_size=config.group_size)
|
53 |
+
self.encoder = GPTExtractor(
|
54 |
+
embed_dim=config.embed_dim,
|
55 |
+
num_heads=config.num_heads,
|
56 |
+
depth=config.depth,
|
57 |
+
group_size=config.group_size,
|
58 |
+
drop_path_rate=config.drop_path_rate,
|
59 |
+
stop_grad=self.stop_grad,
|
60 |
+
pretrained_model_name=config.pretrained_model_name,
|
61 |
+
)
|
62 |
+
self.pos_embed = PositionEmbeddingCoordsSine(3, self.embed_dim, 1.0)
|
63 |
+
else:
|
64 |
+
self.group_divider = Group(num_group=config.num_group, group_size=config.group_size)
|
65 |
+
self.encoder = MAEExtractor(
|
66 |
+
embed_dim=config.embed_dim,
|
67 |
+
num_heads=config.num_heads,
|
68 |
+
depth=config.depth,
|
69 |
+
group_size=config.group_size,
|
70 |
+
drop_path_rate=config.drop_path_rate,
|
71 |
+
stop_grad=self.stop_grad,
|
72 |
+
pretrained_model_name=config.pretrained_model_name,
|
73 |
+
)
|
74 |
+
self.pos_embed = nn.Sequential(
|
75 |
+
nn.Linear(3, 128),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Linear(128, self.embed_dim)
|
78 |
+
)
|
79 |
+
|
80 |
+
self.norm = nn.LayerNorm(self.embed_dim)
|
81 |
+
self.global_query = nn.Parameter(torch.zeros(1, self.global_query_num, self.embed_dim))
|
82 |
+
self.apply(self._init_weights)
|
83 |
+
|
84 |
+
self.num_group = config.num_group
|
85 |
+
|
86 |
+
def _init_weights(self, m):
|
87 |
+
if isinstance(m, nn.Linear):
|
88 |
+
nn.init.normal_(m.weight, 0.02, 0.01)
|
89 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
90 |
+
nn.init.constant_(m.bias, 0)
|
91 |
+
elif isinstance(m, nn.BatchNorm1d):
|
92 |
+
nn.init.constant_(m.bias, 0)
|
93 |
+
nn.init.constant_(m.weight, 1.0)
|
94 |
+
|
95 |
+
def inference(self, pts):
|
96 |
+
with torch.no_grad():
|
97 |
+
neighborhood, center = self.group_divider(pts)
|
98 |
+
group_input_tokens = self.embed(neighborhood) # B G C
|
99 |
+
batch_size, seq_len, C = group_input_tokens.size()
|
100 |
+
|
101 |
+
global_query = self.global_query.expand(batch_size, -1, -1)
|
102 |
+
pos = self.pos_embed(center.to(group_input_tokens.dtype))
|
103 |
+
|
104 |
+
mask = torch.full(
|
105 |
+
(seq_len, seq_len), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
|
106 |
+
).to(torch.bool)
|
107 |
+
if self.mask_type == 'causal':
|
108 |
+
mask = torch.triu(mask, diagonal=1)
|
109 |
+
else:
|
110 |
+
mask = None
|
111 |
+
|
112 |
+
local_features, global_features = self.encoder(
|
113 |
+
group_input_tokens, pos, mask, global_query)
|
114 |
+
|
115 |
+
return pos, local_features, global_features
|
116 |
+
|
117 |
+
|
118 |
+
class ReCon2(nn.Module):
|
119 |
+
def __init__(self, config):
|
120 |
+
super().__init__()
|
121 |
+
self.config = config
|
122 |
+
self.embed_dim = config.embed_dim
|
123 |
+
self.with_color = config.with_color
|
124 |
+
self.img_queries = config.img_queries
|
125 |
+
self.text_queries = config.text_queries
|
126 |
+
self.global_query_num = self.img_queries + self.text_queries
|
127 |
+
self.input_channel = 6 if self.with_color else 3
|
128 |
+
|
129 |
+
self.model = MaskTransformer(config)
|
130 |
+
|
131 |
+
self.img_proj = nn.Linear(self.embed_dim, 1280)
|
132 |
+
self.img_proj.apply(self._init_weights)
|
133 |
+
self.text_proj = nn.Linear(self.embed_dim, 1280)
|
134 |
+
self.text_proj.apply(self._init_weights)
|
135 |
+
|
136 |
+
def _init_weights(self, m):
|
137 |
+
if isinstance(m, nn.Linear):
|
138 |
+
nn.init.normal_(m.weight, 0.02, 0.01)
|
139 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
140 |
+
nn.init.constant_(m.bias, 0)
|
141 |
+
elif isinstance(m, nn.BatchNorm1d):
|
142 |
+
nn.init.constant_(m.bias, 0)
|
143 |
+
nn.init.constant_(m.weight, 1.0)
|
144 |
+
|
145 |
+
@property
|
146 |
+
def device(self):
|
147 |
+
return next(self.parameters()).device
|
148 |
+
|
149 |
+
@property
|
150 |
+
def dtype(self):
|
151 |
+
return next(self.parameters()).dtype
|
152 |
+
|
153 |
+
|
154 |
+
class ReConv2PointEncoder(nn.Module):
|
155 |
+
|
156 |
+
def __init__(self, config):
|
157 |
+
super().__init__()
|
158 |
+
self.config = config
|
159 |
+
self.vision_tower = ReCon2(self.config)
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
def forward(self, pts):
|
163 |
+
|
164 |
+
pts = torch.stack(pts, dim=0)
|
165 |
+
pos_features, local_features, global_features = \
|
166 |
+
self.vision_tower.model.inference(pts.to(device=self.device, dtype=self.dtype))
|
167 |
+
local_features = local_features.to(pts.dtype)
|
168 |
+
global_features = global_features.to(pts.dtype)
|
169 |
+
|
170 |
+
return (pos_features, local_features, global_features)
|
171 |
+
|
172 |
+
@property
|
173 |
+
def dtype(self):
|
174 |
+
return self.vision_tower.dtype
|
175 |
+
|
176 |
+
@property
|
177 |
+
def device(self):
|
178 |
+
return self.vision_tower.device
|
179 |
+
|
180 |
+
|
181 |
+
class ReConProjector_MLP(nn.Module):
|
182 |
+
def __init__(self, in_channels, out_channels, mlp_depth, prompt_token_num,
|
183 |
+
with_ape=True, with_local=True, with_global=True):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.in_channels = in_channels
|
187 |
+
self.out_channels = out_channels
|
188 |
+
self.mlp_depth = mlp_depth
|
189 |
+
self.prompt_token_num = prompt_token_num
|
190 |
+
self.with_ape = with_ape
|
191 |
+
self.with_local = with_local
|
192 |
+
self.with_global = with_global
|
193 |
+
|
194 |
+
if prompt_token_num > 0:
|
195 |
+
self.prompt1 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
|
196 |
+
self.prompt2 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
|
197 |
+
self.prompt3 = nn.Parameter(torch.zeros(1, prompt_token_num, out_channels))
|
198 |
+
|
199 |
+
self.proj1 = self.set_proj()
|
200 |
+
self.proj2 = self.set_proj()
|
201 |
+
self.proj3 = self.set_proj()
|
202 |
+
|
203 |
+
def set_proj(self):
|
204 |
+
modules = [nn.Linear(self.in_channels, self.out_channels)]
|
205 |
+
for i in range(1, self.mlp_depth):
|
206 |
+
modules.append(nn.GELU())
|
207 |
+
modules.append(nn.Linear(self.out_channels, self.out_channels))
|
208 |
+
modules.append(nn.GELU())
|
209 |
+
modules.append(nn.Linear(self.out_channels, self.out_channels))
|
210 |
+
return nn.Sequential(*modules)
|
211 |
+
|
212 |
+
def forward(self, proj_inps):
|
213 |
+
pos_feat, local_feat, global_feat = proj_inps
|
214 |
+
B = pos_feat.shape[0]
|
215 |
+
pos_feat = self.proj1(pos_feat)
|
216 |
+
local_feat = self.proj2(local_feat)
|
217 |
+
global_feat = self.proj3(global_feat)
|
218 |
+
|
219 |
+
if self.prompt_token_num > 0:
|
220 |
+
pos_feat = torch.cat([self.prompt1.expand(B, -1, -1), pos_feat], dim=1)
|
221 |
+
local_feat = torch.cat([self.prompt2.expand(B, -1, -1), local_feat], dim=1)
|
222 |
+
global_feat = torch.cat([self.prompt3.expand(B, -1, -1), global_feat], dim=1)
|
223 |
+
|
224 |
+
pts_feat = [feat for feat, flag in [(pos_feat, self.with_ape), (local_feat, self.with_local), (global_feat, self.with_global)] if flag]
|
225 |
+
pts_feat = torch.cat(pts_feat, dim=1)
|
226 |
+
|
227 |
+
pts_feat = torch.split(pts_feat, 1)
|
228 |
+
pts_feat = [item.squeeze() for item in pts_feat]
|
229 |
+
return pts_feat
|
230 |
+
|
231 |
+
|
232 |
+
def build_point_encoder(modal_cfg):
|
233 |
+
assert modal_cfg['modal_tag'] == 'point', f"building point encoder with '{modal_cfg['modal_tag']}' tag is not supported"
|
234 |
+
if "encoder_cfg" not in modal_cfg: # init
|
235 |
+
cfg = EasyDict(modal_cfg)
|
236 |
+
model = ReConv2PointEncoder(cfg)
|
237 |
+
print(f"loading point encoder from {modal_cfg['model_path']}")
|
238 |
+
model_path = modal_cfg['model_path']
|
239 |
+
model.vision_tower.load_state_dict(torch.load(model_path, map_location='cpu'), strict=True)
|
240 |
+
else:
|
241 |
+
cfg = EasyDict(modal_cfg["encoder_cfg"])
|
242 |
+
model = ReConv2PointEncoder(cfg)
|
243 |
+
return model
|
244 |
+
|
245 |
+
def build_point_projector(modal_cfg):
|
246 |
+
assert modal_cfg['modal_tag'] == 'point', f"building point projector with '{modal_cfg['modal_tag']}' tag is not supported"
|
247 |
+
if "encoder_cfg" in modal_cfg:
|
248 |
+
proj_cfg = EasyDict(modal_cfg["encoder_cfg"])
|
249 |
+
else:
|
250 |
+
proj_cfg = EasyDict(modal_cfg)
|
251 |
+
projector = ReConProjector_MLP(in_channels=proj_cfg.embed_dim,
|
252 |
+
out_channels=proj_cfg.output_dim,
|
253 |
+
mlp_depth=2, prompt_token_num=1)
|
254 |
+
return projector
|
255 |
+
|
256 |
+
|
257 |
+
# if __name__ == '__main__':
|
258 |
+
|
259 |
+
# encoder = build_point_encoder(POINT_RECON2_MODAL_CFG)
|
260 |
+
# print(encoder)
|
261 |
+
# data = torch.randn((1, 8192, 6))
|
262 |
+
# pos_features, local_features, global_features = encoder(data)
|
263 |
+
# print(pos_features.shape, local_features.shape, global_features.shape)
|
264 |
+
# proj = build_point_projector(POINT_RECON2_MODAL_CFG)
|
265 |
+
# point_token = proj((pos_features, local_features, global_features))
|
266 |
+
# print(point_token.shape)
|
mm_models/modal_module/vision/__pycache__/siglip.cpython-310.pyc
ADDED
Binary file (3.92 kB). View file
|
|
mm_models/modal_module/vision/siglip.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import SiglipVisionModel, SiglipVisionConfig
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
import einops
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.nn.init as init
|
7 |
+
|
8 |
+
|
9 |
+
VISION_SIGLIP_MODAL_CFG = {
|
10 |
+
'modal_tag': 'vision',
|
11 |
+
'model_name_or_path': None,
|
12 |
+
'modal_placeholder_token': "<|vision_placeholder|>",
|
13 |
+
'proj_num_layers': 2,
|
14 |
+
'proj_input_dim': 1152,
|
15 |
+
'proj_output_dim': 896, # for qwen2.5 0.5B
|
16 |
+
'multi_grid': False
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
class SiglipVisionModelForMM(SiglipVisionModel):
|
21 |
+
|
22 |
+
def forward(self, pixel_values, **kwargs):
|
23 |
+
|
24 |
+
assert type(pixel_values) == list
|
25 |
+
|
26 |
+
split_sizes = []
|
27 |
+
temp = []
|
28 |
+
for pixel_value in pixel_values:
|
29 |
+
if pixel_value.dim() == 3:
|
30 |
+
pixel_value = pixel_value.unsqueeze(0)
|
31 |
+
temp.append(pixel_value)
|
32 |
+
split_sizes.append(pixel_value.shape[0])
|
33 |
+
pixel_values = torch.cat(temp, dim=0) # (BG) 3 H W
|
34 |
+
|
35 |
+
outputs = super().forward(pixel_values, output_hidden_states=True, **kwargs)
|
36 |
+
hidden_states = outputs['hidden_states'][-2]
|
37 |
+
|
38 |
+
return (hidden_states, split_sizes)
|
39 |
+
|
40 |
+
|
41 |
+
class VisionProjector_MLP(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, proj_num_layers, proj_input_dim, proj_output_dim, multi_grid=False):
|
44 |
+
super().__init__()
|
45 |
+
_proj_input_dim = int(4*proj_input_dim)
|
46 |
+
module = [nn.Linear(_proj_input_dim, proj_output_dim)]
|
47 |
+
for _ in range(proj_num_layers - 1):
|
48 |
+
module.append(nn.GELU())
|
49 |
+
module.append(nn.Linear(proj_output_dim, proj_output_dim))
|
50 |
+
module.append(nn.GELU())
|
51 |
+
module.append(nn.Linear(proj_output_dim, proj_output_dim))
|
52 |
+
self.module = nn.Sequential(*module)
|
53 |
+
|
54 |
+
self.resample_pad_token = nn.Parameter(torch.randn((1, proj_input_dim)))
|
55 |
+
init.kaiming_normal_(self.resample_pad_token)
|
56 |
+
|
57 |
+
self.multi_grid = multi_grid
|
58 |
+
if self.multi_grid:
|
59 |
+
self.grid_sep = nn.Parameter(torch.randn((1, proj_output_dim)))
|
60 |
+
init.kaiming_normal_(self.grid_sep)
|
61 |
+
|
62 |
+
def forward(self, encoder_output):
|
63 |
+
visual_tokens, split_sizes = encoder_output
|
64 |
+
|
65 |
+
B, L, D = visual_tokens.shape
|
66 |
+
|
67 |
+
# Pooling
|
68 |
+
n_patch = int(L**0.5)
|
69 |
+
|
70 |
+
visual_tokens = einops.rearrange(visual_tokens, "B (h w) D -> B D h w", h=n_patch, w=n_patch)
|
71 |
+
if n_patch % 2 != 0:
|
72 |
+
visual_tokens = F.pad(visual_tokens, (0, 1, 0, 1), value=0)
|
73 |
+
visual_tokens = einops.rearrange(visual_tokens, "B D h w -> B h w D")
|
74 |
+
visual_tokens[:, -1, -1, :] = self.resample_pad_token.expand(B, -1)
|
75 |
+
n_patch += 1
|
76 |
+
|
77 |
+
visual_tokens = visual_tokens.view(B, n_patch // 2, 2, n_patch // 2, 2, D) # (B, n//2, 2, n//2, 2, D)
|
78 |
+
visual_tokens = visual_tokens.permute(0, 1, 3, 2, 4, 5) # (B, n//2, n//2, 2, 2, D)
|
79 |
+
visual_tokens = visual_tokens.contiguous().view(B, n_patch // 2, n_patch // 2, D * 4) # (B, n//2, n//2, D*4)
|
80 |
+
|
81 |
+
visual_tokens = einops.rearrange(visual_tokens, "B h w D -> B (h w) D")
|
82 |
+
|
83 |
+
visual_tokens = self.module(visual_tokens)
|
84 |
+
|
85 |
+
# Grid
|
86 |
+
if self.multi_grid:
|
87 |
+
visual_tokens = torch.split(visual_tokens, split_sizes) # B [G n D]
|
88 |
+
visual_tokens_list = []
|
89 |
+
for grid_visual_tokens in visual_tokens:
|
90 |
+
grid_visual_tokens = torch.cat([grid_visual_tokens,
|
91 |
+
self.grid_sep.repeat(grid_visual_tokens.shape[0], 1, 1)], dim=1)
|
92 |
+
grid_visual_tokens = einops.rearrange(grid_visual_tokens, "G n D -> (G n) D")[:-1, :]
|
93 |
+
visual_tokens_list.append(grid_visual_tokens)
|
94 |
+
else:
|
95 |
+
visual_tokens_list = torch.split(visual_tokens, 1)
|
96 |
+
visual_tokens_list = [item.squeeze() for item in visual_tokens_list]
|
97 |
+
|
98 |
+
return visual_tokens_list
|
99 |
+
|
100 |
+
|
101 |
+
def build_vision_encoder(modal_cfg):
|
102 |
+
assert modal_cfg['modal_tag'] == 'vision', f"building vision encoder with '{modal_cfg['modal_tag']}' tag is not supported"
|
103 |
+
if "encoder_cfg" not in modal_cfg: # from pretrained
|
104 |
+
model = SiglipVisionModelForMM.from_pretrained(modal_cfg['model_name_or_path'])
|
105 |
+
else:
|
106 |
+
cfg = SiglipVisionConfig(**modal_cfg['encoder_cfg'])
|
107 |
+
model = SiglipVisionModelForMM._from_config(cfg)
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
def build_vision_projector(modal_cfg):
|
112 |
+
assert modal_cfg['modal_tag'] == 'vision', f"building vision projector with '{modal_cfg['modal_tag']}' tag is not supported"
|
113 |
+
return VisionProjector_MLP(modal_cfg['proj_num_layers'], modal_cfg['proj_input_dim'], modal_cfg['proj_output_dim'],
|
114 |
+
modal_cfg['multi_grid'])
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == '__main__':
|
118 |
+
projector = VisionProjector_MLP(2, 1152, 2048)
|
119 |
+
inputs = torch.randn((5, 729, 1152))
|
120 |
+
outputs = projector(inputs)
|
121 |
+
print(outputs.shape)
|
122 |
+
|
mm_models/modeling_mm.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import PretrainedConfig, PreTrainedModel, Qwen2Config
|
4 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
5 |
+
from transformers.generation.utils import GenerateOutput
|
6 |
+
from .configuration_mm import AllSparkConfig
|
7 |
+
from .modal_module import MODAL_ENCODERS_MAPPING, MODAL_PROJECTORS_MAPPING
|
8 |
+
from .llms.qwen_model_moe import Qwen2ForCausalLMMoE
|
9 |
+
from typing import Optional, List, Union, Tuple
|
10 |
+
import torch.nn.init as init
|
11 |
+
from utils import rank0_print
|
12 |
+
|
13 |
+
|
14 |
+
class AllSparkPreTrainedModel(PreTrainedModel):
|
15 |
+
|
16 |
+
config_class = AllSparkConfig
|
17 |
+
base_model_prefix = "allspark"
|
18 |
+
|
19 |
+
def _init_weights(self, module):
|
20 |
+
"""Initialize the weights"""
|
21 |
+
std = (
|
22 |
+
self.config.initializer_range
|
23 |
+
if hasattr(self.config, "initializer_range")
|
24 |
+
else 0.02
|
25 |
+
)
|
26 |
+
|
27 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
28 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
29 |
+
if module.bias is not None:
|
30 |
+
module.bias.data.zero_()
|
31 |
+
elif isinstance(module, nn.Embedding):
|
32 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
33 |
+
if module.padding_idx is not None:
|
34 |
+
module.weight.data[module.padding_idx].zero_()
|
35 |
+
|
36 |
+
def prepare_multimodal_inputs(self, input_ids, modal_inputs, labels, attention_mask):
|
37 |
+
if modal_inputs is None:
|
38 |
+
return input_ids, None, labels, attention_mask, None
|
39 |
+
|
40 |
+
modal_tensors = dict()
|
41 |
+
for single_sample_modal_inputs in modal_inputs:
|
42 |
+
for tag, modal_tensor in single_sample_modal_inputs:
|
43 |
+
if tag in modal_tensors:
|
44 |
+
modal_tensors[tag].append(modal_tensor)
|
45 |
+
else:
|
46 |
+
modal_tensors[tag] = [modal_tensor]
|
47 |
+
for tag in modal_tensors:
|
48 |
+
modal_tensors[tag] = self.modal_projectors[tag](self.modal_encoders[tag](modal_tensors[tag])) # B [N D]
|
49 |
+
for sample_id, single_sample_modal_inputs in enumerate(modal_inputs):
|
50 |
+
for modal_id, (tag, _) in enumerate(single_sample_modal_inputs):
|
51 |
+
modal_inputs[sample_id][modal_id] = (tag, modal_tensors[tag].pop(0).squeeze(0))
|
52 |
+
for tag in modal_tensors:
|
53 |
+
assert len(modal_tensors[tag]) == 0
|
54 |
+
|
55 |
+
if attention_mask is None:
|
56 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
57 |
+
else:
|
58 |
+
attention_mask = attention_mask.bool()
|
59 |
+
if labels is None:
|
60 |
+
labels = torch.full_like(input_ids, self.config.ignore_index)
|
61 |
+
|
62 |
+
# input_ids: (batch_size, seq_len)
|
63 |
+
# modal_inputs: (B, M, Tuple[str, Tensor])
|
64 |
+
# labels: (batch_size, seq_len)
|
65 |
+
assert input_ids.shape[0] == len(modal_inputs) == labels.shape[0], \
|
66 |
+
f"Batch size mismatch: {input_ids.shape[0]} vs {len(modal_inputs)} vs {labels.shape[0]}" \
|
67 |
+
"If some sample has no modal inputs, please append a empty list to modal_inputs."
|
68 |
+
|
69 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
70 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
71 |
+
|
72 |
+
modal_tag_pos_list = []
|
73 |
+
new_input_embeds = []
|
74 |
+
new_labels = []
|
75 |
+
for single_sample_input_ids, single_sample_modal_inputs, single_sample_labels in zip(input_ids, modal_inputs, labels):
|
76 |
+
tag_num = dict()
|
77 |
+
cur_id = 0
|
78 |
+
single_sample_embeds = []
|
79 |
+
_labels = []
|
80 |
+
single_sample_modal_tag_pos_list = []
|
81 |
+
for modal_input in single_sample_modal_inputs:
|
82 |
+
tag, modal_tensor = modal_input
|
83 |
+
if tag not in self.modal_tags:
|
84 |
+
raise ValueError(f"Unknown modal tag: {tag}. Vaid modal tags: {self.modal_tags}")
|
85 |
+
if tag not in tag_num:
|
86 |
+
tag_num[tag] = 0
|
87 |
+
|
88 |
+
for modal_config in self.config.modal_configs:
|
89 |
+
if modal_config["modal_tag"] == tag:
|
90 |
+
modal_placeholder_token_id = modal_config["modal_placeholder_token_id"]
|
91 |
+
break
|
92 |
+
|
93 |
+
cur_modal_idx = torch.where(single_sample_input_ids == modal_placeholder_token_id)[0].tolist()[tag_num[tag]]
|
94 |
+
|
95 |
+
single_sample_embeds.append(self.llm.get_input_embeddings()(single_sample_input_ids[cur_id:cur_modal_idx]))
|
96 |
+
single_sample_embeds.append(self.modal_embeds[tag][0:1, :]) # start embed
|
97 |
+
single_sample_embeds.append(modal_tensor)
|
98 |
+
single_sample_embeds.append(self.modal_embeds[tag][1:2, :]) # end embed
|
99 |
+
|
100 |
+
_labels.append(single_sample_labels[cur_id:cur_modal_idx])
|
101 |
+
_labels.append(torch.full((modal_tensor.shape[0]+2,), self.config.ignore_index, device=single_sample_labels.device, dtype=single_sample_labels.dtype))
|
102 |
+
|
103 |
+
single_sample_modal_tag_pos_list.append((tag, cur_modal_idx, cur_modal_idx+modal_tensor.shape[0]+1))
|
104 |
+
|
105 |
+
cur_id += cur_modal_idx+1
|
106 |
+
tag_num[tag] += 1
|
107 |
+
|
108 |
+
single_sample_embeds.append(self.llm.get_input_embeddings()(single_sample_input_ids[cur_id:]))
|
109 |
+
_labels.append(single_sample_labels[cur_id:])
|
110 |
+
|
111 |
+
new_input_embeds.append(torch.cat(single_sample_embeds, dim=0))
|
112 |
+
new_labels.append(torch.cat(_labels, dim=0))
|
113 |
+
modal_tag_pos_list.append(single_sample_modal_tag_pos_list)
|
114 |
+
|
115 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
116 |
+
if tokenizer_model_max_length is not None:
|
117 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
118 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
119 |
+
|
120 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
121 |
+
batch_size = len(new_input_embeds)
|
122 |
+
|
123 |
+
new_input_embeds_padded = []
|
124 |
+
new_labels_padded = torch.full((batch_size, max_len), self.config.ignore_index, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
125 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
126 |
+
|
127 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
128 |
+
cur_len = cur_new_embed.shape[0]
|
129 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
130 |
+
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
131 |
+
if cur_len > 0:
|
132 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
133 |
+
attention_mask[i, -cur_len:] = True
|
134 |
+
else:
|
135 |
+
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
136 |
+
if cur_len > 0:
|
137 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
138 |
+
attention_mask[i, :cur_len] = True
|
139 |
+
|
140 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
141 |
+
|
142 |
+
return None, new_input_embeds, new_labels_padded, attention_mask, modal_tag_pos_list
|
143 |
+
|
144 |
+
|
145 |
+
class AllSparkForCausalLM(AllSparkPreTrainedModel):
|
146 |
+
|
147 |
+
def __init__(self, config):
|
148 |
+
super().__init__(config)
|
149 |
+
|
150 |
+
if self.config.modal_configs is not None:
|
151 |
+
self.modal_tags = []
|
152 |
+
self.modal_encoders, self.modal_projectors = nn.ModuleDict(), nn.ModuleDict()
|
153 |
+
for modal_config in self.config.modal_configs:
|
154 |
+
modal_tag = modal_config['modal_tag']
|
155 |
+
assert modal_tag not in self.modal_tags, f"Duplicate modal tag: {modal_tag}"
|
156 |
+
self.modal_tags.append(modal_tag)
|
157 |
+
self.modal_encoders[modal_tag] = MODAL_ENCODERS_MAPPING[modal_tag](modal_config)
|
158 |
+
encoder_cfg = self.modal_encoders[modal_tag].config
|
159 |
+
if isinstance(encoder_cfg, PretrainedConfig):
|
160 |
+
encoder_cfg = encoder_cfg.to_dict()
|
161 |
+
modal_config['encoder_cfg'] = encoder_cfg
|
162 |
+
self.modal_projectors[modal_tag] = MODAL_PROJECTORS_MAPPING[modal_tag](modal_config)
|
163 |
+
else:
|
164 |
+
self.modal_tags = None
|
165 |
+
|
166 |
+
if hasattr(config, 'llm_config'):
|
167 |
+
if "Qwen2" in config.llm_name_or_path:
|
168 |
+
llm_config = Qwen2Config(**config.llm_config)
|
169 |
+
self.llm = Qwen2ForCausalLMMoE._from_config(llm_config, modal_tags=self.modal_tags,
|
170 |
+
add_moe=self.config.add_moe)
|
171 |
+
else:
|
172 |
+
raise ValueError(config.llm_name_or_path)
|
173 |
+
else:
|
174 |
+
if "Qwen2" in config.llm_name_or_path:
|
175 |
+
self.llm = Qwen2ForCausalLMMoE.from_pretrained(config.llm_name_or_path, modal_tags=self.modal_tags,
|
176 |
+
add_moe=self.config.add_moe)
|
177 |
+
else:
|
178 |
+
raise ValueError(config.llm_name_or_path)
|
179 |
+
self.config.llm_config = self.llm.config
|
180 |
+
self.config.hidden_size = self.llm.config.hidden_size
|
181 |
+
|
182 |
+
if self.config.modal_configs is not None:
|
183 |
+
self.modal_embeds = nn.ParameterDict()
|
184 |
+
for modal_config in self.config.modal_configs:
|
185 |
+
modal_tag = modal_config['modal_tag']
|
186 |
+
self.modal_embeds[modal_tag] = torch.randn((2, self.config.hidden_size)) # start and end embeds
|
187 |
+
init.kaiming_normal_(self.modal_embeds[modal_tag])
|
188 |
+
|
189 |
+
self.post_init()
|
190 |
+
|
191 |
+
def initialize_tokenizer_for_multimodal(self, tokenizer, new_tag):
|
192 |
+
config = self.config
|
193 |
+
if config.modal_configs is None:
|
194 |
+
rank0_print("No modal configs provided, skipping multimodal tokenizer initialization.")
|
195 |
+
return None
|
196 |
+
|
197 |
+
for i, modal_config in enumerate(config.modal_configs):
|
198 |
+
# only add new tokens for the new modal
|
199 |
+
if modal_config['modal_tag'] != new_tag:
|
200 |
+
continue
|
201 |
+
modal_placeholder_token = modal_config['modal_placeholder_token']
|
202 |
+
tokenizer.add_tokens([modal_placeholder_token], special_tokens=True)
|
203 |
+
self.config.modal_configs[i]['modal_placeholder_token_id'] = tokenizer.convert_tokens_to_ids(modal_placeholder_token)
|
204 |
+
|
205 |
+
self.llm.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
206 |
+
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
input_ids: torch.LongTensor = None,
|
210 |
+
modal_inputs: Optional[List[List[Tuple[str, torch.FloatTensor]]]] = None, # B M (modal_tag, modal_input)
|
211 |
+
attention_mask: Optional[torch.Tensor] = None,
|
212 |
+
labels: Optional[torch.LongTensor] = None,
|
213 |
+
**kwargs
|
214 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
215 |
+
|
216 |
+
if modal_inputs is not None:
|
217 |
+
input_ids, inputs_embeds, labels, attention_mask, modal_tag_pos_list = \
|
218 |
+
self.prepare_multimodal_inputs(input_ids, modal_inputs, labels, attention_mask)
|
219 |
+
|
220 |
+
return self.llm(input_ids=input_ids,
|
221 |
+
attention_mask=attention_mask,
|
222 |
+
inputs_embeds=inputs_embeds,
|
223 |
+
labels=labels,
|
224 |
+
modal_tag_pos_list=modal_tag_pos_list,
|
225 |
+
**kwargs)
|
226 |
+
else:
|
227 |
+
return self.llm(input_ids=input_ids,
|
228 |
+
attention_mask=attention_mask,
|
229 |
+
labels=labels,
|
230 |
+
**kwargs)
|
231 |
+
|
232 |
+
@torch.no_grad()
|
233 |
+
def generate(
|
234 |
+
self,
|
235 |
+
input_ids: torch.LongTensor = None,
|
236 |
+
modal_inputs: Optional[List[List[Tuple[str, torch.FloatTensor]]]] = None, # B M (modal_tag, modal_input)
|
237 |
+
attention_mask: Optional[torch.Tensor] = None,
|
238 |
+
**kwargs
|
239 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
240 |
+
|
241 |
+
if modal_inputs is not None:
|
242 |
+
input_ids, inputs_embeds, labels, attention_mask, modal_tag_pos_list = \
|
243 |
+
self.prepare_multimodal_inputs(input_ids, modal_inputs, None, attention_mask)
|
244 |
+
|
245 |
+
return self.llm.generate(input_ids=input_ids, attention_mask=attention_mask,
|
246 |
+
inputs_embeds=inputs_embeds, modal_tag_pos_list=modal_tag_pos_list, **kwargs)
|
247 |
+
|
248 |
+
else:
|
249 |
+
return self.llm.generate(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
utils.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.distributed as dist
|
2 |
+
import ast
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
def model_params_summary(module, out_fn, verbose=True):
|
10 |
+
out_fn("-"*30)
|
11 |
+
out_fn(f"module name: {module.__class__}")
|
12 |
+
if verbose:
|
13 |
+
out_fn("-"*30)
|
14 |
+
for n, p in module.named_parameters():
|
15 |
+
out_fn(f"{n}: {'trainable' if p.requires_grad else 'freeze'}")
|
16 |
+
out_fn("-"*30)
|
17 |
+
out_fn(f"Total params: {sum(p.numel() for p in module.parameters())/1e6:.4f}M")
|
18 |
+
out_fn(f"Trainable params: {sum(p.numel() for p in module.parameters() if p.requires_grad)/1e6:.4f}M")
|
19 |
+
out_fn("-"*30)
|
20 |
+
|
21 |
+
|
22 |
+
def rank0_print(*args):
|
23 |
+
if dist.is_initialized():
|
24 |
+
if dist.get_rank() == 0:
|
25 |
+
print(f"Rank {dist.get_rank()}: ", *args)
|
26 |
+
else:
|
27 |
+
print(*args)
|
28 |
+
|
29 |
+
|
30 |
+
LLM_DIM_MAPPING = {
|
31 |
+
'Qwen2.5-0.5B': 896,
|
32 |
+
'Qwen2.5-1.5B': 1536,
|
33 |
+
'Qwen2.5-3B': 2048,
|
34 |
+
'Qwen2.5-7B': 3584
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
SYSTEM_PROMPT = "You are a multimodal AI assistant named AllSparkv2 capable of understanding and generating content " +\
|
39 |
+
"in various forms, including text and images. Your primary function is to provide useful and harmless " +\
|
40 |
+
"information based on user input, assisting with problem-solving, information retrieval, and task completion. "
|
41 |
+
|
42 |
+
|
43 |
+
def select_best_resolution(original_size, possible_resolutions):
|
44 |
+
"""
|
45 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
49 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
tuple: The best fit resolution in the format (width, height).
|
53 |
+
"""
|
54 |
+
original_width, original_height = original_size
|
55 |
+
best_fit = None
|
56 |
+
max_effective_resolution = 0
|
57 |
+
min_wasted_resolution = float("inf")
|
58 |
+
|
59 |
+
for width, height in possible_resolutions:
|
60 |
+
# Calculate the downscaled size to keep the aspect ratio
|
61 |
+
scale = min(width / original_width, height / original_height)
|
62 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
63 |
+
|
64 |
+
# Calculate effective and wasted resolutions
|
65 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
66 |
+
wasted_resolution = (width * height) - effective_resolution
|
67 |
+
|
68 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
69 |
+
max_effective_resolution = effective_resolution
|
70 |
+
min_wasted_resolution = wasted_resolution
|
71 |
+
best_fit = (width, height)
|
72 |
+
|
73 |
+
return best_fit
|
74 |
+
|
75 |
+
|
76 |
+
def resize_and_pad_image(image, target_resolution):
|
77 |
+
"""
|
78 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
image (PIL.Image.Image): The input image.
|
82 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
PIL.Image.Image: The resized and padded image.
|
86 |
+
"""
|
87 |
+
original_width, original_height = image.size
|
88 |
+
target_width, target_height = target_resolution
|
89 |
+
|
90 |
+
# Determine which dimension (width or height) to fill
|
91 |
+
scale_w = target_width / original_width
|
92 |
+
scale_h = target_height / original_height
|
93 |
+
|
94 |
+
if scale_w < scale_h:
|
95 |
+
# Width will be filled completely
|
96 |
+
new_width = target_width
|
97 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
98 |
+
else:
|
99 |
+
# Height will be filled completely
|
100 |
+
new_height = target_height
|
101 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
102 |
+
|
103 |
+
# Resize the image
|
104 |
+
resized_image = image.resize((new_width, new_height))
|
105 |
+
|
106 |
+
# Create a new image with the target size and paste the resized image onto it
|
107 |
+
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
108 |
+
paste_x = (target_width - new_width) // 2
|
109 |
+
paste_y = (target_height - new_height) // 2
|
110 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
111 |
+
|
112 |
+
return new_image
|
113 |
+
|
114 |
+
|
115 |
+
def divide_to_patches(image, patch_size):
|
116 |
+
"""
|
117 |
+
Divides an image into patches of a specified size.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
image (PIL.Image.Image): The input image.
|
121 |
+
patch_size (int): The size of each patch.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
125 |
+
"""
|
126 |
+
patches = []
|
127 |
+
width, height = image.size
|
128 |
+
for i in range(0, height, patch_size):
|
129 |
+
for j in range(0, width, patch_size):
|
130 |
+
box = (j, i, j + patch_size, i + patch_size)
|
131 |
+
patch = image.crop(box)
|
132 |
+
patches.append(patch)
|
133 |
+
|
134 |
+
return patches
|
135 |
+
|
136 |
+
|
137 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
138 |
+
"""
|
139 |
+
Process an image with variable resolutions.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
image (PIL.Image.Image): The input image to be processed.
|
143 |
+
processor: The image processor object.
|
144 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
torch.Tensor: A tensor containing the processed image patches.
|
148 |
+
"""
|
149 |
+
# Convert grid_pinpoints from string to list
|
150 |
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
151 |
+
patch_size = min(processor.size.values())
|
152 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
153 |
+
# Use regex to extract the range from the input string
|
154 |
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
155 |
+
range_start = tuple(map(int, matches[0]))
|
156 |
+
range_end = tuple(map(int, matches[-1]))
|
157 |
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
158 |
+
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
|
159 |
+
# Multiply all elements by patch_size
|
160 |
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
161 |
+
|
162 |
+
if type(grid_pinpoints) is list:
|
163 |
+
possible_resolutions = grid_pinpoints
|
164 |
+
else:
|
165 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
166 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
167 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
168 |
+
|
169 |
+
patches = divide_to_patches(image_padded, processor.size["height"])
|
170 |
+
|
171 |
+
# FIXME: this seems to be a bug that it resizes instead of pad.
|
172 |
+
# but to keep it consistent with previous, i will keep it as it is
|
173 |
+
# TODO: uncomment below to ablate with the padding
|
174 |
+
shortest_edge = min(processor.size.values())
|
175 |
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
176 |
+
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
177 |
+
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
178 |
+
|
179 |
+
image_patches = [image_original_resize] + patches
|
180 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
181 |
+
return torch.stack(image_patches, dim=0)
|