Spaces:
Sleeping
Sleeping
File size: 4,890 Bytes
d7c9e73 8277386 d7c9e73 c2e2aa2 d7c9e73 c2e2aa2 d7c9e73 a1a2a18 d7c9e73 ffec641 d4aa01a ffec641 d4aa01a ab37bbe c2e2aa2 ab37bbe 40354df c2e2aa2 220b4dd d7c9e73 6ce82f5 d952a61 6ce82f5 d7c9e73 6ce82f5 d7c9e73 6ce82f5 2321bd0 d7c9e73 ffec641 2321bd0 d7c9e73 8277386 2321bd0 d7c9e73 2321bd0 8277386 d7c9e73 8277386 d7c9e73 8277386 d7c9e73 c2e2aa2 6ce82f5 d7c9e73 ffec641 c2e2aa2 6ce82f5 ab37bbe 2321bd0 6ce82f5 d952a61 6ce82f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from functools import partial
import os
import re
import time
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
from transformers import pipeline as hf_pipeline
import torch
import litellm
from tqdm import tqdm
import subprocess
# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132
# subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
# pipeline = hf_pipeline(
# "text-generation",
# model="meta-llama/Meta-Llama-3.1-8B-Instruct",
# model_kwargs={"torch_dtype": 'bfloat16'},
# device_map="auto",
# )
class ModelPrediction:
def __init__(self):
self.model_name2pred_func = {
"gpt-3.5": self._init_model_prediction("gpt-3.5"),
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
"llama-70": self._init_model_prediction("llama-70"),
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
"DeepSeek-R1-Distill-Llama-70B"
),
"llama-8": self._init_model_prediction("llama-8"),
}
self._model_name = None
self._pipeline = None
self.base_prompt= (
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
self.base_prompt_QA= (
"Return the answer of the following question based on the provided database."
" Return your answer as the result of a query executed over the database."
" Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n"
"Return the answer in answer tag as <answer> </answer>"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
@staticmethod
def _extract_answer_from_pred(pred: str) -> str:
# extract with regex everything is between <answer> and </answer>
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
if matches:
return matches[-1].replace("```", "").replace("sql", "").strip()
else:
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
return matches[-1].strip() if matches else pred
def make_prediction(self, question, db_schema, model_name, prompt=None, task='SP'):
if model_name not in self.model_name2pred_func:
raise ValueError(
"Model not supported",
"supported models are",
self.model_name2pred_func.keys(),
)
if task == 'SP':
prompt = prompt or self.base_prompt
else:
prompt = prompt or self.base_prompt_QA
prompt = prompt.format(question=question, db_schema=db_schema)
start_time = time.time()
prediction = self.model_name2pred_func[model_name](prompt)
end_time = time.time()
prediction["response_parsed"] = self._extract_answer_from_pred(
prediction["response"]
)
prediction['time'] = end_time - start_time
return prediction
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
response = litellm.completion(
model=model_name,
messages=[{"role": "user", "content": prompt}],
num_retries=2,
)
response_text = response["choices"][0]["message"]["content"]
return {
"response": response_text,
"cost": response._hidden_params["response_cost"],
}
def _init_model_prediction(self, model_name):
predict_fun = self.predict_with_api
if "gpt-3.5" in model_name:
model_name = "openai/gpt-3.5-turbo-0125"
elif "gpt-4o-mini" in model_name:
model_name = "openai/gpt-4o-mini-2024-07-18"
elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
elif "llama-8" in model_name:
model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
elif "llama-70" in model_name:
model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
else:
raise ValueError("Model forbidden")
return partial(predict_fun, model_name=model_name)
|