John Ho commited on
Commit
d673ad7
·
1 Parent(s): 15ef0c9

added different inference code for internvl3

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -114,7 +114,7 @@ MODEL_ZOO = {
114
  use_flash_attention=False,
115
  apply_quantization=False,
116
  ),
117
- "OpenGVLab/InternVL3-1B-hf": AutoModelForImageTextToText.from_pretrained(
118
  "OpenGVLab/InternVL3-1B-hf", device_map=DEVICE, torch_dtype=DTYPE
119
  ),
120
  }
@@ -123,7 +123,7 @@ PROCESSORS = {
123
  "qwen2.5-vl-7b-cam-motion-preview": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
124
  "qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
125
  "qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
126
- "OpenGVLab/InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
127
  }
128
  logger.debug("Models and Processors Loaded!")
129
 
@@ -162,38 +162,63 @@ def inference(
162
  }
163
  ]
164
 
165
- text = processor.apply_chat_template(
166
- messages, tokenize=False, add_generation_prompt=True
167
- )
168
- image_inputs, video_inputs, video_kwargs = process_vision_info(
169
- messages, return_video_kwargs=True
170
- )
171
 
172
  # This prevents PyTorch from building the computation graph for gradients,
173
  # saving a significant amount of memory for intermediate activations.
174
  with torch.no_grad():
175
- inputs = processor(
176
- text=[text],
177
- images=image_inputs,
178
- videos=video_inputs,
179
- # fps=fps,
180
- padding=True,
181
- return_tensors="pt",
182
- **video_kwargs,
183
- )
184
- inputs = inputs.to("cuda")
185
-
186
- # Inference
187
- generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
188
- generated_ids_trimmed = [
189
- out_ids[len(in_ids) :]
190
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
191
- ]
192
- output_text = processor.batch_decode(
193
- generated_ids_trimmed,
194
- skip_special_tokens=True,
195
- clean_up_tokenization_spaces=False,
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return output_text
198
 
199
 
 
114
  use_flash_attention=False,
115
  apply_quantization=False,
116
  ),
117
+ "InternVL3-1B-hf": AutoModelForImageTextToText.from_pretrained(
118
  "OpenGVLab/InternVL3-1B-hf", device_map=DEVICE, torch_dtype=DTYPE
119
  ),
120
  }
 
123
  "qwen2.5-vl-7b-cam-motion-preview": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
124
  "qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"),
125
  "qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"),
126
+ "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"),
127
  }
128
  logger.debug("Models and Processors Loaded!")
129
 
 
162
  }
163
  ]
164
 
165
+ # text = processor.apply_chat_template(
166
+ # messages, tokenize=False, add_generation_prompt=True
167
+ # )
168
+ # image_inputs, video_inputs, video_kwargs = process_vision_info(
169
+ # messages, return_video_kwargs=True
170
+ # )
171
 
172
  # This prevents PyTorch from building the computation graph for gradients,
173
  # saving a significant amount of memory for intermediate activations.
174
  with torch.no_grad():
175
+ model_family = model_name.split("-")[0]
176
+ match model_family:
177
+ case "qwen2.5":
178
+ text = processor.apply_chat_template(
179
+ messages, tokenize=False, add_generation_prompt=True
180
+ )
181
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
182
+ messages, return_video_kwargs=True
183
+ )
184
+ inputs = processor(
185
+ text=[text],
186
+ images=image_inputs,
187
+ videos=video_inputs,
188
+ # fps=fps,
189
+ padding=True,
190
+ return_tensors="pt",
191
+ **video_kwargs,
192
+ )
193
+ inputs = inputs.to("cuda")
194
+
195
+ # Inference
196
+ generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
197
+ generated_ids_trimmed = [
198
+ out_ids[len(in_ids) :]
199
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
200
+ ]
201
+ output_text = processor.batch_decode(
202
+ generated_ids_trimmed,
203
+ skip_special_tokens=True,
204
+ clean_up_tokenization_spaces=False,
205
+ )
206
+ case "InternVL3":
207
+ inputs = processor.apply_chat_template(
208
+ messages,
209
+ add_generation_prompt=True,
210
+ tokenize=True,
211
+ return_dict=True,
212
+ return_tensors="pt",
213
+ # num_frames = 8
214
+ ).to("cuda", dtype=DTYPE)
215
+
216
+ output = model.generate(**inputs, max_new_tokens=max_tokens)
217
+ output_text = processor.decode(
218
+ output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True
219
+ )
220
+ case _:
221
+ raise ValueError(f"{model_name} is not currently supported")
222
  return output_text
223
 
224