Update modeling_llama.py
Browse files- modeling_llama.py +7 -2
modeling_llama.py
CHANGED
@@ -961,12 +961,13 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
961 |
inputs_embeds = self.embed_tokens(input_ids)
|
962 |
|
963 |
past_seen_tokens = 0
|
|
|
964 |
if use_cache: # kept for BC (cache positions)
|
965 |
if past_key_values is not None and not isinstance(
|
966 |
past_key_values, StaticCache
|
967 |
):
|
968 |
if not isinstance(past_key_values, DynamicCache):
|
969 |
-
used_legacy_cache=True
|
970 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
971 |
past_seen_tokens = past_key_values.get_seq_length()
|
972 |
|
@@ -1038,7 +1039,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1038 |
|
1039 |
next_cache = None
|
1040 |
if use_cache:
|
1041 |
-
next_cache =
|
|
|
|
|
|
|
|
|
1042 |
if not return_dict:
|
1043 |
return tuple(
|
1044 |
v
|
|
|
961 |
inputs_embeds = self.embed_tokens(input_ids)
|
962 |
|
963 |
past_seen_tokens = 0
|
964 |
+
used_legacy_cache = False
|
965 |
if use_cache: # kept for BC (cache positions)
|
966 |
if past_key_values is not None and not isinstance(
|
967 |
past_key_values, StaticCache
|
968 |
):
|
969 |
if not isinstance(past_key_values, DynamicCache):
|
970 |
+
used_legacy_cache = True
|
971 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
972 |
past_seen_tokens = past_key_values.get_seq_length()
|
973 |
|
|
|
1039 |
|
1040 |
next_cache = None
|
1041 |
if use_cache:
|
1042 |
+
next_cache = (
|
1043 |
+
next_decoder_cache.to_legacy_cache()
|
1044 |
+
if used_legacy_cache
|
1045 |
+
else next_decoder_cache
|
1046 |
+
)
|
1047 |
if not return_dict:
|
1048 |
return tuple(
|
1049 |
v
|