Update starvector architecture file
Browse files- starvector_arch.py +4 -3
starvector_arch.py
CHANGED
|
@@ -18,9 +18,10 @@ class StarVectorConfig(PretrainedConfig):
|
|
| 18 |
use_cache: bool = True,
|
| 19 |
num_attention_heads: int = 16,
|
| 20 |
num_hidden_layers: int = 24,
|
| 21 |
-
vocab_size: int =
|
| 22 |
-
hidden_size: int =
|
| 23 |
num_kv_heads: int = 4,
|
|
|
|
| 24 |
**kwargs,
|
| 25 |
):
|
| 26 |
self.starcoder_model_name = starcoder_model_name
|
|
@@ -36,7 +37,7 @@ class StarVectorConfig(PretrainedConfig):
|
|
| 36 |
self.vocab_size = vocab_size
|
| 37 |
self.hidden_size = hidden_size
|
| 38 |
self.num_kv_heads = num_kv_heads
|
| 39 |
-
|
| 40 |
super().__init__(**kwargs)
|
| 41 |
|
| 42 |
class StarVectorForCausalLM(PreTrainedModel):
|
|
|
|
| 18 |
use_cache: bool = True,
|
| 19 |
num_attention_heads: int = 16,
|
| 20 |
num_hidden_layers: int = 24,
|
| 21 |
+
vocab_size: int = 49152,
|
| 22 |
+
hidden_size: int = 2048,
|
| 23 |
num_kv_heads: int = 4,
|
| 24 |
+
torch_dtype: str = "bfloat16",
|
| 25 |
**kwargs,
|
| 26 |
):
|
| 27 |
self.starcoder_model_name = starcoder_model_name
|
|
|
|
| 37 |
self.vocab_size = vocab_size
|
| 38 |
self.hidden_size = hidden_size
|
| 39 |
self.num_kv_heads = num_kv_heads
|
| 40 |
+
self.torch_dtype = torch_dtype
|
| 41 |
super().__init__(**kwargs)
|
| 42 |
|
| 43 |
class StarVectorForCausalLM(PreTrainedModel):
|