Spaces:
Sleeping
Sleeping
File size: 1,506 Bytes
d769fbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from wrappers import LSTMWrapper, XGBWrapper, CNNWrapper
import joblib
from tensorflow.keras.models import load_model
def load_model_by_type(model_path):
if model_path.suffix == '.h5':
if 'lstm_multi' in str(model_path):
return LSTMWrapper(load_model(model_path))
elif 'cnn_multi' in str(model_path):
return CNNWrapper(load_model(model_path))
else:
raise ValueError("Unsupported model type")
elif model_path.suffix == '.pkl':
return XGBWrapper(joblib.load(model_path))
else:
raise ValueError("Unsupported model type")
def encoder_from_model(model_name):
if model_name == "cnn_multi_model.h5":
return "cnn_multi_label_encoding.pkl"
elif model_name == "lstm_multi_model.h5":
return "lstm_multi_label_encoding.pkl"
elif model_name == "pca_xgboost_multi_model.pkl":
return "pca_xgboost_multi_label_encoding.pkl"
elif model_name == "cnn_binary_model.h5":
return "cnn_binary_label_encoding.pkl"
elif model_name == "lstm_binary_model.h5":
return "lstm_binary_label_encoding.pkl"
elif model_name == "pca_xgboost_binary_model.pkl":
return "pca_xgboost_binary_label_encoding.pkl"
else:
raise ValueError("Unsupported model name")
if __name__ == "__main__":
from pathlib import Path
PACKAGE_ROOT = Path(__file__).parent.parent.parent
MODEL_PATH = PACKAGE_ROOT / "models" / "lstm_multi_model.h5"
load_model_by_type(MODEL_PATH) |