TherapyNote / models /llm_provider.py
abagherp's picture
Upload folder using huggingface_hub
6830eb0 verified
from __future__ import annotations
from config.settings import settings
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
def get_model_identifier(llm) -> str:
"""Get a unique identifier for the model."""
if isinstance(llm, ChatOpenAI):
return f"openai-{llm.model_name}"
elif isinstance(llm, ChatGoogleGenerativeAI):
return f"gemini-{settings.GEMINI_MODEL_NAME}"
else:
return "unknown-model"
def get_llm(model_name: str | None = None):
"""
Return an LLM instance based on the configured provider.
"""
provider = settings.MODEL_PROVIDER
if provider == "openai":
model_name = model_name or settings.OPENAI_MODEL_NAME
if not settings.OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY is not set")
llm = ChatOpenAI(
model=model_name,
openai_api_key=settings.OPENAI_API_KEY,
temperature=0,
)
elif provider == "google_gemini":
model_name = model_name or settings.GEMINI_MODEL_NAME
if not settings.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set")
llm = ChatGoogleGenerativeAI(
model=model_name,
temperature=0,
max_tokens=None,
max_retries=2,
)
else:
raise ValueError(f"Unknown model provider: {provider}")
return llm