drlon commited on
Commit
86938ce
·
1 Parent(s): 3d2a5fe
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -101,13 +101,23 @@ def get_som_response(instruction, image_som):
101
  add_generation_prompt=True
102
  )
103
 
 
 
 
 
 
 
 
 
104
  inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
105
- # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
106
- # inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
107
- # inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
108
- # logger.warning(inputs['pixel_values'].dtype)
109
- # # inputs = inputs.to("cuda")
110
- inputs = inputs.to("cuda", dtype=torch.bfloat16)
 
 
111
 
112
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
113
  with torch.inference_mode():
 
101
  add_generation_prompt=True
102
  )
103
 
104
+ # inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
105
+ # # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
106
+ # # inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
107
+ # # inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
108
+ # # logger.warning(inputs['pixel_values'].dtype)
109
+ # # # inputs = inputs.to("cuda")
110
+ # inputs = inputs.to("cuda", dtype=torch.bfloat16)
111
+
112
  inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
113
+ inputs['pixel_values'] = inputs['pixel_values'].to("cuda", dtype=torch.bfloat16)
114
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
115
+ inputs['image_sizes'] = inputs['image_sizes'].to("cuda")
116
+
117
+ # 处理其他可能的输入
118
+ for key in inputs:
119
+ if key not in ['pixel_values', 'image_sizes'] and torch.is_tensor(inputs[key]):
120
+ inputs[key] = inputs[key].to("cuda")
121
 
122
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
123
  with torch.inference_mode():