duzx16
commited on
Commit
·
7fabe56
1
Parent(s):
efb7a1e
Fix use_cache=False
Browse files- modeling_chatglm.py +5 -2
modeling_chatglm.py
CHANGED
|
@@ -897,6 +897,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 897 |
past_key_values: Optional[torch.Tensor] = None,
|
| 898 |
attention_mask: Optional[torch.Tensor] = None,
|
| 899 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 900 |
is_first_forward: bool = True,
|
| 901 |
**kwargs
|
| 902 |
) -> dict:
|
|
@@ -904,7 +905,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 904 |
if position_ids is None:
|
| 905 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
| 906 |
if not is_first_forward:
|
| 907 |
-
if
|
| 908 |
position_ids = position_ids[..., -1:]
|
| 909 |
input_ids = input_ids[:, -1:]
|
| 910 |
return {
|
|
@@ -912,7 +913,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 912 |
"past_key_values": past_key_values,
|
| 913 |
"position_ids": position_ids,
|
| 914 |
"attention_mask": attention_mask,
|
| 915 |
-
"return_last_logit": True
|
|
|
|
| 916 |
}
|
| 917 |
|
| 918 |
def forward(
|
|
@@ -1089,6 +1091,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1089 |
generation_config = self.generation_config
|
| 1090 |
generation_config = copy.deepcopy(generation_config)
|
| 1091 |
model_kwargs = generation_config.update(**kwargs)
|
|
|
|
| 1092 |
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
| 1093 |
|
| 1094 |
if isinstance(eos_token_id, int):
|
|
|
|
| 897 |
past_key_values: Optional[torch.Tensor] = None,
|
| 898 |
attention_mask: Optional[torch.Tensor] = None,
|
| 899 |
position_ids: Optional[torch.Tensor] = None,
|
| 900 |
+
use_cache: Optional[bool] = None,
|
| 901 |
is_first_forward: bool = True,
|
| 902 |
**kwargs
|
| 903 |
) -> dict:
|
|
|
|
| 905 |
if position_ids is None:
|
| 906 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
| 907 |
if not is_first_forward:
|
| 908 |
+
if past_key_values is not None:
|
| 909 |
position_ids = position_ids[..., -1:]
|
| 910 |
input_ids = input_ids[:, -1:]
|
| 911 |
return {
|
|
|
|
| 913 |
"past_key_values": past_key_values,
|
| 914 |
"position_ids": position_ids,
|
| 915 |
"attention_mask": attention_mask,
|
| 916 |
+
"return_last_logit": True,
|
| 917 |
+
"use_cache": use_cache
|
| 918 |
}
|
| 919 |
|
| 920 |
def forward(
|
|
|
|
| 1091 |
generation_config = self.generation_config
|
| 1092 |
generation_config = copy.deepcopy(generation_config)
|
| 1093 |
model_kwargs = generation_config.update(**kwargs)
|
| 1094 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
| 1095 |
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
| 1096 |
|
| 1097 |
if isinstance(eos_token_id, int):
|