Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from config.settings import settings | |
from langchain_openai import ChatOpenAI | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
def get_model_identifier(llm) -> str: | |
"""Get a unique identifier for the model.""" | |
if isinstance(llm, ChatOpenAI): | |
return f"openai-{llm.model_name}" | |
elif isinstance(llm, ChatGoogleGenerativeAI): | |
return f"gemini-{settings.GEMINI_MODEL_NAME}" | |
else: | |
return "unknown-model" | |
def get_llm(model_name: str | None = None): | |
""" | |
Return an LLM instance based on the configured provider. | |
""" | |
provider = settings.MODEL_PROVIDER | |
if provider == "openai": | |
model_name = model_name or settings.OPENAI_MODEL_NAME | |
if not settings.OPENAI_API_KEY: | |
raise ValueError("OPENAI_API_KEY is not set") | |
llm = ChatOpenAI( | |
model=model_name, | |
openai_api_key=settings.OPENAI_API_KEY, | |
temperature=0, | |
) | |
elif provider == "google_gemini": | |
model_name = model_name or settings.GEMINI_MODEL_NAME | |
if not settings.GOOGLE_API_KEY: | |
raise ValueError("GOOGLE_API_KEY is not set") | |
llm = ChatGoogleGenerativeAI( | |
model=model_name, | |
temperature=0, | |
max_tokens=None, | |
max_retries=2, | |
) | |
else: | |
raise ValueError(f"Unknown model provider: {provider}") | |
return llm |