File size: 4,890 Bytes
d7c9e73
 
 
8277386
d7c9e73
 
c2e2aa2
 
 
 
 
 
 
 
 
d7c9e73
c2e2aa2
 
d7c9e73
a1a2a18
d7c9e73
 
ffec641
d4aa01a
ffec641
d4aa01a
ab37bbe
c2e2aa2
ab37bbe
 
 
 
 
 
40354df
 
c2e2aa2
220b4dd
d7c9e73
6ce82f5
 
d952a61
6ce82f5
d7c9e73
 
6ce82f5
d7c9e73
 
 
 
6ce82f5
 
 
 
 
 
 
2321bd0
 
 
 
 
 
 
 
 
 
d7c9e73
 
 
 
 
 
 
 
 
 
 
ffec641
2321bd0
d7c9e73
 
 
 
 
 
8277386
2321bd0
 
 
 
d7c9e73
2321bd0
 
8277386
d7c9e73
8277386
d7c9e73
 
 
8277386
d7c9e73
c2e2aa2
6ce82f5
d7c9e73
 
 
 
 
 
ffec641
 
 
 
 
c2e2aa2
6ce82f5
 
 
 
 
 
 
 
 
ab37bbe
2321bd0
 
6ce82f5
 
d952a61
6ce82f5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from functools import partial
import os
import re
import time
from xml.parsers.expat import model

# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

from transformers import pipeline as hf_pipeline
import torch
import litellm

from tqdm import tqdm
import subprocess

# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132
# subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)

# pipeline = hf_pipeline(
#     "text-generation",
#     model="meta-llama/Meta-Llama-3.1-8B-Instruct",
#     model_kwargs={"torch_dtype": 'bfloat16'},
#     device_map="auto",
# )


class ModelPrediction:
    def __init__(self):
        self.model_name2pred_func = {
            "gpt-3.5": self._init_model_prediction("gpt-3.5"),
            "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
            "llama-70": self._init_model_prediction("llama-70"),
            "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
                "DeepSeek-R1-Distill-Llama-70B"
            ),
            "llama-8": self._init_model_prediction("llama-8"),
        }

        self._model_name = None
        self._pipeline = None
        self.base_prompt= (
            "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )
        self.base_prompt_QA= (
            "Return the answer of the following question based on the provided database."
            " Return your answer as the result of a query executed over the database."
            " Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n"
            "Return the answer in answer tag as <answer> </answer>"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )

    @staticmethod
    def _extract_answer_from_pred(pred: str) -> str:
        # extract with regex everything is between <answer> and </answer>
        matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
        if matches:
            return matches[-1].replace("```", "").replace("sql", "").strip()
        else:
            matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
            return matches[-1].strip() if matches else pred


    def make_prediction(self, question, db_schema,  model_name, prompt=None, task='SP'):
        if model_name not in self.model_name2pred_func:
            raise ValueError(
                "Model not supported",
                "supported models are",
                self.model_name2pred_func.keys(),
            )
        
        if task == 'SP':
            prompt = prompt or self.base_prompt
        else: 
            prompt = prompt or self.base_prompt_QA

        prompt = prompt.format(question=question, db_schema=db_schema)
        
        start_time = time.time()
        prediction = self.model_name2pred_func[model_name](prompt)
        end_time = time.time()
        prediction["response_parsed"] = self._extract_answer_from_pred(
            prediction["response"]
        )
        prediction['time'] = end_time - start_time
        return prediction

   
    def predict_with_api(self, prompt, model_name):  # -> dict[str, Any | float]:
        response = litellm.completion(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            num_retries=2,
        )
        response_text = response["choices"][0]["message"]["content"]
        return {
            "response": response_text,
            "cost": response._hidden_params["response_cost"],
        }

    def _init_model_prediction(self, model_name):
        predict_fun = self.predict_with_api
        if "gpt-3.5" in model_name:
            model_name = "openai/gpt-3.5-turbo-0125"
        elif "gpt-4o-mini" in model_name:
            model_name = "openai/gpt-4o-mini-2024-07-18"
        elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
            model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
        elif "llama-8" in model_name:
            model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
        elif "llama-70" in model_name:
            model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
        else:
            raise ValueError("Model forbidden")
            
        return partial(predict_fun, model_name=model_name)