File size: 4,486 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import warnings
import multiprocessing
from functools import partial
from stark_qa.tools.api_lib.claude import complete_text_claude
from stark_qa.tools.api_lib.gpt import get_gpt_output
from stark_qa.tools.api_lib.huggingface import complete_text_hf
from stark_qa.tools.api_lib.openai_emb import get_openai_embedding, get_openai_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings
from stark_qa.tools.api_lib.voyage_emb import get_voyage_embedding, get_voyage_embeddings


# Default parameters for retrying API calls and the sleep time between retries
MAX_OPENAI_RETRY = int(os.getenv("MAX_OPENAI_RETRY", 5))
OPENAI_SLEEP_TIME = int(os.getenv("OPENAI_SLEEP_TIME", 60))
MAX_CLAUDE_RETRY = int(os.getenv("MAX_CLAUDE_RETRY", 10))
CLAUDE_SLEEP_TIME = int(os.getenv("CLAUDE_SLEEP_TIME", 0))
LLM_PARALLEL_NODES = int(os.getenv("LLM_PARALLEL_NODES", 5))

# Register the available text completion LLMs
registered_text_completion_llms = {
    "gpt-4-1106-preview",
    "gpt-4-0125-preview",
    "gpt-4-turbo-preview",
    "gpt-4-turbo",
    "gpt-4-turbo-2024-04-09",
    "claude-2.1",
    "claude-3-opus-20240229", 
    "claude-3-sonnet-20240229", 
    "claude-3-haiku-20240307",
    "huggingface/codellama/CodeLlama-7b-hf",
    "text-embedding-3-small",
    "text-embedding-3-large",
    "text-embedding-ada-002"
}



def get_api_embedding(text, model_name, *args, **kwargs):
    if 'voyage' in model_name:
        return get_voyage_embedding(text, model_name, *args, **kwargs)
    elif 'text-embedding' in model_name:
        return get_openai_embedding(text, model_name, *args, **kwargs)
    else:
        raise ValueError(f"Model {model_name} not recognized.")


def get_api_embeddings(text, model_name, *args, **kwargs):
    if 'voyage' in model_name:
        return get_voyage_embeddings(text, model_name, *args, **kwargs)
    elif 'text-embedding' in model_name:
        return get_openai_embeddings(text, model_name, *args, **kwargs)
    else:
        raise ValueError(f"Model {model_name} not recognized.")


def parallel_func(func, n_max_nodes=LLM_PARALLEL_NODES):
    """
    A general function to call a function on a list of inputs in parallel.

    Args:
        func (callable): The function to apply.
        n_max_nodes (int): Maximum number of parallel processes.

    Returns:
        callable: A wrapper function that applies `func` in parallel.
    """
    def _parallel_func(inputs: list, **kwargs):
        partial_func = partial(func, **kwargs)
        processes = min(len(inputs), n_max_nodes)
        with multiprocessing.Pool(processes=processes) as pool:
            results = pool.map(partial_func, inputs)
        return results
    return _parallel_func


def get_llm_output(message, 
                   model="gpt-4-0125-preview", 
                   max_tokens=2048, 
                   temperature=1, 
                   json_object=False):
    """
    A general function to complete a prompt using the specified model.

    Args:
        message (str or list): The input message or a list of message dicts.
        model (str): The model to use for completion.
        max_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature.
        json_object (bool): Whether to output in JSON format.

    Returns:
        str: The completed text generated by the model.

    Raises:
        ValueError: If the model is not recognized.
    """
    if model not in registered_text_completion_llms:
        warnings.warn(f"Model {model} is not registered. You may still be able to use it.")
    
    kwargs = {
        'message': message, 
        'model': model, 
        'max_tokens': max_tokens, 
        'temperature': temperature, 
        'json_object': json_object
    }
    
    if 'gpt-4' in model:
        kwargs.update({'max_retry': MAX_OPENAI_RETRY, 'sleep_time': OPENAI_SLEEP_TIME})
        return get_gpt_output(**kwargs)
    elif 'claude' in model:
        kwargs.update({'max_retry': MAX_CLAUDE_RETRY, 'sleep_time': CLAUDE_SLEEP_TIME})
        return complete_text_claude(**kwargs)
    elif 'huggingface' in model:
        return complete_text_hf(**kwargs)
    else:
        raise ValueError(f"Model {model} not recognized.")

# Parallel functions for text completion
complete_texts_claude = parallel_func(complete_text_claude)
complete_texts_hf = parallel_func(complete_text_hf)
get_gpt_outputs = parallel_func(get_gpt_output)
get_llm_outputs = parallel_func(get_llm_output)