subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from pathlib import Path
# Add the local NeMo directory to Python path to use development version
nemo_root = Path(__file__).resolve().parents[3]
sys.path.insert(0, str(nemo_root))
import pytest
from omegaconf import DictConfig, OmegaConf
from pipecat.audio.vad.silero import VADParams
from nemo.agents.voice_agent.pipecat.services.nemo.diar import NeMoDiarInputParams
from nemo.agents.voice_agent.pipecat.services.nemo.stt import NeMoSTTInputParams
from nemo.agents.voice_agent.utils.config_manager import ConfigManager
@pytest.fixture
def voice_agent_server_base_path():
"""Retrieve the NeMo root path from __file__ variable"""
nemo_root_path = Path(__file__).resolve().parents[3]
# Check if the expected directories exist in the NeMo root
expected_dirs = ["nemo", "tests", "examples", "requirements"]
existing_dirs = [d.name for d in nemo_root_path.iterdir() if d.is_dir()]
if not all(sub in existing_dirs for sub in expected_dirs):
raise ValueError(
f"{nemo_root_path} is not a NeMo root path. Expected dirs: {expected_dirs}, Found dirs: {existing_dirs}"
)
voice_agent_root_path = os.path.join(nemo_root_path, "examples", "voice_agent", "server")
return voice_agent_root_path
class TestDefaultConfigs:
"""Test suite for ConfigManager class."""
@pytest.mark.unit
def test_constructor_with_valid_path(self, voice_agent_server_base_path):
"""Test ConfigManager initialization with valid configuration files."""
config_manager = ConfigManager(voice_agent_server_base_path)
# Verify initialization
assert config_manager._server_base_path == voice_agent_server_base_path
assert config_manager.SAMPLE_RATE == 16000
assert config_manager.RAW_AUDIO_FRAME_LEN_IN_SECS == 0.016
assert isinstance(config_manager.vad_params, VADParams)
assert isinstance(config_manager.stt_params, NeMoSTTInputParams)
assert isinstance(config_manager.diar_params, NeMoDiarInputParams)
@pytest.mark.unit
def test_constructor_with_invalid_path(self):
"""Test ConfigManager initialization with invalid path."""
with pytest.raises(FileNotFoundError):
ConfigManager("/nonexistent/path")
@pytest.mark.unit
def test_load_model_registry_success(self, voice_agent_server_base_path):
"""Test successful model registry loading."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert config_manager.model_registry is not None
assert "llm_models" in config_manager.model_registry
assert "tts_models" in config_manager.model_registry
assert "stt_models" in config_manager.model_registry
@pytest.mark.unit
def test_configure_stt_nemo_model(self, voice_agent_server_base_path):
"""Test STT configuration for NeMo model."""
# Create necessary files
config_manager = ConfigManager(voice_agent_server_base_path)
assert "stt_en_fastconformer" in config_manager.STT_MODEL_PATH
assert isinstance(config_manager.stt_params, NeMoSTTInputParams)
@pytest.mark.unit
def test_configure_stt_with_model_config(self, voice_agent_server_base_path):
"""Test STT configuration with custom model config."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "STT_MODEL_PATH")
@pytest.mark.unit
def test_configure_diarization(self, voice_agent_server_base_path):
"""Test diarization configuration."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "DIAR_MODEL") and isinstance(config_manager.DIAR_MODEL, str)
assert hasattr(config_manager, "USE_DIAR") and isinstance(config_manager.USE_DIAR, bool)
assert isinstance(config_manager.diar_params, NeMoDiarInputParams)
@pytest.mark.unit
def test_configure_turn_taking(self, voice_agent_server_base_path):
"""Test turn taking configuration."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "TURN_TAKING_BACKCHANNEL_PHRASES_PATH") and isinstance(
config_manager.TURN_TAKING_BACKCHANNEL_PHRASES_PATH, str
)
assert hasattr(config_manager, "TURN_TAKING_MAX_BUFFER_SIZE") and isinstance(
config_manager.TURN_TAKING_MAX_BUFFER_SIZE, int
)
assert hasattr(config_manager, "TURN_TAKING_BOT_STOP_DELAY") and isinstance(
config_manager.TURN_TAKING_BOT_STOP_DELAY, float
)
@pytest.mark.unit
def test_configure_turn_taking_backchannel_phrases(self, voice_agent_server_base_path):
"""Test turn taking configuration."""
config_manager = ConfigManager(voice_agent_server_base_path)
# Load backchannel phrases yaml file
file_path = os.path.join(
voice_agent_server_base_path, os.path.basename(config_manager.TURN_TAKING_BACKCHANNEL_PHRASES_PATH)
)
assert os.path.exists(file_path)
with open(file_path, "r") as f:
backchannel_phrases = OmegaConf.load(f)
backchannel_phrases = list(backchannel_phrases)
assert isinstance(backchannel_phrases, list)
assert all(isinstance(item, str) for item in backchannel_phrases)
@pytest.mark.unit
def test_configure_llm_with_registry_model(self, voice_agent_server_base_path):
"""Test LLM configuration with model from registry."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "SYSTEM_ROLE") and isinstance(config_manager.SYSTEM_ROLE, str)
assert hasattr(config_manager, "SYSTEM_PROMPT") and isinstance(config_manager.SYSTEM_PROMPT, str)
@pytest.mark.unit
def test_configure_llm_with_file_system_prompt(self, voice_agent_server_base_path):
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "SYSTEM_PROMPT") and isinstance(config_manager.SYSTEM_PROMPT, str)
@pytest.mark.unit
def test_configure_llm_reasoning_model(self, voice_agent_server_base_path):
"""Test LLM configuration for reasoning model."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "SYSTEM_ROLE") and isinstance(config_manager.SYSTEM_ROLE, str)
@pytest.mark.unit
def test_configure_llm_fallback_to_generic(self, voice_agent_server_base_path):
"""Test LLM configuration fallback to generic HF model."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "SYSTEM_ROLE") and isinstance(config_manager.SYSTEM_ROLE, str)
@pytest.mark.unit
def test_configure_tts_nemo_model(self, voice_agent_server_base_path):
"""Test TTS configuration for NeMo model."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "TTS_MAIN_MODEL_ID")
assert hasattr(config_manager, "TTS_SUB_MODEL_ID")
assert hasattr(config_manager, "TTS_DEVICE")
@pytest.mark.unit
def test_configure_tts_with_optional_params(self, voice_agent_server_base_path):
"""Test TTS configuration with optional parameters."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "TTS_THINK_TOKENS") and isinstance(config_manager.TTS_THINK_TOKENS, list)
assert all(isinstance(item, str) for item in config_manager.TTS_THINK_TOKENS)
assert hasattr(config_manager, "TTS_EXTRA_SEPARATOR") and isinstance(config_manager.TTS_EXTRA_SEPARATOR, list)
assert all(isinstance(item, str) for item in config_manager.TTS_EXTRA_SEPARATOR)
@pytest.mark.unit
def test_get_server_config(self, voice_agent_server_base_path):
"""Test get_server_config method."""
config_manager = ConfigManager(voice_agent_server_base_path)
server_config = config_manager.get_server_config()
assert isinstance(server_config, DictConfig)
assert hasattr(server_config.transport, "audio_out_10ms_chunks")
assert isinstance(server_config.transport.audio_out_10ms_chunks, int)
@pytest.mark.unit
def test_get_model_registry(self, voice_agent_server_base_path):
"""Test get_model_registry method."""
config_manager = ConfigManager(voice_agent_server_base_path)
model_registry = config_manager.get_model_registry()
assert isinstance(model_registry, DictConfig)
assert "llm_models" in model_registry
assert "tts_models" in model_registry
assert "stt_models" in model_registry
@pytest.mark.unit
def test_get_vad_params(self, voice_agent_server_base_path):
"""Test get_vad_params method."""
config_manager = ConfigManager(voice_agent_server_base_path)
vad_params = config_manager.get_vad_params()
assert isinstance(vad_params, VADParams)
assert isinstance(vad_params.confidence, float) and 0.0 <= vad_params.confidence <= 1.0
assert isinstance(vad_params.start_secs, float) and 0.0 <= vad_params.start_secs <= 1.0
assert isinstance(vad_params.stop_secs, float) and 0.0 <= vad_params.stop_secs <= 1.0
assert isinstance(vad_params.min_volume, float) and 0.0 <= vad_params.min_volume <= 1.0
@pytest.mark.unit
def test_get_stt_params(self, voice_agent_server_base_path):
"""Test get_stt_params method."""
config_manager = ConfigManager(voice_agent_server_base_path)
stt_params = config_manager.get_stt_params()
assert isinstance(stt_params, NeMoSTTInputParams)
assert isinstance(stt_params.att_context_size, list)
assert all(isinstance(item, int) for item in stt_params.att_context_size)
assert isinstance(stt_params.frame_len_in_secs, float) and 0.0 <= stt_params.frame_len_in_secs <= 1.0
assert (
isinstance(stt_params.raw_audio_frame_len_in_secs, float)
and 0.0 <= stt_params.raw_audio_frame_len_in_secs <= 1.0
)
@pytest.mark.unit
def test_get_diar_params(self, voice_agent_server_base_path):
"""Test get_diar_params method."""
config_manager = ConfigManager(voice_agent_server_base_path)
diar_params = config_manager.get_diar_params()
assert isinstance(diar_params, NeMoDiarInputParams)
assert hasattr(diar_params, "frame_len_in_secs") and isinstance(diar_params.frame_len_in_secs, float)
assert hasattr(diar_params, "threshold") and isinstance(diar_params.threshold, float)
@pytest.mark.unit
def test_transport_configuration(self, voice_agent_server_base_path):
"""Test transport configuration."""
config_manager = ConfigManager(voice_agent_server_base_path)
assert hasattr(config_manager, "TRANSPORT_AUDIO_OUT_10MS_CHUNKS")
if not isinstance(config_manager.TRANSPORT_AUDIO_OUT_10MS_CHUNKS, int):
raise ValueError(
f"TRANSPORT_AUDIO_OUT_10MS_CHUNKS is not an integer: {config_manager.TRANSPORT_AUDIO_OUT_10MS_CHUNKS}"
)