|
from models.mlplob import MLPLOB
|
|
from models.tlob import TLOB
|
|
from models.binctabl import BiN_CTABL
|
|
from models.deeplob import DeepLOB
|
|
from transformers import AutoModelForSeq2SeqLM
|
|
|
|
|
|
def pick_model(model_type, hidden_dim, num_layers, seq_size, num_features, num_heads=8, is_sin_emb=False, dataset_type=None):
|
|
if model_type == "MLPLOB":
|
|
return MLPLOB(hidden_dim, num_layers, seq_size, num_features, dataset_type)
|
|
elif model_type == "TLOB":
|
|
return TLOB(hidden_dim, num_layers, seq_size, num_features, num_heads, is_sin_emb, dataset_type)
|
|
elif model_type == "BINCTABL":
|
|
return BiN_CTABL(60, num_features, seq_size, seq_size, 120, 5, 3, 1)
|
|
elif model_type == "DEEPLOB":
|
|
return DeepLOB()
|
|
else:
|
|
raise ValueError("Model not found") |