Spaces:
Sleeping
Sleeping
File size: 1,426 Bytes
6830eb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from __future__ import annotations
from config.settings import settings
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
def get_model_identifier(llm) -> str:
"""Get a unique identifier for the model."""
if isinstance(llm, ChatOpenAI):
return f"openai-{llm.model_name}"
elif isinstance(llm, ChatGoogleGenerativeAI):
return f"gemini-{settings.GEMINI_MODEL_NAME}"
else:
return "unknown-model"
def get_llm(model_name: str | None = None):
"""
Return an LLM instance based on the configured provider.
"""
provider = settings.MODEL_PROVIDER
if provider == "openai":
model_name = model_name or settings.OPENAI_MODEL_NAME
if not settings.OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY is not set")
llm = ChatOpenAI(
model=model_name,
openai_api_key=settings.OPENAI_API_KEY,
temperature=0,
)
elif provider == "google_gemini":
model_name = model_name or settings.GEMINI_MODEL_NAME
if not settings.GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY is not set")
llm = ChatGoogleGenerativeAI(
model=model_name,
temperature=0,
max_tokens=None,
max_retries=2,
)
else:
raise ValueError(f"Unknown model provider: {provider}")
return llm |