Update modeling_ovis2_5.py
Browse files- modeling_ovis2_5.py +48 -1
modeling_ovis2_5.py
CHANGED
|
@@ -894,7 +894,54 @@ class Ovis2_5(OvisPreTrainedModel):
|
|
| 894 |
pixel_values=kwargs.pop('pixel_values', None),
|
| 895 |
grid_thws=kwargs.pop('grid_thws', None)
|
| 896 |
)
|
| 897 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 898 |
|
| 899 |
|
| 900 |
AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
|
|
|
|
| 894 |
pixel_values=kwargs.pop('pixel_values', None),
|
| 895 |
grid_thws=kwargs.pop('grid_thws', None)
|
| 896 |
)
|
| 897 |
+
enable_thinking = kwargs.pop('enable_thinking', False)
|
| 898 |
+
enable_thinking_budget = kwargs.pop('enable_thinking_budget', False)
|
| 899 |
+
thinking_budget = kwargs.pop('thinking_budget', 1024)
|
| 900 |
+
|
| 901 |
+
if enable_thinking and enable_thinking_budget:
|
| 902 |
+
actual_max_new_tokens = kwargs['max_new_tokens']
|
| 903 |
+
kwargs['max_new_tokens'] = thinking_budget
|
| 904 |
+
generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 905 |
+
output_ids = generated_ids
|
| 906 |
+
output_ids_list = generated_ids[0]
|
| 907 |
+
|
| 908 |
+
# check if the generation has already finished (151645 is <|im_end|>)
|
| 909 |
+
if 151645 not in output_ids_list:
|
| 910 |
+
# check if the thinking process has finished (151668 is </think>)
|
| 911 |
+
# and prepare the second model input
|
| 912 |
+
if 151668 not in output_ids_list:
|
| 913 |
+
print("thinking budget is reached")
|
| 914 |
+
early_stopping_text = "\n\nConsidering the limited time by the user, I have to give the solution based on the thinking directly now.\n</think>\n\n"
|
| 915 |
+
early_stopping_ids = self.text_tokenizer(early_stopping_text, return_tensors="pt", return_attention_mask=False).input_ids.to(inputs.device)
|
| 916 |
+
input_ids_appendent = torch.cat([output_ids, early_stopping_ids], dim=-1)
|
| 917 |
+
kwargs['streamer'].put(early_stopping_ids) if 'streamer' in kwargs else None
|
| 918 |
+
else:
|
| 919 |
+
input_ids_appendent = output_ids
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
# second generation
|
| 923 |
+
new_inputs = torch.cat([inputs, input_ids_appendent], dim=-1)
|
| 924 |
+
attention_mask = torch.ne(new_inputs, self.text_tokenizer.pad_token_id).to(device=inputs.device)
|
| 925 |
+
inputs_embeds_appendent = self.merge_multimodal(
|
| 926 |
+
input_ids=input_ids_appendent,
|
| 927 |
+
pixel_values=None,
|
| 928 |
+
grid_thws=None
|
| 929 |
+
)
|
| 930 |
+
new_inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_appendent], dim=-2)
|
| 931 |
+
|
| 932 |
+
kwargs['max_new_tokens'] = inputs_embeds.size(-2) + actual_max_new_tokens - new_inputs_embeds.size(-2)
|
| 933 |
+
generated_ids2 = self.llm.generate(inputs=None, inputs_embeds=new_inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 934 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 935 |
+
return torch.cat([input_ids_appendent, generated_ids2], dim=-1)
|
| 936 |
+
|
| 937 |
+
else:
|
| 938 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 939 |
+
return generated_ids
|
| 940 |
+
|
| 941 |
+
else:
|
| 942 |
+
generated_ids = self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
| 943 |
+
kwargs['streamer'].manual_end() if 'streamer' in kwargs else None
|
| 944 |
+
return generated_ids
|
| 945 |
|
| 946 |
|
| 947 |
AutoConfig.register('siglip2_navit', Siglip2NavitConfig)
|