File size: 1,426 Bytes
6830eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import annotations
from config.settings import settings
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI

def get_model_identifier(llm) -> str:
    """Get a unique identifier for the model."""
    if isinstance(llm, ChatOpenAI):
        return f"openai-{llm.model_name}"
    elif isinstance(llm, ChatGoogleGenerativeAI):
        return f"gemini-{settings.GEMINI_MODEL_NAME}"
    else:
        return "unknown-model"

def get_llm(model_name: str | None = None):
    """
    Return an LLM instance based on the configured provider.
    """
    provider = settings.MODEL_PROVIDER

    if provider == "openai":
        model_name = model_name or settings.OPENAI_MODEL_NAME
        if not settings.OPENAI_API_KEY:
            raise ValueError("OPENAI_API_KEY is not set")
        llm = ChatOpenAI(
            model=model_name,
            openai_api_key=settings.OPENAI_API_KEY,
            temperature=0,
        )
    elif provider == "google_gemini":
        model_name = model_name or settings.GEMINI_MODEL_NAME
        if not settings.GOOGLE_API_KEY:
            raise ValueError("GOOGLE_API_KEY is not set")
        llm = ChatGoogleGenerativeAI(
            model=model_name,
            temperature=0,
            max_tokens=None,
            max_retries=2,
        )
    else:
        raise ValueError(f"Unknown model provider: {provider}")

    return llm