Spaces:
Runtime error
Runtime error
File size: 2,758 Bytes
bb5a96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Dict
from src.models import (
lcnn,
specrnet,
whisper_specrnet,
rawnet3,
whisper_lcnn,
meso_net,
whisper_meso_net
)
def get_model(model_name: str, config: Dict, device: str):
if model_name == "rawnet3":
return rawnet3.prepare_model()
elif model_name == "lcnn":
return lcnn.FrontendLCNN(device=device, **config)
elif model_name == "specrnet":
return specrnet.FrontendSpecRNet(
device=device,
**config,
)
elif model_name == "mesonet":
return meso_net.FrontendMesoInception4(
input_channels=config.get("input_channels", 1),
fc1_dim=config.get("fc1_dim", 1024),
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
device=device,
)
elif model_name == "whisper_lcnn":
return whisper_lcnn.WhisperLCNN(
input_channels=config.get("input_channels", 1),
freeze_encoder=config.get("freeze_encoder", False),
device=device,
)
elif model_name == "whisper_specrnet":
return whisper_specrnet.WhisperSpecRNet(
input_channels=config.get("input_channels", 1),
freeze_encoder=config.get("freeze_encoder", False),
device=device,
)
elif model_name == "whisper_mesonet":
return whisper_meso_net.WhisperMesoNet(
input_channels=config.get("input_channels", 1),
freeze_encoder=config.get("freeze_encoder", True),
fc1_dim=config.get("fc1_dim", 1024),
device=device,
)
elif model_name == "whisper_frontend_lcnn":
return whisper_lcnn.WhisperMultiFrontLCNN(
input_channels=config.get("input_channels", 2),
freeze_encoder=config.get("freeze_encoder", False),
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
device=device,
)
elif model_name == "whisper_frontend_specrnet":
return whisper_specrnet.WhisperMultiFrontSpecRNet(
input_channels=config.get("input_channels", 2),
freeze_encoder=config.get("freeze_encoder", False),
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
device=device,
)
elif model_name == "whisper_frontend_mesonet":
return whisper_meso_net.WhisperMultiFrontMesoNet(
input_channels=config.get("input_channels", 2),
fc1_dim=config.get("fc1_dim", 1024),
freeze_encoder=config.get("freeze_encoder", True),
frontend_algorithm=config.get("frontend_algorithm", "lfcc"),
device=device,
)
else:
raise ValueError(f"Model '{model_name}' not supported")
|