Update modeling_chatglm.py for transformers 4.49 compatibility
#89
by
sylwia-kuros
- opened
- modeling_chatglm.py +13 -2
modeling_chatglm.py
CHANGED
@@ -924,10 +924,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
924 |
outputs: ModelOutput,
|
925 |
model_kwargs: Dict[str, Any],
|
926 |
is_encoder_decoder: bool = False,
|
|
|
927 |
) -> Dict[str, Any]:
|
928 |
# update past_key_values
|
929 |
-
|
930 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
|
932 |
# update attention mask
|
933 |
if "attention_mask" in model_kwargs:
|
@@ -946,6 +953,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
946 |
)
|
947 |
|
948 |
model_kwargs["is_first_forward"] = False
|
|
|
|
|
|
|
|
|
949 |
return model_kwargs
|
950 |
|
951 |
def prepare_inputs_for_generation(
|
|
|
924 |
outputs: ModelOutput,
|
925 |
model_kwargs: Dict[str, Any],
|
926 |
is_encoder_decoder: bool = False,
|
927 |
+
num_new_tokens: int = 1,
|
928 |
) -> Dict[str, Any]:
|
929 |
# update past_key_values
|
930 |
+
for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
|
931 |
+
if hasattr(outputs, possible_cache_name):
|
932 |
+
if possible_cache_name in ("past_buckets_states", "mems"):
|
933 |
+
cache_name = "past_key_values"
|
934 |
+
else:
|
935 |
+
cache_name = possible_cache_name
|
936 |
+
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
937 |
+
break
|
938 |
|
939 |
# update attention mask
|
940 |
if "attention_mask" in model_kwargs:
|
|
|
953 |
)
|
954 |
|
955 |
model_kwargs["is_first_forward"] = False
|
956 |
+
|
957 |
+
if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
|
958 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
959 |
+
|
960 |
return model_kwargs
|
961 |
|
962 |
def prepare_inputs_for_generation(
|