Fix data parsing in forward() method (#18)
Browse files- Fix data parsing in forward() method (482d9ffc5660ee228f84613f759e6cf29e053d3d)
Co-authored-by: Jasiek Kostecki <[email protected]>
- modeling_minicpmv.py +24 -0
 
    	
        modeling_minicpmv.py
    CHANGED
    
    | 
         @@ -203,6 +203,30 @@ class MiniCPMV(MiniCPMVPreTrainedModel): 
     | 
|
| 203 | 
         | 
| 204 | 
         | 
| 205 | 
         
             
                def forward(self, data, **kwargs):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 206 | 
         
             
                    vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
         
     | 
| 207 | 
         | 
| 208 | 
         
             
                    position_ids = data["position_ids"]
         
     | 
| 
         | 
|
| 203 | 
         | 
| 204 | 
         | 
| 205 | 
         
             
                def forward(self, data, **kwargs):
         
     | 
| 206 | 
         
            +
                    if isinstance(data, torch.Tensor):
         
     | 
| 207 | 
         
            +
                        attention_mask = torch.ones_like(data, dtype=torch.bool)
         
     | 
| 208 | 
         
            +
                        kwargs = {'attention_mask': attention_mask}
         
     | 
| 209 | 
         
            +
                        return self.llm(
         
     | 
| 210 | 
         
            +
                            input_ids=data,
         
     | 
| 211 | 
         
            +
                            **kwargs
         
     | 
| 212 | 
         
            +
                        )
         
     | 
| 213 | 
         
            +
                
         
     | 
| 214 | 
         
            +
                    if data is None:
         
     | 
| 215 | 
         
            +
                        data = {
         
     | 
| 216 | 
         
            +
                            "input_ids": kwargs.pop("input_ids", None),
         
     | 
| 217 | 
         
            +
                            "pixel_values": kwargs.pop("pixel_values", None),
         
     | 
| 218 | 
         
            +
                            "image_bound": kwargs.pop("image_bound", None),
         
     | 
| 219 | 
         
            +
                            "tgt_sizes": kwargs.pop("tgt_sizes", None),
         
     | 
| 220 | 
         
            +
                            "position_ids": kwargs.pop("position_ids", None),
         
     | 
| 221 | 
         
            +
                        }
         
     | 
| 222 | 
         
            +
                    else:
         
     | 
| 223 | 
         
            +
                        kwargs.pop("input_ids", None)
         
     | 
| 224 | 
         
            +
                        kwargs.pop("pixel_values", None)
         
     | 
| 225 | 
         
            +
                        kwargs.pop("image_bound", None)
         
     | 
| 226 | 
         
            +
                        kwargs.pop("tgt_sizes", None)
         
     | 
| 227 | 
         
            +
                        kwargs.pop("position_ids", None)
         
     | 
| 228 | 
         
            +
                    kwargs.pop("inputs_embeds", None)
         
     | 
| 229 | 
         
            +
                
         
     | 
| 230 | 
         
             
                    vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
         
     | 
| 231 | 
         | 
| 232 | 
         
             
                    position_ids = data["position_ids"]
         
     |