Spaces:
Runtime error
Runtime error
| # Creates a directory in which to look up available agents | |
| import os | |
| from typing import List, Optional | |
| from src.simuleval_transcoder import SimulevalTranscoder | |
| import json | |
| import logging | |
| logger = logging.getLogger("socketio_server_pubsub") | |
| # fmt: off | |
| M4T_P0_LANGS = [ | |
| "eng", | |
| "arb", "ben", "cat", "ces", "cmn", "cym", "dan", | |
| "deu", "est", "fin", "fra", "hin", "ind", "ita", | |
| "jpn", "kor", "mlt", "nld", "pes", "pol", "por", | |
| "ron", "rus", "slk", "spa", "swe", "swh", "tel", | |
| "tgl", "tha", "tur", "ukr", "urd", "uzn", "vie", | |
| ] | |
| # fmt: on | |
| class NoAvailableAgentException(Exception): | |
| pass | |
| class AgentWithInfo: | |
| def __init__( | |
| self, | |
| agent, | |
| name: str, | |
| modalities: List[str], | |
| target_langs: List[str], | |
| # Supported dynamic params are defined in StreamingTypes.ts | |
| dynamic_params: List[str] = [], | |
| description="", | |
| has_expressive: Optional[bool] = None, | |
| ): | |
| self.agent = agent | |
| self.has_expressive = has_expressive | |
| self.name = name | |
| self.description = description | |
| self.modalities = modalities | |
| self.target_langs = target_langs | |
| self.dynamic_params = dynamic_params | |
| def get_capabilities_for_json(self): | |
| return { | |
| "name": self.name, | |
| "description": self.description, | |
| "modalities": self.modalities, | |
| "targetLangs": self.target_langs, | |
| "dynamicParams": self.dynamic_params, | |
| } | |
| def load_from_json(cls, config: str): | |
| """ | |
| Takes in JSON array of models to load in, e.g. | |
| [{"name": "s2s_m4t_emma-unity2_multidomain_v0.1", "description": "M4T model that supports simultaneous S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}, | |
| {"name": "s2s_m4t_expr-emma_v0.1", "description": "ES-EN expressive model that supports S2S and S2T", "modalities": ["s2t", "s2s"], "targetLangs": ["en"]}] | |
| """ | |
| configs = json.loads(config) | |
| agents = [] | |
| for config in configs: | |
| agent = SimulevalTranscoder.build_agent(config["name"]) | |
| agents.append( | |
| AgentWithInfo( | |
| agent=agent, | |
| name=config["name"], | |
| modalities=config["modalities"], | |
| target_langs=config["targetLangs"], | |
| ) | |
| ) | |
| return agents | |
| class SimulevalAgentDirectory: | |
| # Available models. These are the directories where the models can be found, and also serve as an ID for the model. | |
| seamless_streaming_agent = "SeamlessStreaming" | |
| seamless_agent = "Seamless" | |
| def __init__(self): | |
| self.agents = [] | |
| self.did_build_and_add_agents = False | |
| def add_agent(self, agent: AgentWithInfo): | |
| self.agents.append(agent) | |
| def build_agent_if_available(self, model_id, config_name=None): | |
| agent = None | |
| try: | |
| if config_name is not None: | |
| agent = SimulevalTranscoder.build_agent( | |
| model_id, | |
| config_name=config_name, | |
| ) | |
| else: | |
| agent = SimulevalTranscoder.build_agent( | |
| model_id, | |
| ) | |
| except Exception as e: | |
| from fairseq2.assets.error import AssetError | |
| logger.warning("Failed to build agent %s: %s" % (model_id, e)) | |
| if isinstance(e, AssetError): | |
| logger.warning( | |
| "Please download gated assets and set `gated_model_dir` in the config" | |
| ) | |
| raise e | |
| return agent | |
| def build_and_add_agents(self, models_override=None): | |
| if self.did_build_and_add_agents: | |
| return | |
| if models_override is not None: | |
| agent_infos = AgentWithInfo.load_from_json(models_override) | |
| for agent_info in agent_infos: | |
| self.add_agent(agent_info) | |
| else: | |
| s2s_agent = None | |
| if os.environ.get("USE_EXPRESSIVE_MODEL", "0") == "1": | |
| logger.info("Building expressive model...") | |
| s2s_agent = self.build_agent_if_available( | |
| SimulevalAgentDirectory.seamless_agent, | |
| config_name="vad_s2st_sc_24khz_main.yaml", | |
| ) | |
| has_expressive = True | |
| else: | |
| logger.info("Building non-expressive model...") | |
| s2s_agent = self.build_agent_if_available( | |
| SimulevalAgentDirectory.seamless_streaming_agent, | |
| config_name="vad_s2st_sc_main.yaml", | |
| ) | |
| has_expressive = False | |
| if s2s_agent: | |
| self.add_agent( | |
| AgentWithInfo( | |
| agent=s2s_agent, | |
| name=SimulevalAgentDirectory.seamless_streaming_agent, | |
| modalities=["s2t", "s2s"], | |
| target_langs=M4T_P0_LANGS, | |
| dynamic_params=["expressive"], | |
| description="multilingual expressive model that supports S2S and S2T", | |
| has_expressive=has_expressive, | |
| ) | |
| ) | |
| if len(self.agents) == 0: | |
| logger.error( | |
| "No agents were loaded. This likely means you are missing the actual model files specified in simuleval_agent_directory." | |
| ) | |
| self.did_build_and_add_agents = True | |
| def get_agent(self, name): | |
| for agent in self.agents: | |
| if agent.name == name: | |
| return agent | |
| return None | |
| def get_agent_or_throw(self, name): | |
| agent = self.get_agent(name) | |
| if agent is None: | |
| raise NoAvailableAgentException("No agent found with name= %s" % (name)) | |
| return agent | |
| def get_agents_capabilities_list_for_json(self): | |
| return [agent.get_capabilities_for_json() for agent in self.agents] | |