File size: 3,373 Bytes
873d0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq

try:
    from .utils.db import (
        load_api_key,
        load_openai_url,
        load_model_settings,
        load_groq_api_key,
        load_google_api_key,
    )
    from .custom_callback import customcallback
    from .llm_settings import llm_settings
except ImportError:
    from utils.db import (
        load_api_key,
        load_openai_url,
        load_model_settings,
        load_groq_api_key,
        load_google_api_key,
    )
    from custom_callback import customcallback
    from llm_settings import llm_settings


the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"])


def get_model(high_context=False):
    the_model = load_model_settings()
    the_api_key = load_api_key()
    the_groq_api_key = load_groq_api_key()
    the_google_api_key = load_google_api_key()
    the_openai_url = load_openai_url()

    def open_ai_base(high_context):
        if the_openai_url == "default":
            true_model = the_model
            if high_context:
                true_model = "gpt-4-turbo"
            return {
                "model": true_model,
                "api_key": the_api_key,
                "max_retries": 15,
                "streaming": True,
                "callbacks": [the_callback],
            }
        else:
            return {
                "model": the_model,
                "api_key": the_api_key,
                "max_retries": 15,
                "streaming": True,
                "callbacks": [the_callback],
                "base_url": the_openai_url,
            }

    args_mapping = {
        ChatOpenAI: open_ai_base(high_context=high_context),
        ChatOllama: {"model": the_model},
        ChatGroq: {
            "temperature": 0,
            "model_name": the_model.replace("-groq", ""),
            "groq_api_key": the_openai_url,
        },
        ChatGoogleGenerativeAI: {
            "model": the_model,
            "google_api_key": the_google_api_key,
        },
    }

    model_mapping = {}

    for model_name, model_args in llm_settings.items():
        the_tuple = None
        if model_args["provider"] == "openai":
            the_tuple = (ChatOpenAI, args_mapping[ChatOpenAI])
        elif model_args["provider"] == "ollama":
            the_tuple = (
                ChatOpenAI,
                {
                    "api_key": "ollama",
                    "base_url": "http://localhost:11434/v1",
                    "model": model_name,
                },
            )
        elif model_args["provider"] == "google":
            the_tuple = (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI])
        elif model_args["provider"] == "groq":
            the_tuple = (ChatGroq, args_mapping[ChatGroq])

        if the_tuple:
            model_mapping[model_name] = the_tuple

    model_class, args = model_mapping[the_model]
    return model_class(**args) if model_class else None


def get_client():
    the_api_key = load_api_key()
    the_openai_url = load_openai_url()
    if the_openai_url == "default":
        return OpenAI(api_key=the_api_key)
    else:
        return OpenAI(api_key=the_api_key, base_url=the_openai_url)