Update modeling_chatglm.py for transformers 4.49 compatibility

#89
Files changed (1) hide show
  1. modeling_chatglm.py +13 -2
modeling_chatglm.py CHANGED
@@ -924,10 +924,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
924
  outputs: ModelOutput,
925
  model_kwargs: Dict[str, Any],
926
  is_encoder_decoder: bool = False,
 
927
  ) -> Dict[str, Any]:
928
  # update past_key_values
929
- cache_name, cache = self._extract_past_from_model_output(outputs)
930
- model_kwargs[cache_name] = cache
 
 
 
 
 
 
931
 
932
  # update attention mask
933
  if "attention_mask" in model_kwargs:
@@ -946,6 +953,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
946
  )
947
 
948
  model_kwargs["is_first_forward"] = False
 
 
 
 
949
  return model_kwargs
950
 
951
  def prepare_inputs_for_generation(
 
924
  outputs: ModelOutput,
925
  model_kwargs: Dict[str, Any],
926
  is_encoder_decoder: bool = False,
927
+ num_new_tokens: int = 1,
928
  ) -> Dict[str, Any]:
929
  # update past_key_values
930
+ for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
931
+ if hasattr(outputs, possible_cache_name):
932
+ if possible_cache_name in ("past_buckets_states", "mems"):
933
+ cache_name = "past_key_values"
934
+ else:
935
+ cache_name = possible_cache_name
936
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
937
+ break
938
 
939
  # update attention mask
940
  if "attention_mask" in model_kwargs:
 
953
  )
954
 
955
  model_kwargs["is_first_forward"] = False
956
+
957
+ if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
958
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
959
+
960
  return model_kwargs
961
 
962
  def prepare_inputs_for_generation(