lyraLLaMA / lyra_llama /config.py
yibolu
commit message
dae7d0f
raw
history blame contribute delete
979 Bytes
import dataclasses
from typing import Optional
@dataclasses.dataclass
class LyraLLaMAParam:
num_heads: int = 40
size_per_head: int = 128
inter_size: int = 13824
num_layers: int = 40
vocab_size: int = 39424
start_id: Optional[int] = 1
end_id: Optional[int] = 2
tensor_para_size: int = 1
pipeline_para_size: int = 1
remove_padding: bool = True
shared_contexts_ratio: float = 1.0
layernorm_eps: float = 1e-6
weights_data_type: str = "fp16"
rotary_embedding: int = 128
use_gptj_residual: bool = False
def __post_init__(self):
if not 0.0 <= self.shared_contexts_ratio <= 1.0:
raise ValueError(
f'Got an invalid value of shared_context_ratio '
f'{self.shared_contexts_ratio} - range: [0.0, 1.0]')
def asdict(self):
return dataclasses.asdict(self)
LYRA_LLAMA_PARAM = LyraLLaMAParam()
LIB_SO_PATH = '/app/LyraLLaMAPy/ftlib/libth_transformer_sm80_cu11.so'