|
""" |
|
Configuration loader for the Mental Health Chatbot |
|
""" |
|
|
|
import os |
|
import yaml |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, Optional |
|
from pathlib import Path |
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
@dataclass |
|
class RAGConfig: |
|
"""Configuration for RAG agent""" |
|
def __init__(self, config_dict: Dict[str, Any]): |
|
self.config_dict = config_dict |
|
self.llm = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
temperature=0.3, |
|
google_api_key=os.environ.get("GOOGLE_API_KEY") |
|
) |
|
|
|
|
|
if 'rag' in config_dict: |
|
rag_config = config_dict['rag'] |
|
self.embedding_dim = rag_config.get('embedding_dim', 384) |
|
|
|
|
|
self.embedding_model = None |
|
self.embedding_model_name = "all-MiniLM-L6-v2" |
|
|
|
self.collection_name = rag_config.get('collection_name', 'mental_health_docs') |
|
self.chunk_size = rag_config.get('chunk_size', 256) |
|
self.chunk_overlap = rag_config.get('chunk_overlap', 32) |
|
self.reranker_model = rag_config.get('reranker_model', 'cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
self.reranker_top_k = rag_config.get('reranker_top_k', 5) |
|
self.max_context_length = rag_config.get('max_context_length', 2048) |
|
self.include_sources = rag_config.get('include_sources', True) |
|
self.use_local = rag_config.get('use_local', True) |
|
self.url = rag_config.get('url', 'http://localhost:6333') |
|
self.distance_metric = rag_config.get('distance_metric', 'Cosine') |
|
self.min_retrieval_confidence = rag_config.get('min_retrieval_confidence', 0.85) |
|
|
|
|
|
self.processed_docs_dir = rag_config.get('processed_docs_dir', 'processed_docs') |
|
self.knowledge_dir = rag_config.get('knowledge_dir', 'knowledge') |
|
else: |
|
|
|
self.embedding_dim = 384 |
|
self.embedding_model = None |
|
self.embedding_model_name = "all-MiniLM-L6-v2" |
|
self.collection_name = 'mental_health_docs' |
|
self.chunk_size = 256 |
|
self.chunk_overlap = 32 |
|
self.reranker_model = 'cross-encoder/ms-marco-MiniLM-L-6-v2' |
|
self.reranker_top_k = 5 |
|
self.max_context_length = 2048 |
|
self.include_sources = True |
|
self.use_local = True |
|
self.url = 'http://localhost:6333' |
|
self.distance_metric = 'Cosine' |
|
self.min_retrieval_confidence = 0.85 |
|
self.processed_docs_dir = 'processed_docs' |
|
self.knowledge_dir = 'knowledge' |
|
|
|
self.context_limit = 4 |
|
|
|
def get_embedding_model(self): |
|
"""Lazy load the embedding model only when needed""" |
|
if self.embedding_model is None: |
|
try: |
|
from sentence_transformers import SentenceTransformer |
|
self.embedding_model = SentenceTransformer(self.embedding_model_name) |
|
print(f"✅ Embedding model loaded: {self.embedding_model_name}") |
|
except Exception as e: |
|
print(f"⚠️ Failed to load embedding model: {e}") |
|
return None |
|
return self.embedding_model |
|
|
|
|
|
@dataclass |
|
class ConversationConfig: |
|
"""Configuration for Conversation agent""" |
|
def __init__(self): |
|
self.llm = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
temperature=0.7, |
|
google_api_key=os.environ.get("GOOGLE_API_KEY") |
|
) |
|
|
|
|
|
@dataclass |
|
class WebSearchConfig: |
|
"""Configuration for Web Search agent""" |
|
def __init__(self): |
|
self.context_limit = 4 |
|
self.llm = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash", |
|
temperature=0.5, |
|
google_api_key=os.environ.get("GOOGLE_API_KEY") |
|
) |
|
|
|
self.tavily_api_key = os.environ.get("TAVILY_API_KEY", "tvly-your-api-key-here") |
|
|
|
|
|
@dataclass |
|
class AgentDecisionConfig: |
|
"""Configuration for Agent Decision system""" |
|
def __init__(self): |
|
self.llm = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
temperature=0, |
|
google_api_key=os.environ.get("GOOGLE_API_KEY") |
|
) |
|
|
|
|
|
class Config: |
|
"""Main configuration class that loads from YAML files""" |
|
|
|
def __init__(self): |
|
|
|
os.environ["GOOGLE_API_KEY"] = "AIzaSyDzBTzKt211XwMurywdk5HFCnFeeFxcRJ0" |
|
os.environ["TAVILY_API_KEY"] = "tvly-your-api-key-here" |
|
|
|
|
|
self.config_dict = self._load_yaml_configs() |
|
|
|
|
|
self.rag = RAGConfig(self.config_dict) |
|
self.conversation = ConversationConfig() |
|
self.web_search = WebSearchConfig() |
|
self.agent_decision = AgentDecisionConfig() |
|
|
|
|
|
self.max_conversation_history = 20 |
|
|
|
def _load_yaml_configs(self) -> Dict[str, Any]: |
|
"""Load all YAML configuration files""" |
|
config_dict = {} |
|
config_dir = Path(__file__).parent |
|
|
|
|
|
yaml_files = ['agents.yaml', 'rag.yaml', 'tasks.yaml'] |
|
for yaml_file in yaml_files: |
|
file_path = config_dir / yaml_file |
|
if file_path.exists(): |
|
with open(file_path, 'r') as f: |
|
data = yaml.safe_load(f) |
|
if data: |
|
config_dict.update(data) |
|
|
|
return config_dict |
|
|
|
def get_agent_config(self, agent_name: str) -> Dict[str, Any]: |
|
"""Get configuration for a specific agent""" |
|
return self.config_dict.get(agent_name, {}) |
|
|
|
def get_task_config(self, task_name: str) -> Dict[str, Any]: |
|
"""Get configuration for a specific task""" |
|
return self.config_dict.get(task_name, {}) |