ctranslate2-4you commited on
Commit
1a5be97
·
verified ·
1 Parent(s): ae8facc

make compatible with transformers 4.49+

Browse files

They completely removed the ```_extract_past_from_model_output``` method in transformers 4.49+

Files changed (1) hide show
  1. modeling_chatglm.py +13 -7
modeling_chatglm.py CHANGED
@@ -1082,19 +1082,22 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1082
  outputs: ModelOutput,
1083
  model_kwargs: Dict[str, Any],
1084
  is_encoder_decoder: bool = False,
 
1085
  ) -> Dict[str, Any]:
1086
- # update past_key_values
1087
- cache_name, cache = self._extract_past_from_model_output(outputs)
1088
- model_kwargs[cache_name] = cache
 
 
 
 
 
1089
 
1090
- # update attention mask
1091
  if "attention_mask" in model_kwargs:
1092
  attention_mask = model_kwargs["attention_mask"]
1093
  model_kwargs["attention_mask"] = torch.cat(
1094
  [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1095
  )
1096
-
1097
- # update position ids
1098
  if "position_ids" in model_kwargs:
1099
  position_ids = model_kwargs["position_ids"]
1100
  new_position_id = position_ids[..., -1:].clone()
@@ -1102,8 +1105,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1102
  model_kwargs["position_ids"] = torch.cat(
1103
  [position_ids, new_position_id], dim=-1
1104
  )
1105
-
1106
  model_kwargs["is_first_forward"] = False
 
 
 
 
1107
  return model_kwargs
1108
 
1109
  def prepare_inputs_for_generation(
 
1082
  outputs: ModelOutput,
1083
  model_kwargs: Dict[str, Any],
1084
  is_encoder_decoder: bool = False,
1085
+ num_new_tokens: int = 1,
1086
  ) -> Dict[str, Any]:
1087
+ for possible_cache_name in ["past_key_values", "mems", "past_buckets_states", "cache_params"]:
1088
+ if hasattr(outputs, possible_cache_name):
1089
+ if possible_cache_name in ("past_buckets_states", "mems"):
1090
+ cache_name = "past_key_values"
1091
+ else:
1092
+ cache_name = possible_cache_name
1093
+ model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
1094
+ break
1095
 
 
1096
  if "attention_mask" in model_kwargs:
1097
  attention_mask = model_kwargs["attention_mask"]
1098
  model_kwargs["attention_mask"] = torch.cat(
1099
  [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1100
  )
 
 
1101
  if "position_ids" in model_kwargs:
1102
  position_ids = model_kwargs["position_ids"]
1103
  new_position_id = position_ids[..., -1:].clone()
 
1105
  model_kwargs["position_ids"] = torch.cat(
1106
  [position_ids, new_position_id], dim=-1
1107
  )
 
1108
  model_kwargs["is_first_forward"] = False
1109
+
1110
+ if model_kwargs.get("use_cache", True) and "cache_position" in model_kwargs:
1111
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
1112
+
1113
  return model_kwargs
1114
 
1115
  def prepare_inputs_for_generation(