zifei9 commited on
Commit
681024c
·
verified ·
1 Parent(s): 3f419d7

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +7 -2
modeling_llama.py CHANGED
@@ -961,12 +961,13 @@ class LlamaModel(LlamaPreTrainedModel):
961
  inputs_embeds = self.embed_tokens(input_ids)
962
 
963
  past_seen_tokens = 0
 
964
  if use_cache: # kept for BC (cache positions)
965
  if past_key_values is not None and not isinstance(
966
  past_key_values, StaticCache
967
  ):
968
  if not isinstance(past_key_values, DynamicCache):
969
- used_legacy_cache=True
970
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
971
  past_seen_tokens = past_key_values.get_seq_length()
972
 
@@ -1038,7 +1039,11 @@ class LlamaModel(LlamaPreTrainedModel):
1038
 
1039
  next_cache = None
1040
  if use_cache:
1041
- next_cache = next_decoder_cache.to_legacy_cache() if used_legacy_cache else next_decoder_cache
 
 
 
 
1042
  if not return_dict:
1043
  return tuple(
1044
  v
 
961
  inputs_embeds = self.embed_tokens(input_ids)
962
 
963
  past_seen_tokens = 0
964
+ used_legacy_cache = False
965
  if use_cache: # kept for BC (cache positions)
966
  if past_key_values is not None and not isinstance(
967
  past_key_values, StaticCache
968
  ):
969
  if not isinstance(past_key_values, DynamicCache):
970
+ used_legacy_cache = True
971
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
972
  past_seen_tokens = past_key_values.get_seq_length()
973
 
 
1039
 
1040
  next_cache = None
1041
  if use_cache:
1042
+ next_cache = (
1043
+ next_decoder_cache.to_legacy_cache()
1044
+ if used_legacy_cache
1045
+ else next_decoder_cache
1046
+ )
1047
  if not return_dict:
1048
  return tuple(
1049
  v