LeonardoBerti's picture
Upload 51 files
69524d0 verified
from models.mlplob import MLPLOB
from models.tlob import TLOB
from models.binctabl import BiN_CTABL
from models.deeplob import DeepLOB
from transformers import AutoModelForSeq2SeqLM
def pick_model(model_type, hidden_dim, num_layers, seq_size, num_features, num_heads=8, is_sin_emb=False, dataset_type=None):
if model_type == "MLPLOB":
return MLPLOB(hidden_dim, num_layers, seq_size, num_features, dataset_type)
elif model_type == "TLOB":
return TLOB(hidden_dim, num_layers, seq_size, num_features, num_heads, is_sin_emb, dataset_type)
elif model_type == "BINCTABL":
return BiN_CTABL(60, num_features, seq_size, seq_size, 120, 5, 3, 1)
elif model_type == "DEEPLOB":
return DeepLOB()
else:
raise ValueError("Model not found")