diff --git a/block_config.py b/block_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd4ecaf4dced6a01dafab91c62ad556d311ba08c
--- /dev/null
+++ b/block_config.py
@@ -0,0 +1,118 @@
+import dataclasses
+import json
+import warnings
+from dataclasses import dataclass, MISSING
+from functools import partial
+from typing import Optional, Any
+
+
+@partial(dataclass, frozen=True, kw_only=True)
+class JsonComparable:
+ def to_json(self) -> str:
+ return json.dumps(dataclasses.asdict(self))
+
+ def __eq__(self, other: "JsonComparable") -> bool:
+ return self.to_json() == other.to_json()
+
+ def __hash__(self) -> int:
+ return hash(self.to_json())
+
+ def __lt__(self, other: "JsonComparable") -> bool:
+ return self.to_json() < other.to_json()
+
+
+@partial(dataclass, frozen=True, kw_only=True)
+class SubblockConfig(JsonComparable):
+ no_op: bool = False
+ replace_with_linear: bool = False
+ sparsify: Optional[list[str]] = None
+
+ def __post_init__(self):
+ assert not (self.no_op and self.replace_with_linear)
+
+ def _force_setattr(self, name: str, value: Any) -> None:
+ """
+ Set an attribute even in frozen dataclasses.
+ Use only inside __post_init__!
+ """
+ object.__setattr__(self, name, value)
+
+
+@partial(dataclass, frozen=True, kw_only=True)
+class AttentionConfig(SubblockConfig):
+ n_heads_in_group: Optional[int] = None
+ window_length: Optional[int] = None
+ num_sink_tokens: Optional[int] = None
+ use_prefill_window_in_sink_attention: bool = False
+ unshifted_sink: bool = False
+
+ def __post_init__(self):
+ super().__post_init__()
+ assert not (self.no_op and self.replace_with_linear)
+
+ if self.no_op or self.replace_with_linear:
+ for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]:
+ self._force_setattr(irrelevant_att, None)
+ else:
+ assert self.n_heads_in_group is not None
+
+ if self.is_sink:
+ assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \
+ ("Unshifted sink uses its own kind of explicit masking, not standard window. "
+ "Set use_prefill_window_in_sink_attention to False.")
+ assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \
+ "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True"
+
+ @property
+ def prefill_sliding_window(self) -> Optional[int]:
+ if self.window_length is not None:
+ if not self.is_sink or self.use_prefill_window_in_sink_attention:
+ return self.window_length
+ return None
+
+ @property
+ def is_sliding(self) -> bool:
+ return self.prefill_sliding_window is not None
+
+ @property
+ def is_sink(self) -> bool:
+ return (
+ (self.window_length is not None)
+ and
+ (self.num_sink_tokens is not None)
+ )
+
+
+@partial(dataclass, frozen=True, kw_only=True)
+class FFNConfig(SubblockConfig):
+ ffn_mult: Optional[float] = None
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.no_op or self.replace_with_linear:
+ self._force_setattr("ffn_mult", None)
+ else:
+ assert self.ffn_mult is not None
+ self._force_setattr("ffn_mult", round(self.ffn_mult, 6))
+
+
+@partial(dataclass, frozen=True, kw_only=True)
+class BlockConfig(JsonComparable):
+ attention: AttentionConfig = MISSING
+ ffn: FFNConfig = MISSING
+
+ def __post_init__(self):
+ """
+ Init subblock dataclasses from dicts
+ """
+ for subblock_name in dataclasses.fields(self):
+ subblock_config = getattr(self, subblock_name.name)
+ if isinstance(subblock_config, dict):
+ subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
+ unsupported_fields = [field_name for field_name in subblock_config.keys()
+ if field_name not in subblock_fields]
+ if len(unsupported_fields) > 0:
+ warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
+ subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
+ object.__setattr__(self, subblock_name.name,
+ subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True
diff --git a/chat_template.jinja b/chat_template.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..357e0a73e1ec7c614af2cfa78db4ba585dfa8eb9
--- /dev/null
+++ b/chat_template.jinja
@@ -0,0 +1,46 @@
+{%- if messages[0]["role"] == "system" -%}
+ {%- set system_message = messages[0]["content"] | trim -%}
+ {%- set messages = messages[1:] -%}
+{%- else -%}
+ {%- set system_message = "" -%}
+{%- endif -%}
+{%- if tools is not none -%}
+ {{- "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" + "\n\n" + system_message -}}
+ {{- "\n\n" if system_message else "" -}}
+ {{- "[" -}}
+ {%- for t in tools -%}
+ {{- (t.function if t.function is defined else t) | tojson() -}}
+ {{- ", " if not loop.last else "" -}}
+ {%- endfor -%}
+ {{- "]" -}}
+ {{- "<|eot_id|>" -}}
+{%- else -%}
+ {{- "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" + "\n\n" + system_message + "<|eot_id|>" -}}
+{%- endif -%}
+{%- for message in messages -%}
+ {%- if message["role"] == "user" -%}
+ {{- "<|start_header_id|>user<|end_header_id|>" + "\n\n" + message["content"] | trim + "<|eot_id|>" -}}
+ {%- elif message["role"] == "tool" -%}
+ {%- set tool_response = "[" + message["content"] | trim + "]" -%}
+ {{- "<|start_header_id|>user<|end_header_id|>" + "\n\n" + tool_response + "<|eot_id|>" -}}
+ {%- elif message["role"] == "assistant" and message.get("tool_calls") is not none -%}
+ {%- set tool_calls = message["tool_calls"] -%}
+ {{- "<|start_header_id|>assistant<|end_header_id|>" + "\n\n" + "[" -}}
+ {%- for tool_call in tool_calls -%}
+ {{- "{" + "\"name\": \"" + tool_call.function.name + "\", \"arguments\": " + tool_call.function.arguments | tojson + "}" -}}
+ {%- if not loop.last -%}
+ {{- ", " -}}
+ {%- else -%}
+ {{- "]" + "<|eot_id|>" -}}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- elif message["role"] == "assistant" -%}
+ {{- "<|start_header_id|>assistant<|end_header_id|>" + "\n\n" -}}
+ {%- generation %}
+ {{- message["content"] | trim + "<|eot_id|>" -}}
+ {%- endgeneration %}
+ {%- endif -%}
+{%- endfor -%}
+{%- if add_generation_prompt -%}
+ {{- "<|start_header_id|>assistant<|end_header_id|>" + "\n\n" -}}
+{%- endif -%}
\ No newline at end of file
diff --git a/config.json b/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..1af0197236073423ce1ac3a4d202bcef9060fd63
--- /dev/null
+++ b/config.json
@@ -0,0 +1,1484 @@
+{
+ "architectures": [
+ "DeciLMForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "auto_map": {
+ "AutoConfig": "configuration_decilm.DeciLMConfig",
+ "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM"
+ },
+ "block_configs": [
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 3.28125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 2.625,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.3125,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 1.0,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": null,
+ "no_op": true,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 0.5,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ },
+ {
+ "attention": {
+ "n_heads_in_group": 8,
+ "no_op": false,
+ "num_sink_tokens": null,
+ "replace_with_linear": false,
+ "sparsify": null,
+ "unshifted_sink": false,
+ "use_prefill_window_in_sink_attention": false,
+ "window_length": null
+ },
+ "ffn": {
+ "ffn_mult": 5.25,
+ "no_op": false,
+ "replace_with_linear": false,
+ "sparsify": null
+ }
+ }
+ ],
+ "bos_token_id": 128000,
+ "eos_token_id": [
+ 128001,
+ 128008,
+ 128009
+ ],
+ "hidden_act": "silu",
+ "hidden_size": 8192,
+ "initializer_range": 0.02,
+ "intermediate_size": null,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "model_type": "nemotron-nas",
+ "num_attention_heads": 64,
+ "num_hidden_layers": 80,
+ "num_key_value_heads": null,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "torch_dtype": "float32",
+ "transformers_version": "4.53.3",
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/configuration_decilm.py b/configuration_decilm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e241b4025109b0b9ad34e0815bc34df834c600cb
--- /dev/null
+++ b/configuration_decilm.py
@@ -0,0 +1,65 @@
+# coding=utf-8
+# Copyright 2024 Nvidia Corporation. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import warnings
+from typing import Dict, Any
+
+from transformers.utils import is_flash_attn_2_available
+
+from .block_config import BlockConfig
+from .transformers_4_44_2__configuration_llama import LlamaConfig
+from .transformers_4_44_2__modeling_rope_utils import \
+ rope_config_validation # fake import to make AutoConfig infer the dependency
+
+rope_config_validation # this line is here to make sure that auto-formatting doesn't remove the import
+
+
+class DeciLMConfig(LlamaConfig):
+ model_type = "nemotron-nas"
+
+ def __init__(
+ self,
+ block_configs: list[dict] | list[BlockConfig] = None,
+ **kwargs,
+ ):
+ attn_implementation = kwargs.pop("attn_implementation", None)
+ if attn_implementation is None and is_flash_attn_2_available():
+ attn_implementation = "flash_attention_2"
+
+ if block_configs is not None:
+ if isinstance(block_configs[0], dict):
+ block_configs = [BlockConfig(**conf) for conf in block_configs]
+
+ using_unshifted_sink = any([block_config.attention.unshifted_sink for block_config in block_configs])
+ if using_unshifted_sink and attn_implementation != "eager":
+ warnings.warn("Forcing attn_implementation='eager' since some attention layers use unshifted sink")
+ attn_implementation = "eager"
+
+ super().__init__(attn_implementation=attn_implementation, **kwargs)
+
+ self.intermediate_size = None
+ self.num_key_value_heads = None
+
+ if block_configs is not None:
+ assert len(block_configs) == self.num_hidden_layers
+
+ self.block_configs: list[BlockConfig] = block_configs
+
+ def to_dict(self) -> Dict[str, Any]:
+ self_dict = super().to_dict()
+ if self.block_configs is not None:
+ self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs]
+ return self_dict
diff --git a/generation_config.json b/generation_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f67526f9f77f10f0d273e7ec51fd3aa0c196385f
--- /dev/null
+++ b/generation_config.json
@@ -0,0 +1,11 @@
+{
+ "_from_model_config": true,
+ "bos_token_id": 128000,
+ "do_sample": true,
+ "eos_token_id": [
+ 128001,
+ 128008,
+ 128009
+ ],
+ "transformers_version": "4.53.3"
+}
diff --git a/model-00001-of-00044.safetensors b/model-00001-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c5f17af85295b61c58757586ca8a3678f0d31a36
--- /dev/null
+++ b/model-00001-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:947a6e0a58c53c5718ba92a3341b4faf40f944fd0f303c58934fbbbeea0a8ff8
+size 4806738752
diff --git a/model-00002-of-00044.safetensors b/model-00002-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..70f5ce956f831753ffd5120b8d2635656d2da08c
--- /dev/null
+++ b/model-00002-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4abdb0931d2537e5b1c926ff4484cae467a2d3178db052d51d1302749a2aac4f
+size 4831938024
diff --git a/model-00003-of-00044.safetensors b/model-00003-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..086fead5ab577144065643657d6bbe18c56435f7
--- /dev/null
+++ b/model-00003-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:834e9b999f724ccc01a456bd3bdf97143acb29fe7a258d96447543eef1c94644
+size 4966156000
diff --git a/model-00004-of-00044.safetensors b/model-00004-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..42d8b705dd0be50cc2694bebf73e7b4a53e85e94
--- /dev/null
+++ b/model-00004-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc891b7c6cbd817d6a68c8b8a1f144f346765db9680f632b6d4465aeb34bc1d9
+size 4362142872
diff --git a/model-00005-of-00044.safetensors b/model-00005-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c13f58b72d9364bcd47c84c8d30d6d9974a03fc3
--- /dev/null
+++ b/model-00005-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:131484a50c23bf397ca9f7ea14924625e44d17c747ae17b2a2d14338fbed9d18
+size 4831937920
diff --git a/model-00006-of-00044.safetensors b/model-00006-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..c3d7ec79e8ca0595425f3587f720b6d7272cf7d9
--- /dev/null
+++ b/model-00006-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bf114a1c321c31854d2bcdb4eb2cf2afeac41f62042f75a8f4b1042f6222a5aa
+size 4831938144
diff --git a/model-00007-of-00044.safetensors b/model-00007-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..5ef8b042e39ccea191b53e2e9ac00aeca0d2c3fe
--- /dev/null
+++ b/model-00007-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d04b4373792d844293f708f6528cf1afe3e8db9b379aad2eaecfe5c3cb7ee5a5
+size 4966188856
diff --git a/model-00008-of-00044.safetensors b/model-00008-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3e912422abe6e144b0fd97596aed5f0e6480d476
--- /dev/null
+++ b/model-00008-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0681a00fea32b2da6edf61a6c5756e81bcfbcc0cf8f9ec064706b61f3c46a6d
+size 4915791128
diff --git a/model-00009-of-00044.safetensors b/model-00009-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7b7be0c6e0021631adff2be298608107665fe574
--- /dev/null
+++ b/model-00009-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e7d9bf4b19c4a46e85f7011ca3a74297b105facabdc4270cbfeaaedc32f5a099
+size 4630611344
diff --git a/model-00010-of-00044.safetensors b/model-00010-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..6c0c08f54b7d61fdb2d20f2359849452ee74411a
--- /dev/null
+++ b/model-00010-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:760736d6ff3be31f0837a08c77b6dd0b07716e39070a9e53081a562d45bb4cc9
+size 4362142880
diff --git a/model-00011-of-00044.safetensors b/model-00011-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..03859c4fd4f5b2a67458df49dc256f50e231c56a
--- /dev/null
+++ b/model-00011-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5aa674d17abc5c72f877ce0ee4368215bc876d133af2ee92f194bc64927510c0
+size 4966188896
diff --git a/model-00012-of-00044.safetensors b/model-00012-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..9ded0e57643ba98ed45ff88efed7f23f497ea459
--- /dev/null
+++ b/model-00012-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03e00d8591bb14ca4b89b57f6f0b98be1f699bcc7f5cbbfbc7a7589d9f365f34
+size 4362142888
diff --git a/model-00013-of-00044.safetensors b/model-00013-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..6fc0378943e23c19b5e3b738a4099cf0f9fa231b
--- /dev/null
+++ b/model-00013-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad931dd89f6550bb433caa5bb93eed80c40c2e7b8f51c98302da5f60edb4c5b3
+size 4362142880
diff --git a/model-00014-of-00044.safetensors b/model-00014-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..446c263b51c7981183358adc242427186850a091
--- /dev/null
+++ b/model-00014-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a563497caeca286091d956821877c9b3d123f9930779a9f2b1a7be85093e1ec7
+size 4966188896
diff --git a/model-00015-of-00044.safetensors b/model-00015-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1095dcfaeb7e13534c9bdff3cfc89827a09b3f3b
--- /dev/null
+++ b/model-00015-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4856a9d1a7db05f3ffcf05a1be55ebe0ea0b57a21ec0ba191d873de58c6a0bdf
+size 4362142888
diff --git a/model-00016-of-00044.safetensors b/model-00016-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..b8f136d833c03ef5523cd20e098160cad95463d8
--- /dev/null
+++ b/model-00016-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02f5d608de8c229fc9dd41a4b3003c69e26a9687d3876077353acc3ec1f0f9dc
+size 4362142880
diff --git a/model-00017-of-00044.safetensors b/model-00017-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a6f4f452ddad4d6ab039520c4c54b88480753c39
--- /dev/null
+++ b/model-00017-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c69db3a3b324fda4fae4af859544b647a0acf33d9b9c241d4fc9b5c9ae8ed1e
+size 4966188896
diff --git a/model-00018-of-00044.safetensors b/model-00018-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..f0932f1320e5d39e6563241fe85602de826c25d0
--- /dev/null
+++ b/model-00018-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:081cf33c3540896d02569a2770d9402765773d92143c2efe1805314153ac9c5d
+size 4362142888
diff --git a/model-00019-of-00044.safetensors b/model-00019-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..86f7f68e8d3208dc6633591fd56d7f802d8a3a51
--- /dev/null
+++ b/model-00019-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8638eff3c1192fb17a4858fd00c52f27193772045ffc22de670f2497f83e9cd9
+size 4362142880
diff --git a/model-00020-of-00044.safetensors b/model-00020-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..5c384e823bec3e85008c82e692a80974f336c4dd
--- /dev/null
+++ b/model-00020-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:adff119059ae002e787cd97d7b0167554a05f87a85881204b36b9ad559f22863
+size 4966188896
diff --git a/model-00021-of-00044.safetensors b/model-00021-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..efdf911010eb89cbef73250db7d1a312b1dc89ec
--- /dev/null
+++ b/model-00021-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d26d727ab454581e43d13a59d4ad8c3a99c6f970aab86b89c61e6f269a476ef
+size 4362142888
diff --git a/model-00022-of-00044.safetensors b/model-00022-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..63eb865527daa43ddd64937058fd9361ee885f3c
--- /dev/null
+++ b/model-00022-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74c39c02eb45480c8e0882a3a4b9881af7e5d09d2d3f463c52f2a837c4b44e64
+size 4362142880
diff --git a/model-00023-of-00044.safetensors b/model-00023-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..84999a661f036c6ba2cd19615cd55a8f51622e54
--- /dev/null
+++ b/model-00023-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04f3865584321ec5621f3f3ea55220d51a0238516d98793c87bd57c6e61041f6
+size 4966188896
diff --git a/model-00024-of-00044.safetensors b/model-00024-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..adae97df173b36fb0c2801bcf80ab9ed3f6ac2f5
--- /dev/null
+++ b/model-00024-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdafb4418cc5ffac5f87a46c34063357b77726b7fb7d6ff3995b8c737c0a9961
+size 4362142888
diff --git a/model-00025-of-00044.safetensors b/model-00025-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7b3650e759cc1fe6eb2d13cb16821620e3b9fa1a
--- /dev/null
+++ b/model-00025-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:10279b256ccb6e649172e740dab73f9c6c9cdc91750701c66f65ca3ceedb6069
+size 4362142880
diff --git a/model-00026-of-00044.safetensors b/model-00026-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..feae60018e1ca7443fd02ee0e1018c09429bd5e3
--- /dev/null
+++ b/model-00026-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f35464d9cfc2af265ea507386ce1b12db668b7771bade9c529be807519a38978
+size 4966188896
diff --git a/model-00027-of-00044.safetensors b/model-00027-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..15045dd04cb87715c6ff4577af477c51724727fd
--- /dev/null
+++ b/model-00027-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63578d387cc9941bc7bff2024f937df2340d53a2151e7e5eb1ea06db44c4193c
+size 4362142888
diff --git a/model-00028-of-00044.safetensors b/model-00028-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..e69a0cfaba521b9023f790bc155cf2e94e3d41b6
--- /dev/null
+++ b/model-00028-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:88520d5fd37f9014c0d5301193bce090cce3f1d0b122abb354c34284d5b9f1cc
+size 4362142880
diff --git a/model-00029-of-00044.safetensors b/model-00029-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a152d5c689b9435d555586d12eb7a776d01c781e
--- /dev/null
+++ b/model-00029-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ba270319ac7bf916210c48952580277f580367bdce0d194ba2ee933bca3cd2a
+size 4966188896
diff --git a/model-00030-of-00044.safetensors b/model-00030-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..586a53e6b38c428c227d7eab8ca5da0004d61eca
--- /dev/null
+++ b/model-00030-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:467c34b866a722a7c5a3102c6772390edc6e22db117656fb18bc960af6b8a61f
+size 4362142888
diff --git a/model-00031-of-00044.safetensors b/model-00031-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..34371a48822a055f97181feae175c3a2ffa79c00
--- /dev/null
+++ b/model-00031-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:011337d0209849f9773cf07df778d7c7cabc7054c2891f48ad95a1e76478add2
+size 4932601336
diff --git a/model-00032-of-00044.safetensors b/model-00032-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..df35b92cbd235cc387ed2f301b88750261cce69b
--- /dev/null
+++ b/model-00032-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14809feeb5516733cfeab0eb8a2ebf7de1b3f01ef4431f332a6e487f611adef5
+size 4697753200
diff --git a/model-00033-of-00044.safetensors b/model-00033-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..7b0bb9c05824f122d12e59885535e6604e83adad
--- /dev/null
+++ b/model-00033-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd4b7dc7f5bdee3dc1727a6b277e47c62b08e5d63572151cecac90328c57712a
+size 4127361432
diff --git a/model-00034-of-00044.safetensors b/model-00034-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..919a0aa437196b33f8e07e1b3e790151a0f1820c
--- /dev/null
+++ b/model-00034-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:39db6314f988891b2bb183fd95b03ab34f280fade1d6bdde6e53746ba709b021
+size 4865525704
diff --git a/model-00035-of-00044.safetensors b/model-00035-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..4013f4f0e9c98adbda30ed1143cab959bac09f5f
--- /dev/null
+++ b/model-00035-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4154e1d7710b15bdd69703a60e6a26737ad49e2b12e1edd56da23329c4afda07
+size 4832137288
diff --git a/model-00036-of-00044.safetensors b/model-00036-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..dd125956768cb8c75e00d4e36327264102a4b48b
--- /dev/null
+++ b/model-00036-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7f39b81b2f5cb535831b16771597651a6b85731f81d0bc528c8b8ec56604e97
+size 4513303944
diff --git a/model-00037-of-00044.safetensors b/model-00037-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..72b354327018fd79dcc2bfe2a0128d24b773fed4
--- /dev/null
+++ b/model-00037-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa78d7d7ec86ff135f9fdd52be9f4746f42cef54a7f53eb812e984c95f74bfb8
+size 4966188896
diff --git a/model-00038-of-00044.safetensors b/model-00038-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..970724a7e06574fc050324d865247a56d91623b1
--- /dev/null
+++ b/model-00038-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c71612a0be36eaacb10051e6e4058b5a59913dc193d25fc07d4f9029e11997fa
+size 4362142888
diff --git a/model-00039-of-00044.safetensors b/model-00039-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1219c3ba4071301e29ebd6446e17782fe82c717f
--- /dev/null
+++ b/model-00039-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6199afdda299b805a404341f363e6870a46c425861dbe6ebe2f1e9aa257c5d75
+size 4362142880
diff --git a/model-00040-of-00044.safetensors b/model-00040-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..b6c7a3e7cdaff6465923895d6c22299174adf371
--- /dev/null
+++ b/model-00040-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:15633ace97f6fa29224022af4e8a66549344e0d55c8c24e577026fafb70fbce4
+size 4966188896
diff --git a/model-00041-of-00044.safetensors b/model-00041-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..1a6992facb1d2c4241dd89b4cbb780a08753d6a6
--- /dev/null
+++ b/model-00041-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:588b759a0eb33485e3b0f8caeae179531dc4cb6026f0ffcd4d05faf51f4c5b5c
+size 4362142888
diff --git a/model-00042-of-00044.safetensors b/model-00042-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..f5ec94fec599bf2e608e8caed06088141f87fd11
--- /dev/null
+++ b/model-00042-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae3c91f22829d818bee873f665ebdc028622f36a12d01c0d9bb2c11124dd985
+size 4362142880
diff --git a/model-00043-of-00044.safetensors b/model-00043-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..6ba31213ed78a9fa7bf97632c6cd9ce6e2f6633d
--- /dev/null
+++ b/model-00043-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a616b8fe07e2461e464b6ab105c866f5ababf6473f324aa8e0f107501f9aae20
+size 939557104
diff --git a/model-00044-of-00044.safetensors b/model-00044-of-00044.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..8c655e305b18ad834a76ae33c5e28d1f6d791ce7
--- /dev/null
+++ b/model-00044-of-00044.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8a59545fc83f5a4dc160bf2a985233907549d5a67bfa40dac0a8b05c564f34f3
+size 4202692736
diff --git a/model.safetensors.index.json b/model.safetensors.index.json
new file mode 100644
index 0000000000000000000000000000000000000000..7263e497a47e2e8665948374ce9f6d128e967336
--- /dev/null
+++ b/model.safetensors.index.json
@@ -0,0 +1,576 @@
+{
+ "metadata": {
+ "total_parameters": 49867145216,
+ "total_size": 199468580864
+ },
+ "weight_map": {
+ "lm_head.weight": "model-00044-of-00044.safetensors",
+ "model.embed_tokens.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.mlp.down_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.0.mlp.gate_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.0.mlp.up_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00044.safetensors",
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00044.safetensors",
+ "model.layers.1.input_layernorm.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.mlp.down_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.mlp.gate_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.mlp.up_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.post_attention_layernorm.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.self_attn.k_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.self_attn.o_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.self_attn.q_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.1.self_attn.v_proj.weight": "model-00002-of-00044.safetensors",
+ "model.layers.10.input_layernorm.weight": "model-00007-of-00044.safetensors",
+ "model.layers.10.mlp.down_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.10.mlp.gate_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.10.mlp.up_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.10.post_attention_layernorm.weight": "model-00007-of-00044.safetensors",
+ "model.layers.10.self_attn.k_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.10.self_attn.o_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.10.self_attn.q_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.10.self_attn.v_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.11.mlp.down_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.11.mlp.gate_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.11.mlp.up_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.11.post_attention_layernorm.weight": "model-00008-of-00044.safetensors",
+ "model.layers.12.input_layernorm.weight": "model-00008-of-00044.safetensors",
+ "model.layers.12.mlp.down_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.12.mlp.gate_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.12.mlp.up_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.12.post_attention_layernorm.weight": "model-00009-of-00044.safetensors",
+ "model.layers.12.self_attn.k_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.12.self_attn.o_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.12.self_attn.q_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.12.self_attn.v_proj.weight": "model-00008-of-00044.safetensors",
+ "model.layers.13.input_layernorm.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.mlp.down_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.13.mlp.gate_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.mlp.up_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.13.post_attention_layernorm.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.self_attn.k_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.self_attn.o_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.self_attn.q_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.13.self_attn.v_proj.weight": "model-00009-of-00044.safetensors",
+ "model.layers.14.input_layernorm.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.mlp.down_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.14.mlp.gate_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.mlp.up_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.post_attention_layernorm.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.self_attn.k_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.self_attn.o_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.self_attn.q_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.14.self_attn.v_proj.weight": "model-00010-of-00044.safetensors",
+ "model.layers.15.input_layernorm.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.mlp.down_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.mlp.gate_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.mlp.up_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.post_attention_layernorm.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.self_attn.k_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.self_attn.o_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.self_attn.q_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.15.self_attn.v_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.input_layernorm.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.mlp.down_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.16.mlp.gate_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.16.mlp.up_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.16.post_attention_layernorm.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.self_attn.k_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.self_attn.o_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.self_attn.q_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.16.self_attn.v_proj.weight": "model-00011-of-00044.safetensors",
+ "model.layers.17.input_layernorm.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.mlp.down_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.17.mlp.gate_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.mlp.up_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.17.post_attention_layernorm.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.self_attn.k_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.self_attn.o_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.self_attn.q_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.17.self_attn.v_proj.weight": "model-00012-of-00044.safetensors",
+ "model.layers.18.input_layernorm.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.mlp.down_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.18.mlp.gate_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.mlp.up_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.post_attention_layernorm.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.self_attn.k_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.self_attn.o_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.self_attn.q_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.18.self_attn.v_proj.weight": "model-00013-of-00044.safetensors",
+ "model.layers.19.input_layernorm.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.mlp.down_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.mlp.gate_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.mlp.up_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.post_attention_layernorm.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.self_attn.k_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.self_attn.o_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.self_attn.q_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.19.self_attn.v_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.2.input_layernorm.weight": "model-00002-of-00044.safetensors",
+ "model.layers.2.mlp.down_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.mlp.gate_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.mlp.up_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.post_attention_layernorm.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.self_attn.k_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.self_attn.o_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.self_attn.q_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.2.self_attn.v_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.20.input_layernorm.weight": "model-00014-of-00044.safetensors",
+ "model.layers.20.mlp.down_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.20.mlp.gate_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.20.mlp.up_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.20.post_attention_layernorm.weight": "model-00014-of-00044.safetensors",
+ "model.layers.20.self_attn.k_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.20.self_attn.o_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.20.self_attn.q_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.20.self_attn.v_proj.weight": "model-00014-of-00044.safetensors",
+ "model.layers.21.input_layernorm.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.mlp.down_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.21.mlp.gate_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.mlp.up_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.21.post_attention_layernorm.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.self_attn.k_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.self_attn.o_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.self_attn.q_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.21.self_attn.v_proj.weight": "model-00015-of-00044.safetensors",
+ "model.layers.22.input_layernorm.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.mlp.down_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.22.mlp.gate_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.mlp.up_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.post_attention_layernorm.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.self_attn.k_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.self_attn.o_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.self_attn.q_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.22.self_attn.v_proj.weight": "model-00016-of-00044.safetensors",
+ "model.layers.23.input_layernorm.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.mlp.down_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.mlp.gate_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.mlp.up_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.post_attention_layernorm.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.self_attn.k_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.self_attn.o_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.self_attn.q_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.23.self_attn.v_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.input_layernorm.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.mlp.down_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.24.mlp.gate_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.24.mlp.up_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.24.post_attention_layernorm.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.self_attn.k_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.self_attn.o_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.self_attn.q_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.24.self_attn.v_proj.weight": "model-00017-of-00044.safetensors",
+ "model.layers.25.input_layernorm.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.mlp.down_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.25.mlp.gate_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.mlp.up_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.25.post_attention_layernorm.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.self_attn.k_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.self_attn.o_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.self_attn.q_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.25.self_attn.v_proj.weight": "model-00018-of-00044.safetensors",
+ "model.layers.26.input_layernorm.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.mlp.down_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.26.mlp.gate_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.mlp.up_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.post_attention_layernorm.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.self_attn.k_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.self_attn.o_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.self_attn.q_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.26.self_attn.v_proj.weight": "model-00019-of-00044.safetensors",
+ "model.layers.27.input_layernorm.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.mlp.down_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.mlp.gate_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.mlp.up_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.post_attention_layernorm.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.self_attn.k_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.self_attn.o_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.self_attn.q_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.27.self_attn.v_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.input_layernorm.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.mlp.down_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.28.mlp.gate_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.28.mlp.up_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.28.post_attention_layernorm.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.self_attn.k_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.self_attn.o_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.self_attn.q_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.28.self_attn.v_proj.weight": "model-00020-of-00044.safetensors",
+ "model.layers.29.input_layernorm.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.mlp.down_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.29.mlp.gate_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.mlp.up_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.29.post_attention_layernorm.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.self_attn.k_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.self_attn.o_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.self_attn.q_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.29.self_attn.v_proj.weight": "model-00021-of-00044.safetensors",
+ "model.layers.3.input_layernorm.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.mlp.down_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.3.mlp.gate_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.mlp.up_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.3.post_attention_layernorm.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.self_attn.k_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.self_attn.o_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.self_attn.q_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.3.self_attn.v_proj.weight": "model-00003-of-00044.safetensors",
+ "model.layers.30.input_layernorm.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.mlp.down_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.30.mlp.gate_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.mlp.up_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.post_attention_layernorm.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.self_attn.k_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.self_attn.o_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.self_attn.q_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.30.self_attn.v_proj.weight": "model-00022-of-00044.safetensors",
+ "model.layers.31.input_layernorm.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.mlp.down_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.mlp.gate_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.mlp.up_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.post_attention_layernorm.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.self_attn.k_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.self_attn.o_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.self_attn.q_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.31.self_attn.v_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.input_layernorm.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.mlp.down_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.32.mlp.gate_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.32.mlp.up_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.32.post_attention_layernorm.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.self_attn.k_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.self_attn.o_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.self_attn.q_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.32.self_attn.v_proj.weight": "model-00023-of-00044.safetensors",
+ "model.layers.33.input_layernorm.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.mlp.down_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.33.mlp.gate_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.mlp.up_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.33.post_attention_layernorm.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.self_attn.k_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.self_attn.o_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.self_attn.q_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.33.self_attn.v_proj.weight": "model-00024-of-00044.safetensors",
+ "model.layers.34.input_layernorm.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.mlp.down_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.34.mlp.gate_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.mlp.up_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.post_attention_layernorm.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.self_attn.k_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.self_attn.o_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.self_attn.q_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.34.self_attn.v_proj.weight": "model-00025-of-00044.safetensors",
+ "model.layers.35.input_layernorm.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.mlp.down_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.mlp.gate_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.mlp.up_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.post_attention_layernorm.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.self_attn.k_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.self_attn.o_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.self_attn.q_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.35.self_attn.v_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.input_layernorm.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.mlp.down_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.36.mlp.gate_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.36.mlp.up_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.36.post_attention_layernorm.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.self_attn.k_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.self_attn.o_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.self_attn.q_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.36.self_attn.v_proj.weight": "model-00026-of-00044.safetensors",
+ "model.layers.37.input_layernorm.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.mlp.down_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.37.mlp.gate_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.mlp.up_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.37.post_attention_layernorm.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.self_attn.k_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.self_attn.o_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.self_attn.q_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.37.self_attn.v_proj.weight": "model-00027-of-00044.safetensors",
+ "model.layers.38.input_layernorm.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.mlp.down_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.38.mlp.gate_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.mlp.up_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.post_attention_layernorm.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.self_attn.k_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.self_attn.o_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.self_attn.q_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.38.self_attn.v_proj.weight": "model-00028-of-00044.safetensors",
+ "model.layers.39.input_layernorm.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.mlp.down_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.mlp.gate_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.mlp.up_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.post_attention_layernorm.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.self_attn.k_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.self_attn.o_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.self_attn.q_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.39.self_attn.v_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.4.input_layernorm.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.mlp.down_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.4.mlp.gate_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.mlp.up_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.post_attention_layernorm.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.self_attn.k_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.self_attn.o_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.self_attn.q_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.4.self_attn.v_proj.weight": "model-00004-of-00044.safetensors",
+ "model.layers.40.input_layernorm.weight": "model-00029-of-00044.safetensors",
+ "model.layers.40.mlp.down_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.40.mlp.gate_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.40.mlp.up_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.40.post_attention_layernorm.weight": "model-00029-of-00044.safetensors",
+ "model.layers.40.self_attn.k_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.40.self_attn.o_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.40.self_attn.q_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.40.self_attn.v_proj.weight": "model-00029-of-00044.safetensors",
+ "model.layers.41.input_layernorm.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.mlp.down_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.41.mlp.gate_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.mlp.up_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.41.post_attention_layernorm.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.self_attn.k_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.self_attn.o_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.self_attn.q_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.41.self_attn.v_proj.weight": "model-00030-of-00044.safetensors",
+ "model.layers.42.mlp.down_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.42.mlp.gate_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.42.mlp.up_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.42.post_attention_layernorm.weight": "model-00031-of-00044.safetensors",
+ "model.layers.43.mlp.down_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.43.mlp.gate_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.43.mlp.up_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.43.post_attention_layernorm.weight": "model-00031-of-00044.safetensors",
+ "model.layers.44.mlp.down_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.44.mlp.gate_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.44.mlp.up_proj.weight": "model-00031-of-00044.safetensors",
+ "model.layers.44.post_attention_layernorm.weight": "model-00031-of-00044.safetensors",
+ "model.layers.45.mlp.down_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.45.mlp.gate_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.45.mlp.up_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.45.post_attention_layernorm.weight": "model-00032-of-00044.safetensors",
+ "model.layers.46.mlp.down_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.46.mlp.gate_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.46.mlp.up_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.46.post_attention_layernorm.weight": "model-00032-of-00044.safetensors",
+ "model.layers.47.mlp.down_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.47.mlp.gate_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.47.mlp.up_proj.weight": "model-00032-of-00044.safetensors",
+ "model.layers.47.post_attention_layernorm.weight": "model-00032-of-00044.safetensors",
+ "model.layers.48.mlp.down_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.48.mlp.gate_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.48.mlp.up_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.48.post_attention_layernorm.weight": "model-00032-of-00044.safetensors",
+ "model.layers.49.mlp.down_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.49.mlp.gate_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.49.mlp.up_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.49.post_attention_layernorm.weight": "model-00033-of-00044.safetensors",
+ "model.layers.5.input_layernorm.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.mlp.down_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.mlp.gate_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.mlp.up_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.post_attention_layernorm.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.self_attn.k_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.self_attn.o_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.self_attn.q_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.5.self_attn.v_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.50.mlp.down_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.50.mlp.gate_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.50.mlp.up_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.50.post_attention_layernorm.weight": "model-00033-of-00044.safetensors",
+ "model.layers.51.mlp.down_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.51.mlp.gate_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.51.mlp.up_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.51.post_attention_layernorm.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.input_layernorm.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.mlp.down_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.52.mlp.gate_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.52.mlp.up_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.52.post_attention_layernorm.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.self_attn.k_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.self_attn.o_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.self_attn.q_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.52.self_attn.v_proj.weight": "model-00033-of-00044.safetensors",
+ "model.layers.53.mlp.down_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.53.mlp.gate_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.53.mlp.up_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.53.post_attention_layernorm.weight": "model-00034-of-00044.safetensors",
+ "model.layers.54.mlp.down_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.54.mlp.gate_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.54.mlp.up_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.54.post_attention_layernorm.weight": "model-00034-of-00044.safetensors",
+ "model.layers.55.mlp.down_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.55.mlp.gate_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.55.mlp.up_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.55.post_attention_layernorm.weight": "model-00034-of-00044.safetensors",
+ "model.layers.56.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.56.mlp.gate_proj.weight": "model-00034-of-00044.safetensors",
+ "model.layers.56.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.56.post_attention_layernorm.weight": "model-00034-of-00044.safetensors",
+ "model.layers.57.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.57.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.57.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.57.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.58.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.58.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.58.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.58.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.59.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.59.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.59.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.59.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.6.mlp.down_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.6.mlp.gate_proj.weight": "model-00005-of-00044.safetensors",
+ "model.layers.6.mlp.up_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.6.post_attention_layernorm.weight": "model-00005-of-00044.safetensors",
+ "model.layers.60.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.60.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.60.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.60.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.61.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.61.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.61.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.61.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.62.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.62.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.62.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.62.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.63.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.63.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.63.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.63.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.64.mlp.down_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.64.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.64.mlp.up_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.64.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.65.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.65.mlp.gate_proj.weight": "model-00035-of-00044.safetensors",
+ "model.layers.65.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.65.post_attention_layernorm.weight": "model-00035-of-00044.safetensors",
+ "model.layers.66.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.66.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.66.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.66.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.67.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.67.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.67.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.67.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.68.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.68.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.68.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.68.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.69.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.69.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.69.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.69.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.7.mlp.down_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.7.mlp.gate_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.7.mlp.up_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.7.post_attention_layernorm.weight": "model-00006-of-00044.safetensors",
+ "model.layers.70.mlp.down_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.70.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.70.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.70.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.input_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.mlp.down_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.71.mlp.gate_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.mlp.up_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.post_attention_layernorm.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.self_attn.k_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.self_attn.o_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.self_attn.q_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.71.self_attn.v_proj.weight": "model-00036-of-00044.safetensors",
+ "model.layers.72.input_layernorm.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.mlp.down_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.mlp.gate_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.mlp.up_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.post_attention_layernorm.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.self_attn.k_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.self_attn.o_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.self_attn.q_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.72.self_attn.v_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.input_layernorm.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.mlp.down_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.73.mlp.gate_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.73.mlp.up_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.73.post_attention_layernorm.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.self_attn.k_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.self_attn.o_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.self_attn.q_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.73.self_attn.v_proj.weight": "model-00037-of-00044.safetensors",
+ "model.layers.74.input_layernorm.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.mlp.down_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.74.mlp.gate_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.mlp.up_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.74.post_attention_layernorm.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.self_attn.k_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.self_attn.o_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.self_attn.q_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.74.self_attn.v_proj.weight": "model-00038-of-00044.safetensors",
+ "model.layers.75.input_layernorm.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.mlp.down_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.75.mlp.gate_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.mlp.up_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.post_attention_layernorm.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.self_attn.k_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.self_attn.o_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.self_attn.q_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.75.self_attn.v_proj.weight": "model-00039-of-00044.safetensors",
+ "model.layers.76.input_layernorm.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.mlp.down_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.mlp.gate_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.mlp.up_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.post_attention_layernorm.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.self_attn.k_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.self_attn.o_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.self_attn.q_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.76.self_attn.v_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.input_layernorm.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.mlp.down_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.77.mlp.gate_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.77.mlp.up_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.77.post_attention_layernorm.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.self_attn.k_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.self_attn.o_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.self_attn.q_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.77.self_attn.v_proj.weight": "model-00040-of-00044.safetensors",
+ "model.layers.78.input_layernorm.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.mlp.down_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.78.mlp.gate_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.mlp.up_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.78.post_attention_layernorm.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.self_attn.k_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.self_attn.o_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.self_attn.q_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.78.self_attn.v_proj.weight": "model-00041-of-00044.safetensors",
+ "model.layers.79.input_layernorm.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.mlp.down_proj.weight": "model-00043-of-00044.safetensors",
+ "model.layers.79.mlp.gate_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.mlp.up_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.post_attention_layernorm.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.self_attn.k_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.self_attn.o_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.self_attn.q_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.79.self_attn.v_proj.weight": "model-00042-of-00044.safetensors",
+ "model.layers.8.input_layernorm.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.mlp.down_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.8.mlp.gate_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.mlp.up_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.post_attention_layernorm.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.self_attn.k_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.self_attn.o_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.self_attn.q_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.8.self_attn.v_proj.weight": "model-00006-of-00044.safetensors",
+ "model.layers.9.input_layernorm.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.mlp.down_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.mlp.gate_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.mlp.up_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.post_attention_layernorm.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.self_attn.k_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.self_attn.o_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.self_attn.q_proj.weight": "model-00007-of-00044.safetensors",
+ "model.layers.9.self_attn.v_proj.weight": "model-00007-of-00044.safetensors",
+ "model.norm.weight": "model-00043-of-00044.safetensors"
+ }
+}
diff --git a/modeling_decilm.py b/modeling_decilm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a079b913ae9c774b75aea8fc595a1d95b0215f6
--- /dev/null
+++ b/modeling_decilm.py
@@ -0,0 +1,1684 @@
+# coding=utf-8
+# Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved.
+#
+# This code for Nvidia's model is based on the Llama modeling code by HuggingFace,
+# which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and
+# OPT implementations in this library.
+# Sliding window code based on Gemma2 by Google.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import GenerationConfig
+from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from .block_config import AttentionConfig, FFNConfig
+from .configuration_decilm import DeciLMConfig
+from .transformers_4_44_2__activations import ACT2FN
+from .transformers_4_44_2__cache_utils import Cache, StaticCache
+from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter
+from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import _flash_attention_forward
+from .transformers_4_44_2__modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS
+from .variable_cache import VariableCache
+
+MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM"
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DeciLMConfig"
+
+
+def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ min_dtype: float,
+ cache_position: torch.Tensor,
+ batch_size: int,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ min_dtype (`float`):
+ The minimum value representable with the dtype `dtype`.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class DeciLMRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DeciLMRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm)
+
+
+class DeciLMRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[DeciLMConfig] = None,
+ ):
+ super().__init__()
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.45"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding):
+ """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
+ "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
+ )
+ kwargs["rope_type"] = "linear"
+ super().__init__(*args, **kwargs)
+
+
+class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding):
+ """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, *args, **kwargs):
+ logger.warning_once(
+ "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
+ "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
+ "__init__)."
+ )
+ kwargs["rope_type"] = "dynamic"
+ super().__init__(*args, **kwargs)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class DeciLMMLP(nn.Module):
+ def __init__(self,
+ config: DeciLMConfig,
+ ffn_config: FFNConfig,
+ ):
+ super().__init__()
+ self.config = config
+ self.ffn_config = ffn_config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = _ffn_mult_to_intermediate_size(
+ ffn_config.ffn_mult, config.hidden_size) # DeciLM-specific code
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ if ffn_config.sparsify is not None:
+ self.register_full_backward_hook(sparsity_backward_hook)
+
+ def forward(self, x):
+ if self.config.pretraining_tp > 1:
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat(
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
+ )
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class DeciLMAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self,
+ config: DeciLMConfig,
+ attention_config: AttentionConfig,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.attention_config = attention_config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code
+ self.num_key_value_heads = self.num_heads // self.num_key_value_groups # DeciLM-specific code
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+
+ # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = DeciLMRotaryEmbedding(config=self.config)
+
+ if attention_config.sparsify is not None:
+ self.register_full_backward_hook(sparsity_backward_hook)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class DeciLMFlashAttention2(DeciLMAttention):
+ """
+ DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ self.sliding_window = self.attention_config.prefill_sliding_window
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DeciLMRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=self.sliding_window,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+DECILM_ATTENTION_CLASSES = {
+ "eager": DeciLMAttention,
+ "flash_attention_2": DeciLMFlashAttention2,
+}
+
+
+class DeciLMDecoderLayer(nn.Module):
+ # DeciLM-specific code
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.block_config = config.block_configs[layer_idx]
+ self.attention_config = self.block_config.attention
+ self.ffn_config = self.block_config.ffn
+
+ if not self.attention_config.no_op:
+ self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ if not self.attention_config.replace_with_linear:
+ self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation](
+ config=config, attention_config=self.attention_config, layer_idx=layer_idx)
+ else:
+ self.self_attn = DeciLMLinearAttention(config)
+
+ if not self.ffn_config.no_op:
+ self.post_attention_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ if not self.ffn_config.replace_with_linear:
+ self.mlp = DeciLMMLP(config, self.ffn_config)
+ else:
+ self.mlp = DeciLMLinearMLP(config)
+
+ self.is_sliding = self.attention_config.is_sliding
+ self.sliding_window = self.attention_config.prefill_sliding_window
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ if self.attention_config.unshifted_sink and self.attention_config.is_sink:
+ attention_mask = self._unshifted_sink_mask(
+ attention_mask, hidden_states,
+ self.attention_config.window_length, self.attention_config.num_sink_tokens)
+ else:
+ attention_mask = self._gemma2_window_mask(attention_mask, hidden_states, past_key_value)
+
+ self_attn_weights = None
+ present_key_value = past_key_value
+ if self.attention_config.no_op:
+ pass
+ elif self.attention_config.replace_with_linear:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(hidden_states)
+ hidden_states = residual + hidden_states
+ else:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ if not self.ffn_config.no_op:
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ def _gemma2_window_mask(self,
+ attention_mask: Optional[torch.Tensor],
+ hidden_states: torch.Tensor,
+ past_key_value: Optional[VariableCache],
+ ) -> Optional[torch.Tensor]:
+ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
+ # Flash-attn is a 2D tensor
+ if self.config._attn_implementation == "flash_attention_2":
+ if past_key_value is not None: # when decoding
+ attention_mask = attention_mask[:, -self.sliding_window:]
+ else:
+ min_dtype = torch.finfo(hidden_states.dtype).min
+ sliding_window_mask = torch.tril(
+ torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
+ )
+ attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
+ if attention_mask.shape[-1] <= 1: # when decoding
+ attention_mask = attention_mask[:, :, :, -self.sliding_window:]
+ return attention_mask
+
+ def _unshifted_sink_mask(self,
+ attention_mask: torch.Tensor,
+ hidden_states: torch.Tensor,
+ window_length: int,
+ num_sink_tokens: Optional[int],
+ ) -> torch.Tensor:
+ assert self.config._attn_implementation == "eager", "Unshifted sink is only supported in 'eager' mode."
+ assert attention_mask is not None, "The attention mask seems to not be prepared"
+
+ attention_mask = attention_mask.clone()
+ min_dtype = torch.finfo(hidden_states.dtype).min
+
+ if window_length == 0:
+ attention_mask = torch.full_like(attention_mask, fill_value=min_dtype)
+ else:
+ query_length = attention_mask.shape[-2]
+ is_decode = (query_length == 1)
+ if is_decode:
+ attention_mask[:, :, :, :-window_length] = min_dtype
+ else:
+ sliding_window_mask = torch.tril(
+ torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length
+ )
+ attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
+
+ attention_mask[:, :, :, :num_sink_tokens] = 0
+ return attention_mask
+
+
+DECILM_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`DeciLMConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
+ DECILM_START_DOCSTRING,
+)
+class DeciLMPreTrainedModel(PreTrainedModel):
+ config_class = DeciLMConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DeciLMDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True
+ _supports_quantized_cache = False
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _prepare_generation_config(
+ self,
+ generation_config: Optional[GenerationConfig],
+ *args,
+ **kwargs,
+ ) -> tuple[GenerationConfig, dict]:
+ # DeciLM-specific code
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, *args, **kwargs)
+ generation_config.cache_implementation = "variable"
+ NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
+ return generation_config, model_kwargs
+
+
+DECILM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`VariableCache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ If passed to the forward function, past_key_values must be a VariableCache object (see imports).
+ For generation purposes, this is already handled inside model.generate().
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
+ DECILM_START_DOCSTRING,
+)
+class DeciLMModel(DeciLMPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
+
+ Args:
+ config: DeciLMConfig
+ """
+
+ def __init__(self, config: DeciLMConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [DeciLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = DeciLMRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ is_legacy_cache_format = (past_key_values is not None) and not isinstance(past_key_values, Cache)
+ if is_legacy_cache_format:
+ raise NotImplementedError("DeciLMModel does not support legacy cache format, please use a newer "
+ "transformers version or use VariableCache explicitly (see import in this file).")
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ) and all([not layer.is_sliding for layer in self.layers]):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_length()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = DeciLMModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Return:
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ if past_key_values is not None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0]:]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if model_inputs["inputs_embeds"] is not None:
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
+ device = model_inputs["inputs_embeds"].device
+ else:
+ batch_size, sequence_length = model_inputs["input_ids"].shape
+ device = model_inputs["input_ids"].device
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = torch.finfo(dtype).min
+
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ bos_token_id: Optional[torch.Tensor] = None,
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
+ ) -> torch.LongTensor:
+ """
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
+ """
+ input_ids = super()._maybe_initialize_input_ids_for_generation(
+ inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
+ if (
+ "inputs_embeds" in model_kwargs
+ and input_ids is not None
+ and input_ids.shape[1] == 0
+ ):
+ batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
+ input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
+ return input_ids
+
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ """
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
+ """
+ only_passed_inputs_embeds = (
+ "inputs_embeds" in kwargs and
+ "input_ids" not in kwargs and
+ inputs is None
+ )
+ if only_passed_inputs_embeds:
+ input_sequence_length = kwargs["inputs_embeds"].shape[1]
+
+ generation_output = super().generate(inputs=inputs, *args, **kwargs)
+
+ if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
+ generation_output = generation_output[:, input_sequence_length:]
+
+ return generation_output
+
+
+@add_start_docstrings(
+ """
+ The DeciLM Model transformer with a sequence classification head on top (linear layer).
+
+ [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ DECILM_START_DOCSTRING,
+)
+class DeciLMForSequenceClassification(DeciLMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = DeciLMModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DECILM_START_DOCSTRING,
+)
+class DeciLMForQuestionAnswering(DeciLMPreTrainedModel):
+ base_model_prefix = "transformer"
+
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = DeciLMModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.transformer.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ DECILM_START_DOCSTRING,
+)
+class DeciLMForTokenClassification(DeciLMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = DeciLMModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+########################################################################
+# DeciLM-specific code
+########################################################################
+
+
+def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
+ # DeciLM-specific code
+ intermediate_size = int(2 * ffn_mult * n_embd / 3)
+ return _find_multiple(intermediate_size, 256)
+
+
+def _find_multiple(n: int, k: int) -> int:
+ # DeciLM-specific code
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+class DeciLMLinearMLP(nn.Module):
+ # DeciLM-specific code
+ def __init__(self,
+ config: DeciLMConfig,
+ ):
+ super().__init__()
+ self.linear_mlp = nn.Linear(in_features=config.hidden_size,
+ out_features=config.hidden_size,
+ bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear_mlp.forward(x)
+
+
+class DeciLMLinearAttention(nn.Module):
+ # DeciLM-specific code
+ def __init__(self,
+ config: DeciLMConfig,
+ ):
+ super().__init__()
+ self.linear_attn = nn.Linear(in_features=config.hidden_size,
+ out_features=config.hidden_size,
+ bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear_attn.forward(x)
+
+
+def sparsity_backward_hook(*args, **kwargs):
+ raise NotImplementedError("No support for sparsity when training HF DeciLM (inference is ok though)")
diff --git a/transformers_4_44_2__activations.py b/transformers_4_44_2__activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca0fb16d4236a7599f2523e9e45aa28fc3ac5e69
--- /dev/null
+++ b/transformers_4_44_2__activations.py
@@ -0,0 +1,239 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+
+import torch
+from packaging import version
+from torch import Tensor, nn
+
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PytorchGELUTanh(nn.Module):
+ """
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
+ https://arxiv.org/abs/1606.08415.
+
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
+ match due to rounding errors.
+ """
+
+ def __init__(self):
+ super().__init__()
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
+ raise ImportError(
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
+ "PytorchGELUTanh. Please upgrade torch."
+ )
+
+ def forward(self, input: Tensor) -> Tensor:
+ return nn.functional.gelu(input, approximate="tanh")
+
+
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class GELUActivation(nn.Module):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, use_gelu_python: bool = False):
+ super().__init__()
+ if use_gelu_python:
+ self.act = self._gelu_python
+ else:
+ self.act = nn.functional.gelu
+
+ def _gelu_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class FastGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+class QuickGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input * torch.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Module):
+ """
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://arxiv.org/abs/2004.09602.
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created.
+
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, min: float, max: float):
+ if min > max:
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
+
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.clip(gelu(x), self.min, self.max)
+
+
+class AccurateGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
+ https://github.com/hendrycks/GELUs
+
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.precomputed_constant = math.sqrt(2 / math.pi)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
+
+
+class MishActivation(nn.Module):
+ """
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
+ """
+
+ def __init__(self):
+ super().__init__()
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
+ self.act = self._mish_python
+ else:
+ self.act = nn.functional.mish
+
+ def _mish_python(self, input: Tensor) -> Tensor:
+ return input * torch.tanh(nn.functional.softplus(input))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class LinearActivation(nn.Module):
+ """
+ Applies the linear activation function, i.e. forwarding input directly to output.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input
+
+
+class LaplaceActivation(nn.Module):
+ """
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
+ https://arxiv.org/abs/2209.10655
+
+ Inspired by squared relu, but with bounded range and gradient for better stability
+ """
+
+ def forward(self, input, mu=0.707107, sigma=0.282095):
+ input = (input - mu).div(sigma * math.sqrt(2.0))
+ return 0.5 * (1.0 + torch.erf(input))
+
+
+class ReLUSquaredActivation(nn.Module):
+ """
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
+ """
+
+ def forward(self, input):
+ relu_applied = nn.functional.relu(input)
+ squared = torch.square(relu_applied)
+ return squared
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+ACT2CLS = {
+ "gelu": GELUActivation,
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
+ "gelu_fast": FastGELUActivation,
+ "gelu_new": NewGELUActivation,
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
+ "gelu_pytorch_tanh": PytorchGELUTanh,
+ "gelu_accurate": AccurateGELUActivation,
+ "laplace": LaplaceActivation,
+ "leaky_relu": nn.LeakyReLU,
+ "linear": LinearActivation,
+ "mish": MishActivation,
+ "quick_gelu": QuickGELUActivation,
+ "relu": nn.ReLU,
+ "relu2": ReLUSquaredActivation,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": nn.SiLU,
+ "swish": nn.SiLU,
+ "tanh": nn.Tanh,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+
+
+# For backwards compatibility with: from activations import gelu_python
+gelu_python = get_activation("gelu_python")
+gelu_new = get_activation("gelu_new")
+gelu = get_activation("gelu")
+gelu_fast = get_activation("gelu_fast")
+quick_gelu = get_activation("quick_gelu")
+silu = get_activation("silu")
+mish = get_activation("mish")
+linear_act = get_activation("linear")
diff --git a/transformers_4_44_2__cache_utils.py b/transformers_4_44_2__cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1374a2b0ba924e2277ec415b40392efa4100150
--- /dev/null
+++ b/transformers_4_44_2__cache_utils.py
@@ -0,0 +1,1347 @@
+import copy
+import importlib.metadata
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging import version
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import is_torchdynamo_compiling, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Cache(torch.nn.Module):
+ """
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+ cache to be created.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states, if there is any."""
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
+
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
+ # Cache without size limit -> all cache is usable
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
+ max_length = self.get_max_length()
+ previous_seq_length = self.get_seq_length(layer_idx)
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
+ return max_length - new_seq_length
+ return previous_seq_length
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ @property
+ def seen_tokens(self):
+ logger.warning_once(
+ "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+ "model input instead."
+ )
+ if hasattr(self, "_seen_tokens"):
+ return self._seen_tokens
+ else:
+ return None
+
+
+@dataclass
+class CacheConfig:
+ """
+ Base class for cache configs
+ """
+
+ cache_implementation: None
+
+ @classmethod
+ def from_dict(cls, config_dict, **kwargs):
+ """
+ Constructs a CacheConfig instance from a dictionary of parameters.
+ Args:
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
+ **kwargs: Additional keyword arguments to override dictionary values.
+
+ Returns:
+ CacheConfig: Instance of CacheConfig constructed from the dictionary.
+ """
+ config = cls(**config_dict)
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+ return config
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default
+ `QuantizationConfig()` is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ config_dict = self.to_dict()
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ writer.write(json_string)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return copy.deepcopy(self.__dict__)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def to_json_string(self):
+ """
+ Serializes this instance to a JSON formatted string.
+ Returns:
+ str: JSON formatted string representing the configuration instance.
+ """
+ return json.dumps(self.__dict__, indent=2) + "\n"
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # Remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+
+class DynamicCache(Cache):
+ """
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+ `[batch_size, num_heads, seq_len, head_dim]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = DynamicCache()
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ """
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self):
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def __iter__(self):
+ """
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+ keys and values
+ """
+ for layer_idx in range(len(self)):
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+ def __len__(self):
+ """
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+ to the number of layers in the model.
+ """
+ return len(self.key_cache)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
+ backward compatibility."""
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ return legacy_cache
+
+ @classmethod
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
+ backward compatibility."""
+ cache = cls()
+ if past_key_values is not None:
+ for layer_idx in range(len(past_key_values)):
+ key_states, value_states = past_key_values[layer_idx]
+ cache.update(key_states, value_states, layer_idx)
+ return cache
+
+ def crop(self, max_length: int):
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
+ # In case it is negative
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+
+ if self.get_seq_length() <= max_length:
+ return
+
+ self._seen_tokens = max_length
+ for idx in range(len(self.key_cache)):
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
+
+ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`"""
+ out = []
+ for i in range(0, full_batch_size, split_size):
+ current_split = DynamicCache()
+ current_split._seen_tokens = self._seen_tokens
+ current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
+ current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
+ out.append(current_split)
+ return out
+
+ @classmethod
+ def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
+ `generation.utils`"""
+ cache = cls()
+ for idx in range(len(splits[0])):
+ layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
+ layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
+ cache.update(layer_keys, layer_values, idx)
+ return cache
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
+ for layer_idx in range(len(self)):
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
+ for layer_idx in range(len(self)):
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
+
+
+class OffloadedCache(DynamicCache):
+ """
+ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
+ Useful for generating from models with very long context.
+
+ In addition to the default CUDA stream, where all forward() computations happen,
+ this class uses another stream, the prefetch stream, which it creates itself.
+ Since scheduling of operations on separate streams happens independently, this class uses
+ the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
+ The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
+ ensure the eviction is scheduled after all computations on that cache are finished.
+ """
+
+ def __init__(self) -> None:
+ if not torch.cuda.is_available():
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.beam_idx = None # used to delay beam search operations
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ original_device = self.original_device[layer_idx]
+ self.prefetch_stream.synchronize()
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ # Now deal with beam search ops which were delayed
+ if self.beam_idx is not None:
+ self.beam_idx = self.beam_idx.to(original_device)
+ key_tensor = key_tensor.index_select(0, self.beam_idx)
+ value_tensor = value_tensor.index_select(0, self.beam_idx)
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Saves the beam indices and reorders the cache when the tensor is back to its device."""
+ # We delay this operation until the tensors are back to their original
+ # device because performing torch.index_select on the CPU is very slow
+ del self.beam_idx
+ self.beam_idx = beam_idx.clone()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ self.evict_previous_layer(layer_idx)
+ else:
+ key_tensor, value_tensor = self[layer_idx]
+ self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
+ # if a method is not supposed to be supported in a subclass we should set it to None
+ from_legacy_cache = None
+
+ to_legacy_cache = None
+
+
+class SinkCache(Cache):
+ """
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+ `[batch_size, num_heads, seq_len, head_dim]`.
+
+ Parameters:
+ window_length (`int`):
+ The length of the context window.
+ num_sink_tokens (`int`):
+ The number of sink tokens. See the original paper for more information.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+ super().__init__()
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ self.window_length = window_length
+ self.num_sink_tokens = num_sink_tokens
+ self.cos_sin_rerotation_cache = {}
+ self._cos_cache = None
+ self._sin_cache = None
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ @staticmethod
+ def _rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_key_rotary_pos_emb(
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+ ) -> torch.Tensor:
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+ return rotated_key_states
+
+ def _get_rerotation_cos_sin(
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
+ # Upcast to float32 temporarily for better accuracy
+ cos = cos.to(torch.float32)
+ sin = sin.to(torch.float32)
+
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+ self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
+ )
+ return self.cos_sin_rerotation_cache[key_states.shape[-2]]
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+ return self.window_length
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
+ rotation as the tokens are shifted.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
+ # with partially rotated position embeddings, like Phi or Persimmon.
+ sin = cache_kwargs.get("sin")
+ cos = cache_kwargs.get("cos")
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
+ using_rope = cos is not None and sin is not None
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
+ if using_rope and layer_idx == 0:
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
+ # after all RoPE models have a llama-like cache utilization.
+ if cos.dim() == 2:
+ self._cos_cache = cos
+ self._sin_cache = sin
+ else:
+ if self._cos_cache is None:
+ self._cos_cache = cos[0, ...]
+ self._sin_cache = sin[0, ...]
+ elif self._cos_cache.shape[0] < self.window_length:
+ self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
+ self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
+
+ # [bsz, num_heads, seq_len, head_dim]
+ if len(self.key_cache) <= layer_idx:
+ # Empty cache
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+ # Growing cache
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ else:
+ # Shifting cache
+ keys_to_keep = self.key_cache[layer_idx][
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
+ ]
+
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+ if using_rope:
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+ key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
+ )
+ if partial_rotation_size is not None:
+ keys_to_keep, keys_pass = (
+ keys_to_keep[..., :partial_rotation_size],
+ keys_to_keep[..., partial_rotation_size:],
+ )
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+ if partial_rotation_size is not None:
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+ values_to_keep = self.value_cache[layer_idx][
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+ ]
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (*optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
+ super().__init__()
+ self.max_batch_size = max_batch_size
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+
+ self.dtype = dtype if dtype is not None else torch.float32
+ self.num_key_value_heads = (
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+ )
+
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
+ cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+ for idx in range(config.num_hidden_layers):
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+ # Notes:
+ # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
+ # it is not needed anyway)
+ # 2. `torch.export()` requires mutations to be registered as buffers.
+ if not is_torchdynamo_compiling():
+ self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
+ self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
+ new_layer_key_cache = getattr(self, f"key_cache_{idx}")
+ new_layer_value_cache = getattr(self, f"value_cache_{idx}")
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ cache_position = cache_kwargs.get("cache_position")
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+
+ if cache_position is None:
+ k_out.copy_(key_states)
+ v_out.copy_(value_states)
+ else:
+ # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
+ # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
+ # operation, that avoids copies and uses less memory.
+ try:
+ k_out.index_copy_(2, cache_position, key_states)
+ v_out.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ return k_out, v_out
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+ return self._seen_tokens
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+ return self.max_cache_len
+
+ def reset(self):
+ self._seen_tokens = 0
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+
+class SlidingWindowCache(StaticCache):
+ """
+ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
+ Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
+ if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
+ we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
+
+ The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
+
+ indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
+ tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
+
+ We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (*optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
+ super().__init__(config, max_batch_size, max_cache_len, device, dtype)
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
+ raise ValueError(
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
+ "config and it's not set to None."
+ )
+ max_cache_len = min(config.sliding_window, max_cache_len)
+ super().__init__(
+ config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
+ )
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor]:
+ cache_position = cache_kwargs.get("cache_position")
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+
+ # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
+ if cache_position.shape[0] > self.max_cache_len:
+ k_out = key_states[:, :, -self.max_cache_len :, :]
+ v_out = value_states[:, :, -self.max_cache_len :, :]
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
+ return key_states, value_states
+
+ slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
+ cache_position = cache_position.clamp(0, self.max_cache_len - 1)
+ to_shift = cache_position >= self.max_cache_len - 1
+ indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
+
+ k_out = k_out[:, :, indices]
+ v_out = v_out[:, :, indices]
+
+ try:
+ cache_position.to(device=k_out.device)
+ k_out.index_copy_(2, cache_position, key_states)
+ v_out.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+
+ return k_out, v_out
+
+ def get_max_length(self) -> Optional[int]:
+ # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
+ return None
+
+ def reset(self):
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+
+class EncoderDecoderCache(Cache):
+ """
+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
+ cross-attention caches.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
+
+ >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
+
+ >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
+ >>> self_attention_cache = DynamicCache()
+ >>> cross_attention_cache = DynamicCache()
+ >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+
+ """
+
+ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
+ super().__init__()
+ self.self_attention_cache = self_attention_cache
+ self.cross_attention_cache = cross_attention_cache
+
+ self.is_updated = {}
+ for layer_idx in range(len(cross_attention_cache.key_cache)):
+ self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ """
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self):
+ return (
+ self.self_attention_cache.key_cache[layer_idx],
+ self.self_attention_cache.value_cache[layer_idx],
+ self.cross_attention_cache.key_cache[layer_idx],
+ self.cross_attention_cache.value_cache[layer_idx],
+ )
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def __len__(self):
+ """
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+ to the number of layers in the model.
+ """
+ return len(self.self_attention_cache)
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
+ legacy_cache = ()
+ if len(self.cross_attention_cache) > 0:
+ for self_attn, cross_attn in zip(
+ self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
+ ):
+ legacy_cache += (self_attn + cross_attn,)
+ else:
+ legacy_cache = self.self_attention_cache.to_legacy_cache()
+ return legacy_cache
+
+ @classmethod
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ ) -> "EncoderDecoderCache":
+ """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
+ cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
+ if past_key_values is not None:
+ for layer_idx in range(len(past_key_values)):
+ key_states, value_states = past_key_values[layer_idx][:2]
+ cache.self_attention_cache.update(key_states, value_states, layer_idx)
+ if len(past_key_values[layer_idx]) > 2:
+ key_states, value_states = past_key_values[layer_idx][2:]
+ cache.cross_attention_cache.update(key_states, value_states, layer_idx)
+ cache.is_updated[layer_idx] = True
+ return cache
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.self_attention_cache.key_cache) <= layer_idx:
+ return 0
+ return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def reset(self):
+ if hasattr(self.self_attention_cache, "reset"):
+ self.self_attention_cache.reset()
+ if hasattr(self.cross_attention_cache, "reset"):
+ self.cross_attention_cache.reset()
+ elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
+ raise ValueError(
+ "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
+ "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
+ f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
+ f"{self.cross_attention_cache.__str__()} for the cross attention cache."
+ )
+ for layer_idx in self.is_updated:
+ self.is_updated[layer_idx] = False
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ self.self_attention_cache.reorder_cache(beam_idx)
+ self.cross_attention_cache.reorder_cache(beam_idx)
+
+ def check_dynamic_cache(self, method: str):
+ if not (
+ isinstance(self.self_attention_cache, DynamicCache)
+ and isinstance(self.cross_attention_cache, DynamicCache)
+ ):
+ raise ValueError(
+ f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
+ f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
+ )
+
+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
+ def crop(self, maximum_length: int):
+ """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
+ self.check_dynamic_cache(self.crop.__name__)
+ self.self_attention_cache.crop(maximum_length)
+
+ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`"""
+ self.check_dynamic_cache(self.batch_split.__name__)
+ self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
+ cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
+
+ out = []
+ for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
+ out.append(EncoderDecoderCache(self_attn, cross_attn))
+ return out
+
+ @classmethod
+ def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
+ `generation.utils`"""
+ self_attention_cache = DynamicCache()
+ cross_attention_cache = DynamicCache()
+ for idx in range(len(splits[0])):
+ layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
+ layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
+ self_attention_cache.update(layer_keys, layer_values, idx)
+
+ layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
+ layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
+ cross_attention_cache.update(layer_keys, layer_values, idx)
+ return cls(self_attention_cache, cross_attention_cache)
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
+ self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
+ self.self_attention_cache.batch_repeat_interleave(repeats)
+ self.cross_attention_cache.batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
+ self.check_dynamic_cache(self.batch_select_indices.__name__)
+ self.self_attention_cache.batch_select_indices(indices)
+ self.cross_attention_cache.batch_select_indices(indices)
+
+
+class HybridCache(Cache):
+ """
+ Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
+ and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
+ and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
+
+ Parameters:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (*optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
+
+ >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
+ super().__init__()
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
+ raise ValueError(
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
+ "config and it's not set to None."
+ )
+ self.max_cache_len = max_cache_len
+ self.max_batch_size = max_batch_size
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+
+ self.dtype = dtype if dtype is not None else torch.float32
+ self.num_key_value_heads = (
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+ )
+ self.is_sliding = torch.tensor(
+ [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
+ )
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
+ sliding_cache_shape = (
+ max_batch_size,
+ self.num_key_value_heads,
+ min(config.sliding_window, max_cache_len),
+ self.head_dim,
+ )
+ for i in range(config.num_hidden_layers):
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache.
+ cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ if cache_position.shape[0] > max_cache_len:
+ k_out = key_states[:, :, -max_cache_len:, :]
+ v_out = value_states[:, :, -max_cache_len:, :]
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
+ return key_states, value_states
+
+ slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
+ to_shift = cache_position >= max_cache_len - 1
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
+ k_out = k_out[:, :, indices]
+ v_out = v_out[:, :, indices]
+
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ return k_out, v_out
+
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ self.key_cache[layer_idx] = k_out
+ self.value_cache[layer_idx] = v_out
+ return k_out, v_out
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor]:
+ cache_position = cache_kwargs.get("cache_position")
+ sliding_window = cache_kwargs.get("sliding_window")
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+ if sliding_window:
+ update_fn = self._sliding_update
+ else:
+ update_fn = self._static_update
+
+ return update_fn(
+ cache_position,
+ layer_idx,
+ key_states,
+ value_states,
+ k_out,
+ v_out,
+ k_out.shape[2],
+ )
+
+ def get_max_length(self) -> Optional[int]:
+ # in theory there is no limit because the sliding window size is fixed
+ # no matter how long the sentence is
+ return self.max_cache_len
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
+ return None
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+
+class MambaCache:
+ """
+ Cache for mamba model which does not have attention mechanism and key value states.
+
+ Arguments:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ dtype (*optional*, defaults to `torch.float16`):
+ The default `dtype` to use when initializing the layer.
+ device (`torch.device`, *optional*):
+ The device on which the cache should be initialized. Should be the same as the layer.
+
+ Attributes:
+ dtype: (`torch.dtype`):
+ The default `dtype` used to initializing the cache.
+ intermediate_size: (`int`):
+ Model's intermediate_size taken from config.
+ ssm_state_size: (`int`):
+ Model's state_size taken from config.
+ conv_kernel_size: (`int`):
+ Model's convolution kernel size taken from config
+ conv_states: (`torch.Tensor`):
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
+ ssm_states: (`torch.Tensor`):
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
+
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
+
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv = outputs.past_key_values
+ ```
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ max_batch_size: int,
+ dtype: torch.dtype = torch.float16,
+ device: Optional[str] = None,
+ **kwargs,
+ ):
+ self.dtype = dtype
+ self.max_batch_size = max_batch_size
+ self.intermediate_size = config.intermediate_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+
+ self.conv_states: torch.Tensor = torch.zeros(
+ config.num_hidden_layers,
+ self.max_batch_size,
+ self.intermediate_size,
+ self.conv_kernel_size,
+ device=device,
+ dtype=dtype,
+ )
+ self.ssm_states: torch.Tensor = torch.zeros(
+ config.num_hidden_layers,
+ self.max_batch_size,
+ self.intermediate_size,
+ self.ssm_state_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ torch._dynamo.mark_static_address(self.conv_states)
+ torch._dynamo.mark_static_address(self.ssm_states)
+
+ def update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+ ) -> torch.Tensor:
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
+
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
+ return self.ssm_states[layer_idx]
+
+ def reset(self):
+ self.conv_states.zero_()
+ self.ssm_states.zero_()
diff --git a/transformers_4_44_2__configuration_llama.py b/transformers_4_44_2__configuration_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6122d5b1e5b9d027aae3217ac55df9e2bcbba8f
--- /dev/null
+++ b/transformers_4_44_2__configuration_llama.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""LLaMA model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from .transformers_4_44_2__modeling_rope_utils import rope_config_validation
+
+
+class LlamaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LLaMA-7B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
+ Llama 2 up to 4096, CodeLlama up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+
+ ```python
+ >>> from transformers import LlamaModel, LlamaConfig
+
+ >>> # Initializing a LLaMA llama-7b style configuration
+ >>> configuration = LlamaConfig()
+
+ >>> # Initializing a model from the llama-7b style configuration
+ >>> model = LlamaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llama"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ mlp_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.mlp_bias = mlp_bias
+
+ # Validate the correctness of rotary position embeddings parameters
+ # BC: if there is a 'type' field, move it to 'rope_type'.
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ rope_config_validation(self)
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
diff --git a/transformers_4_44_2__modeling_attn_mask_utils.py b/transformers_4_44_2__modeling_attn_mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0d443c8a8e24ed7c911e27db48ce251e7c99f0
--- /dev/null
+++ b/transformers_4_44_2__modeling_attn_mask_utils.py
@@ -0,0 +1,482 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+
+@dataclass
+class AttentionMaskConverter:
+ """
+ A utility attention mask class that allows one to:
+ - Create a causal 4d mask
+ - Create a causal 4d mask with slided window
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
+ key_value_length) that can be multiplied with attention scores
+
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+
+ >>> converter = AttentionMaskConverter(True)
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
+ ```
+
+ Parameters:
+ is_causal (`bool`):
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
+
+ sliding_window (`int`, *optional*):
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
+ """
+
+ is_causal: bool
+ sliding_window: int
+
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
+ self.is_causal = is_causal
+ self.sliding_window = sliding_window
+
+ if self.sliding_window is not None and self.sliding_window <= 0:
+ raise ValueError(
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+ )
+
+ def to_causal_4d(
+ self,
+ batch_size: int,
+ query_length: int,
+ key_value_length: int,
+ dtype: torch.dtype,
+ device: Union[torch.device, "str"] = "cpu",
+ ) -> Optional[torch.Tensor]:
+ """
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
+ bias to upper right hand triangular matrix (causal mask).
+ """
+ if not self.is_causal:
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
+
+ # If shape is not cached, create a new causal mask and cache it
+ input_shape = (batch_size, query_length)
+ past_key_values_length = key_value_length - query_length
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if input_shape[-1] > 1 or self.sliding_window is not None:
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+
+ return causal_4d_mask
+
+ def to_4d(
+ self,
+ attention_mask_2d: torch.Tensor,
+ query_length: int,
+ dtype: torch.dtype,
+ key_value_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ """
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
+ causal, a causal mask will be added.
+ """
+ input_shape = (attention_mask_2d.shape[0], query_length)
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
+ if key_value_length is None:
+ raise ValueError(
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+ )
+
+ past_key_values_length = key_value_length - query_length
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=attention_mask_2d.device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+ elif self.sliding_window is not None:
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
+ attention_mask_2d.device
+ )
+
+ if causal_4d_mask is not None:
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
+
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
+ expanded_4d_mask = expanded_attn_mask
+
+ return expanded_4d_mask
+
+ @staticmethod
+ def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+ ):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # add lower triangular sliding window mask if necessary
+ if sliding_window is not None:
+ diagonal = past_key_values_length - sliding_window - 1
+
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
+
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+ @staticmethod
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+ @staticmethod
+ def _unmask_unattended(
+ expanded_mask: torch.FloatTensor,
+ min_dtype: float,
+ ):
+ # fmt: off
+ """
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ Details: https://github.com/pytorch/pytorch/issues/110213
+
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
+ `attention_mask` is [bsz, src_seq_len].
+
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
+
+ For example, if `expanded_mask` is (e.g. here left-padding case)
+ ```
+ [[[[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[0, 0, 0],
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ then the modified `expanded_mask` will be
+ ```
+ [[[[1, 1, 1], <-- modified
+ [1, 1, 1], <-- modified
+ [0, 0, 1]]],
+ [[[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]],
+ [[[1, 1, 1], <-- modified
+ [0, 1, 0],
+ [0, 1, 1]]]]
+ ```
+ """
+ # fmt: on
+ if expanded_mask.dtype == torch.bool:
+ raise ValueError(
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
+ )
+
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
+
+ @staticmethod
+ def _ignore_causal_mask_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+ is_training: bool = False,
+ ) -> bool:
+ """
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
+
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
+ """
+
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
+ key_value_length = query_length + past_key_values_length
+
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(inputs_embeds, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+
+ ignore_causal_mask = False
+
+ if attention_mask is None:
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
+ # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
+ #
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
+ if (
+ (is_training or not is_tracing)
+ and (query_length == 1 or key_value_length == query_length)
+ and (sliding_window is None or key_value_length < sliding_window)
+ ):
+ ignore_causal_mask = True
+ elif sliding_window is None or key_value_length < sliding_window:
+ if len(attention_mask.shape) == 4:
+ return False
+ elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
+ if query_length == 1 or key_value_length == query_length:
+ # For query_length == 1, causal attention and bi-directional attention are the same.
+ ignore_causal_mask = True
+
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
+
+ return ignore_causal_mask
+
+
+def _prepare_4d_causal_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, Tuple, List],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ attention_mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ inputs_embeds (`torch.Tensor`):
+ The embedded inputs as a torch Tensor.
+ past_key_values_length (`int`):
+ The length of the key value cache.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+
+ # 4d mask is passed through the layers
+ if attention_mask is not None and len(attention_mask.shape) == 2:
+ attention_mask = attn_mask_converter.to_4d(
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
+ )
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
+ if tuple(attention_mask.shape) != expected_shape:
+ raise ValueError(
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
+ )
+ else:
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
+ inverted_mask = 1.0 - attention_mask
+ attention_mask = inverted_mask.masked_fill(
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
+ )
+ else:
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+
+ return attention_mask
+
+
+# Adapted from _prepare_4d_causal_attention_mask
+def _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, Tuple, List],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
+
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(inputs_embeds, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ sliding_window=sliding_window,
+ )
+
+ if ignore_causal_mask:
+ expanded_4d_mask = None
+ elif attention_mask is None:
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ else:
+ if attention_mask.dim() == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ if attention_mask.max() != 0:
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
+ expanded_4d_mask = attention_mask
+ else:
+ expanded_4d_mask = attn_mask_converter.to_4d(
+ attention_mask,
+ input_shape[-1],
+ dtype=inputs_embeds.dtype,
+ key_value_length=key_value_length,
+ )
+
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
+ )
+
+ return expanded_4d_mask
+
+
+def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ _, key_value_length = mask.shape
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
+
+ is_tracing = (
+ torch.jit.is_tracing()
+ or isinstance(mask, torch.fx.Proxy)
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
+ )
+
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
+ if not is_tracing and torch.all(mask == 1):
+ return None
+ else:
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _create_4d_causal_attention_mask(
+ input_shape: Union[torch.Size, Tuple, List],
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+) -> Optional[torch.Tensor]:
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
+
+ Args:
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ device (`int`):
+ The torch device the created mask shall have.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = past_key_values_length + input_shape[-1]
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
+ )
+
+ return attention_mask
diff --git a/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb66b746cf0e94f3fb4774f57e75d1afad3399c
--- /dev/null
+++ b/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py
@@ -0,0 +1,348 @@
+# coding=utf-8
+# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import os
+from typing import Optional, Tuple, Union
+
+
+import torch
+import torch.nn.functional as F
+
+from functools import lru_cache
+import importlib.metadata
+import importlib.util
+from packaging import version
+
+from transformers.utils import is_flash_attn_2_available
+
+
+if is_flash_attn_2_available():
+ try:
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+ except ImportError:
+ raise "Unable to import flash_attn"
+
+
+def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
+ # Check if the package spec exists and grab its version to avoid importing a local directory
+ package_exists = importlib.util.find_spec(pkg_name) is not None
+ package_version = "N/A"
+ if package_exists:
+ try:
+ # Primary method to get the package version
+ package_version = importlib.metadata.version(pkg_name)
+ except importlib.metadata.PackageNotFoundError:
+ # Fallback method: Only for "torch" and versions containing "dev"
+ if pkg_name == "torch":
+ try:
+ package = importlib.import_module(pkg_name)
+ temp_version = getattr(package, "__version__", "N/A")
+ # Check if the version contains "dev"
+ if "dev" in temp_version:
+ package_version = temp_version
+ package_exists = True
+ else:
+ package_exists = False
+ except ImportError:
+ # If the package can't be imported, it's not available
+ package_exists = False
+ else:
+ # For packages other than "torch", don't attempt the fallback and set as not available
+ package_exists = False
+ if return_version:
+ return package_exists, package_version
+ else:
+ return package_exists
+
+
+@lru_cache()
+def is_flash_attn_greater_or_equal(library_version: str):
+ if not _is_package_available("flash_attn"):
+ return False
+
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
+
+
+def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ indices (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input sequence.
+ cu_seqlens (`torch.Tensor`):
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ max_seqlen_in_batch (`int`):
+ Maximum sequence length in batch.
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _upad_input(
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+):
+ """
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+ tensors for query, key, value tensors.
+
+ Arguments:
+ query_layer (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+ query_length (`int`):
+ Target length.
+
+ Return:
+ query_layer (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+def prepare_fa2_from_position_ids(query, key, value, position_ids):
+ """
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
+ All three query, key, value states will be flattened.
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
+
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
+
+ Arguments:
+ query (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ position_ids (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ query (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ query = query.view(-1, query.size(-2), query.size(-1))
+ key = key.view(-1, key.size(-2), key.size(-1))
+ value = value.view(-1, value.size(-2), value.size(-1))
+ position_ids = position_ids.flatten()
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
+
+ cu_seq_lens = torch.cat(
+ (
+ indices_q[position_ids == 0],
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
+ )
+ )
+
+ max_length = position_ids.max() + 1
+
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
+
+
+def _flash_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ is_causal: bool,
+ dropout: float = 0.0,
+ position_ids: Optional[torch.Tensor] = None,
+ softmax_scale: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ use_top_left_mask: bool = False,
+ softcap: Optional[float] = None,
+ deterministic: bool = None,
+):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_top_left_mask (`bool`, defaults to `False`):
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
+ softcap (`float`, *optional*):
+ Softcap for the attention logits, used e.g. in gemma2.
+ deterministic (`bool`, *optional*):
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ """
+ if not use_top_left_mask:
+ causal = is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
+ causal = is_causal and query_length != 1
+
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
+ use_sliding_windows = (
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
+ )
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
+
+ if is_flash_attn_greater_or_equal("2.4.1"):
+ if deterministic is None:
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ flash_kwargs["deterministic"] = deterministic
+
+ if softcap is not None:
+ flash_kwargs["softcap"] = softcap
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ **flash_kwargs,
+ )
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
+ elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
+ batch_size = query_states.size(0)
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
+ query_states, key_states, value_states, position_ids
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ **flash_kwargs,
+ )
+
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
+
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
+ )
+
+ return attn_output
diff --git a/transformers_4_44_2__modeling_outputs.py b/transformers_4_44_2__modeling_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca74004a825e52ef01e882825ff71cfdef74a4b
--- /dev/null
+++ b/transformers_4_44_2__modeling_outputs.py
@@ -0,0 +1,1753 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+
+from transformers.utils import ModelOutput
+
+
+@dataclass
+class BaseModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithNoAttention(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPooling(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MoECausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
+ states terms, to train a MoE model.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ z_loss for the sparse modules.
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ aux_loss for the sparse modules.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+ modules.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ z_loss: torch.FloatTensor = None
+ aux_loss: torch.FloatTensor = None
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutput(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss and the z_loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ router_probs: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) with mixture of experts outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+ aux_loss for the sparse modules.
+
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss for Mixture of Experts models.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ aux_loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
+ Mixture of Expert's router hidden states terms, to train a MoE model.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+ loss and the z_loss for Mixture of Experts models.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ router_probs: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+ modules.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class CausalLMOutput(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentions(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
+ cross-attention heads.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
+ value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
+ setting. Only relevant if `config.is_decoder = True`.
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutputWithPast(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MaskedLMOutput(ModelOutput):
+ """
+ Base class for masked language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqLMOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+ Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts
+ models.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ encoder_z_loss: torch.FloatTensor = None
+ decoder_z_loss: torch.FloatTensor = None
+ encoder_aux_loss: torch.FloatTensor = None
+ decoder_aux_loss: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class NextSentencePredictorOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided):
+ Next sequence prediction (classification) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqSequenceClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MultipleChoiceModelOutput(ModelOutput):
+ """
+ Base class for outputs of multiple choice models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class TokenClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of token classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class QuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of question answering models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: torch.FloatTensor = None
+ end_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
+ """
+ Base class for outputs of sequence-to-sequence question answering models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+ start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-start scores (before SoftMax).
+ end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Span-end scores (before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ start_logits: torch.FloatTensor = None
+ end_logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SemanticSegmenterOutput(ModelOutput):
+ """
+ Base class for outputs of semantic segmentation models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+ Classification scores for each pixel.
+
+
+
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+ original image size as post-processing. You should always check your logits shape and resize as needed.
+
+
+
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutput(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutputWithNoAttention(ModelOutput):
+ """
+ Base class for outputs of image classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
+ called feature maps) of the model at the output of each stage.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DepthEstimatorOutput(ModelOutput):
+ """
+ Base class for outputs of depth estimation models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
+ Predicted depth for each pixel.
+
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ predicted_depth: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageSuperResolutionOutput(ModelOutput):
+ """
+ Base class for outputs of image super resolution models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Reconstruction loss.
+ reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed images, possibly upscaled.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ reconstruction: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Wav2Vec2BaseModelOutput(ModelOutput):
+ """
+ Base class for models that have been trained with the Wav2Vec2 loss objective.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
+ Sequence of extracted feature vectors of the last convolutional layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ extract_features: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class XVectorOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ForXVector`].
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+ Classification hidden states before AMSoftmax.
+ embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+ Utterance embeddings used for vector similarity-based retrieval.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ embeddings: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BackboneOutput(ModelOutput):
+ """
+ Base class for outputs of backbones.
+
+ Args:
+ feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
+ Feature maps of the stages.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
+ depending on the backbone.
+
+ Hidden-states of the model at the output of each stage plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Only applicable if the backbone uses attention.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ feature_maps: Tuple[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndProjection(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`.
+
+ Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ projection_state: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqSpectrogramOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence spectrogram outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Spectrogram generation loss.
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
+ The predicted spectrogram.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ spectrogram: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqTSModelOutput(ModelOutput):
+ """
+ Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
+ sequential decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Shift values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to shift back to the original magnitude.
+ scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Scaling values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to rescale back to the original magnitude.
+ static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+ Static features of each time series' in a batch which are copied to the covariates at inference time.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+ static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class Seq2SeqTSPredictionOutput(ModelOutput):
+ """
+ Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
+ chosen distribution.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided):
+ Distributional loss.
+ params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
+ Parameters of the chosen distribution.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Shift values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to shift back to the original magnitude.
+ scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+ Scaling values of each time series' context window which is used to give the model inputs of the same
+ magnitude and then used to rescale back to the original magnitude.
+ static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+ Static features of each time series' in a batch which are copied to the covariates at inference time.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ params: Optional[Tuple[torch.FloatTensor]] = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ loc: Optional[torch.FloatTensor] = None
+ scale: Optional[torch.FloatTensor] = None
+ static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class SampleTSPredictionOutput(ModelOutput):
+ """
+ Base class for time series model's predictions outputs that contains the sampled values from the chosen
+ distribution.
+
+ Args:
+ sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
+ Sampled values from the chosen distribution.
+ """
+
+ sequences: torch.FloatTensor = None
+
+
+@dataclass
+class MaskedImageModelingOutput(ModelOutput):
+ """
+ Base class for outputs of masked image completion / in-painting models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+ Reconstruction loss.
+ reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Reconstructed / completed images.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
+ when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+ (also called feature maps) of the model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ reconstruction: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+ @property
+ def logits(self):
+ warnings.warn(
+ "logits attribute is deprecated and will be removed in version 5 of Transformers."
+ " Please use the reconstruction attribute to retrieve the final output instead.",
+ FutureWarning,
+ )
+ return self.reconstruction
diff --git a/transformers_4_44_2__modeling_rope_utils.py b/transformers_4_44_2__modeling_rope_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3a3f0f8e30f94bf18a4b75bf173e1b879ac816
--- /dev/null
+++ b/transformers_4_44_2__modeling_rope_utils.py
@@ -0,0 +1,559 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Tuple
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+def _compute_default_rope_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+ **rope_kwargs,
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ if config is not None and len(rope_kwargs) > 0:
+ raise ValueError(
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
+ )
+ if len(rope_kwargs) > 0:
+ base = rope_kwargs["base"]
+ dim = rope_kwargs["dim"]
+ elif config is not None:
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
+ return inv_freq, attention_factor
+
+
+def _compute_linear_scaling_rope_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+ **rope_kwargs,
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ if config is not None and len(rope_kwargs) > 0:
+ raise ValueError(
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
+ )
+ if len(rope_kwargs) > 0:
+ factor = rope_kwargs["factor"]
+ elif config is not None:
+ factor = config.rope_scaling["factor"]
+
+ # Gets the default RoPE parameters
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
+
+ # Then applies linear scaling to the frequencies.
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
+ # applying scaling to the inverse frequencies is equivalent.
+ inv_freq /= factor
+ return inv_freq, attention_factor
+
+
+def _compute_dynamic_ntk_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional["torch.device"] = None,
+ seq_len: Optional[int] = None,
+ **rope_kwargs,
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length, used to update the dynamic RoPE at inference time.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+ if config is not None and len(rope_kwargs) > 0:
+ raise ValueError(
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
+ )
+ if len(rope_kwargs) > 0:
+ base = rope_kwargs["base"]
+ dim = rope_kwargs["dim"]
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
+ factor = rope_kwargs["factor"]
+ elif config is not None:
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ max_position_embeddings = config.max_position_embeddings
+ factor = config.rope_scaling["factor"]
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # seq_len: default to max_position_embeddings, e.g. at init time
+ seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
+
+ # Compute the inverse frequencies
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
+ return inv_freq, attention_factor
+
+
+def _compute_yarn_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with NTK scaling. Please refer to the
+ [original paper](https://arxiv.org/abs/2309.00071)
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
+ if len(rope_kwargs) > 0:
+ raise ValueError(
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
+ )
+
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ max_position_embeddings = config.max_position_embeddings
+ factor = config.rope_scaling["factor"]
+
+ # Sets the attention factor as suggested in the paper
+ attention_factor = config.rope_scaling.get("attention_factor")
+ if attention_factor is None:
+ attention_factor = 0.1 * math.log(factor) + 1.0
+
+ # Optional config options
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
+
+ # Compute the inverse frequencies
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
+
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
+ """Find dimension range bounds based on rotations"""
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
+ return max(low, 0), min(high, dim - 1)
+
+ def linear_ramp_factor(min, max, dim):
+ if min == max:
+ max += 0.001 # Prevent singularity
+
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+ ramp_func = torch.clamp(linear_func, 0, 1)
+ return ramp_func
+
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
+ inv_freq_extrapolation = 1.0 / pos_freqs
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
+
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
+
+ # Get n-dimensional rotational scaling corrected for extrapolation
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
+ inv_freq = (
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
+ )
+
+ return inv_freq, attention_factor
+
+
+def _compute_longrope_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
+ [original implementation](https://github.com/microsoft/LongRoPE)
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
+ if len(rope_kwargs) > 0:
+ raise ValueError(
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
+ f"{rope_kwargs}"
+ )
+
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ long_factor = config.rope_scaling["long_factor"]
+ short_factor = config.rope_scaling["short_factor"]
+ factor = config.rope_scaling.get("factor")
+ attention_factor = config.rope_scaling.get("attention_factor")
+
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
+ # values to compute the default attention scaling factor, instead of using `factor`.
+ if hasattr(config, "original_max_position_embeddings"):
+ max_position_embeddings = config.original_max_position_embeddings
+ expanded_max_position_embeddings = config.max_position_embeddings
+ factor = expanded_max_position_embeddings / max_position_embeddings
+ else:
+ max_position_embeddings = config.max_position_embeddings
+ expanded_max_position_embeddings = max_position_embeddings * factor
+
+ # Sets the attention factor as suggested in the paper
+ if attention_factor is None:
+ if factor <= 1.0:
+ attention_factor = 1.0
+ else:
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
+
+ # Compute the inverse frequencies -- scaled based on the target sequence length
+ if expanded_max_position_embeddings > max_position_embeddings:
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
+ else:
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
+
+ return inv_freq, attention_factor
+
+
+def _compute_llama3_parameters(
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
+) -> Tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies for llama 3.1.
+
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ rope_kwargs (`Dict`, *optional*):
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin.
+ """
+ # Gets the default RoPE parameters
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
+
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
+
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+
+ wavelen = 2 * math.pi / inv_freq
+ # wavelen < high_freq_wavelen: do nothing
+ # wavelen > low_freq_wavelen: divide by factor
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
+ # otherwise: interpolate between the two, using a smooth factor
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+ return inv_freq_llama, attention_factor
+
+
+# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
+# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
+# parameterizations, as long as the callable has the same signature.
+ROPE_INIT_FUNCTIONS = {
+ "default": _compute_default_rope_parameters,
+ "linear": _compute_linear_scaling_rope_parameters,
+ "dynamic": _compute_dynamic_ntk_parameters,
+ "yarn": _compute_yarn_parameters,
+ "longrope": _compute_longrope_parameters,
+ "llama3": _compute_llama3_parameters,
+}
+
+
+def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
+ # BC: "rope_type" was originally "type" -- let's gracefully handle it
+ if "rope_type" not in received_keys and "type" in received_keys:
+ received_keys -= {"type"}
+ received_keys.add("rope_type")
+
+ missing_keys = required_keys - received_keys
+ if missing_keys:
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
+
+ if optional_keys is not None:
+ unused_keys = received_keys - required_keys - optional_keys
+ else:
+ unused_keys = received_keys - required_keys
+ if unused_keys:
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
+
+
+def _validate_default_rope_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys)
+
+
+def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+ optional_keys = {"original_max_position_embeddings"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_yarn_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor"}
+ optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ attention_factor = rope_scaling.get("attention_factor")
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
+ logger.warning(
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+ )
+ beta_fast = rope_scaling.get("beta_fast")
+ if beta_fast is not None and not isinstance(beta_fast, float):
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
+ beta_slow = rope_scaling.get("beta_slow")
+ if beta_slow is not None and not isinstance(beta_slow, float):
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
+
+ if (beta_fast or 32) < (beta_slow or 1):
+ logger.warning(
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
+ )
+
+
+def _validate_longrope_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "short_factor", "long_factor"}
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
+
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+
+ short_factor = rope_scaling.get("short_factor")
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
+ if not len(short_factor) == dim // 2:
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
+
+ long_factor = rope_scaling.get("long_factor")
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
+ if not len(long_factor) == dim // 2:
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
+
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
+ # unique to longrope (= undesirable)
+ if hasattr(config, "original_max_position_embeddings"):
+ logger.warning_once(
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
+ "as it is compatible with most model architectures."
+ )
+ else:
+ factor = rope_scaling.get("factor")
+ if factor is None:
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
+ elif not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ attention_factor = rope_scaling.get("attention_factor")
+ if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
+ logger.warning(
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+ )
+
+
+def _validate_llama3_parameters(config: PretrainedConfig):
+ rope_scaling = config.rope_scaling
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
+ received_keys = set(rope_scaling.keys())
+ _check_received_keys(rope_type, received_keys, required_keys)
+
+ factor = rope_scaling["factor"]
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+ low_freq_factor = rope_scaling["low_freq_factor"]
+ high_freq_factor = rope_scaling["high_freq_factor"]
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
+ if high_freq_factor <= low_freq_factor:
+ logger.warning(
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
+ )
+
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
+ logger.warning(
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
+ f"{original_max_position_embeddings}"
+ )
+ if original_max_position_embeddings >= config.max_position_embeddings:
+ logger.warning(
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
+ )
+
+
+# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
+ROPE_VALIDATION_FUNCTIONS = {
+ "default": _validate_default_rope_parameters,
+ "linear": _validate_linear_scaling_rope_parameters,
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
+ "yarn": _validate_yarn_parameters,
+ "longrope": _validate_longrope_parameters,
+ "llama3": _validate_llama3_parameters,
+}
+
+
+def rope_config_validation(config: PretrainedConfig):
+ """
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
+ """
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
+ if rope_scaling is None:
+ return
+
+ # BC: "rope_type" was originally "type"
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
+ if validation_fn is not None:
+ validation_fn(config)
+ else:
+ logger.warning(
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
+ )
diff --git a/transformers_4_44_2__pytorch_utils.py b/transformers_4_44_2__pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f0c181f013fcd83f06e52d531c48203c0762a88
--- /dev/null
+++ b/transformers_4_44_2__pytorch_utils.py
@@ -0,0 +1,17 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from torch import nn
+
+ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
\ No newline at end of file
diff --git a/variable_cache.py b/variable_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..3164b81591fa9c8439f160fedd8f9814092f791e
--- /dev/null
+++ b/variable_cache.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+# Copyright 2024 Nvidia Corporation. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+from typing import Optional, Dict, Any, Tuple
+
+import torch
+from transformers.cache_utils import Cache # used to let GenerationMixin know that we use a Cache object
+
+from .configuration_decilm import DeciLMConfig
+from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, SinkCache, StaticCache, SlidingWindowCache
+
+
+class VariableCache(Cache_4_44_2, Cache):
+ """
+ A Cache object that supports a different Cache implementation for every layer,
+ including layers without any kv-cache.
+ Implemented using a list of Cache objects, each represents a "model" with 1 layer.
+ The default implementation for the layer caches is StaticCache.
+ The cache of each layer is allocated to the same gpu as the layer itself.
+ """
+
+ def __init__(
+ self,
+ *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
+ config: DeciLMConfig,
+ batch_size: int = None,
+ max_cache_len: int = None,
+ dtype: torch.dtype = torch.float32,
+ max_batch_size: Optional[int] = None,
+ **kwargs,
+ ) -> None:
+ Cache_4_44_2.__init__(self)
+
+ self.config = deepcopy(config)
+ self.max_batch_size = batch_size or max_batch_size
+ self.batch_size = self.max_batch_size
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+ self.dtype = dtype
+
+ self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
+ self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.layer_caches[layer_idx] is None:
+ self.layer_devices[layer_idx] = key_states.device
+ self._init_layer_cache(layer_idx)
+
+ layer_cache = self.layer_caches[layer_idx]
+ assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
+
+ k_out, v_out = layer_cache.update(key_states=key_states,
+ value_states=value_states,
+ layer_idx=0,
+ cache_kwargs=cache_kwargs)
+ seq_len = self.get_seq_length(layer_idx)
+ k_out = k_out[:, :, :seq_len, :]
+ v_out = v_out[:, :, :seq_len, :]
+ return k_out, v_out
+
+ def _init_layer_cache(self, layer_idx: int) -> None:
+ block_config = self.config.block_configs[layer_idx]
+ attention_config = block_config.attention
+
+ if attention_config.no_op or attention_config.replace_with_linear:
+ return None
+
+ device = self.layer_devices[layer_idx]
+ assert device is not None, f"Trying to init layer cache for {layer_idx=} without device"
+
+ config = deepcopy(self.config)
+ config.num_hidden_layers = 1
+ config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
+
+ if attention_config.window_length is not None:
+ if not attention_config.is_sink:
+ config.sliding_window = attention_config.window_length
+ self.layer_caches[layer_idx] = SlidingWindowCache(config=config,
+ max_batch_size=self.max_batch_size,
+ max_cache_len=self.max_cache_len,
+ device=device,
+ dtype=self.dtype)
+ return
+ elif not attention_config.unshifted_sink:
+ self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length,
+ num_sink_tokens=attention_config.num_sink_tokens)
+ return
+
+ self.layer_caches[layer_idx] = StaticCache(config=config,
+ max_batch_size=self.max_batch_size,
+ max_cache_len=self.max_cache_len,
+ device=device,
+ dtype=self.dtype)
+
+ def _get_first_real_cache(self) -> Cache:
+ for layer_cache in self.layer_caches:
+ if layer_cache is not None:
+ return layer_cache
+ raise ValueError(f"No real cache found, all layer caches are None.")
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ if layer_idx == 0 and self.layer_caches[0] is None:
+ try:
+ layer_cache = self._get_first_real_cache()
+ except ValueError:
+ return 0
+ else:
+ layer_cache = self.layer_caches[layer_idx]
+ return layer_cache.get_seq_length()
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+ return self.max_cache_len
+
+ def reset(self):
+ for layer_idx in range(len(self.layer_caches)):
+ layer_cache = self.layer_caches[layer_idx]
+ if hasattr(layer_cache, "reset"):
+ layer_cache.reset()
+ else:
+ self._init_layer_cache(layer_idx)