make compatible with transformers 4.49+
Browse filesThey completely removed the ```_extract_past_from_model_output``` method in transformers 4.49+
- modeling_chatglm.py +13 -7
modeling_chatglm.py
CHANGED
@@ -1082,19 +1082,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1082 |
outputs: ModelOutput,
|
1083 |
model_kwargs: Dict[str, Any],
|
1084 |
is_encoder_decoder: bool = False,
|
|
|
1085 |
) -> Dict[str, Any]:
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
|
|
|
|
|
|
|
|
|
|
1089 |
|
1090 |
-
# update attention mask
|
1091 |
if "attention_mask" in model_kwargs:
|
1092 |
attention_mask = model_kwargs["attention_mask"]
|
1093 |
model_kwargs["attention_mask"] = torch.cat(
|
1094 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1095 |
)
|
1096 |
-
|
1097 |
-
# update position ids
|
1098 |
if "position_ids" in model_kwargs:
|
1099 |
position_ids = model_kwargs["position_ids"]
|
1100 |
new_position_id = position_ids[..., -1:].clone()
|
@@ -1102,8 +1105,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1102 |
model_kwargs["position_ids"] = torch.cat(
|
1103 |
[position_ids, new_position_id], dim=-1
|
1104 |
)
|
1105 |
-
|
1106 |
model_kwargs["is_first_forward"] = False
|
|
|
|
|
|
|
|
|
1107 |
return model_kwargs
|
1108 |
|
1109 |
def prepare_inputs_for_generation(
|
|
|
1082 |
outputs: ModelOutput,
|
1083 |
model_kwargs: Dict[str, Any],
|
1084 |
is_encoder_decoder: bool = False,
|
1085 |
+
num_new_tokens: int = 1,
|
1086 |
) -> Dict[str, Any]:
|
1087 |
+
for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
|
1088 |
+
if hasattr(outputs, possible_cache_name):
|
1089 |
+
if possible_cache_name in ("past_buckets_states", "mems"):
|
1090 |
+
cache_name = "past_key_values"
|
1091 |
+
else:
|
1092 |
+
cache_name = possible_cache_name
|
1093 |
+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
1094 |
+
break
|
1095 |
|
|
|
1096 |
if "attention_mask" in model_kwargs:
|
1097 |
attention_mask = model_kwargs["attention_mask"]
|
1098 |
model_kwargs["attention_mask"] = torch.cat(
|
1099 |
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1100 |
)
|
|
|
|
|
1101 |
if "position_ids" in model_kwargs:
|
1102 |
position_ids = model_kwargs["position_ids"]
|
1103 |
new_position_id = position_ids[..., -1:].clone()
|
|
|
1105 |
model_kwargs["position_ids"] = torch.cat(
|
1106 |
[position_ids, new_position_id], dim=-1
|
1107 |
)
|
|
|
1108 |
model_kwargs["is_first_forward"] = False
|
1109 |
+
|
1110 |
+
if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
|
1111 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
1112 |
+
|
1113 |
return model_kwargs
|
1114 |
|
1115 |
def prepare_inputs_for_generation(
|