Spaces:
Running
on
Zero
Running
on
Zero
from functools import partial | |
import os | |
import re | |
import time | |
from xml.parsers.expat import model | |
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 | |
if os.environ.get("SPACES_ZERO_GPU") is not None: | |
import spaces | |
else: | |
class spaces: | |
def GPU(func): | |
def wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
return wrapper | |
from transformers import pipeline as hf_pipeline | |
import litellm | |
from tqdm import tqdm | |
class ModelPrediction: | |
def __init__(self): | |
self.model_name2pred_func = { | |
"gpt-3.5": self._init_model_prediction("gpt-3.5"), | |
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), | |
"o1-mini": self._init_model_prediction("o1-mini"), | |
"QwQ": self._init_model_prediction("QwQ"), | |
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( | |
"DeepSeek-R1-Distill-Llama-70B" | |
), | |
"llama-8": self._init_model_prediction("llama-8"), | |
} | |
self._model_name = None | |
self._pipeline = None | |
self.base_prompt= ( | |
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" | |
" Question\n" | |
"{question}\n" | |
"Database Schema\n" | |
"{db_schema}\n" | |
) | |
def pipeline(self): | |
if self._pipeline is None: | |
self._pipeline = hf_pipeline( | |
task="text-generation", | |
model=self._model_name, | |
device_map="auto", | |
) | |
return self._pipeline | |
def _reset_pipeline(self, model_name): | |
if self._model_name != model_name: | |
self._model_name = model_name | |
self._pipeline = None | |
def _extract_answer_from_pred(pred: str) -> str: | |
# extract with regex everything is between <answer> and </answer> | |
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL) | |
if matches: | |
return matches[-1].replace("```", "").replace("sql", "").strip() | |
else: | |
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) | |
return matches[-1].strip() if matches else pred | |
def make_prediction(self, question, db_schema, model_name, prompt=None): | |
if model_name not in self.model_name2pred_func: | |
raise ValueError( | |
"Model not supported", | |
"supported models are", | |
self.model_name2pred_func.keys(), | |
) | |
prompt = prompt or self.base_prompt | |
#prompt = prompt.format(question=question, db_schema=db_schema) | |
start_time = time.time() | |
prediction = self.model_name2pred_func[model_name](prompt) | |
end_time = time.time() | |
prediction["response_parsed"] = self._extract_answer_from_pred( | |
prediction["response"] | |
) | |
prediction['time'] = end_time - start_time | |
return prediction | |
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]: | |
response = litellm.completion( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
num_retries=2, | |
) | |
response_text = response["choices"][0]["message"]["content"] | |
return { | |
"response": response_text, | |
"cost": response._hidden_params["response_cost"], | |
} | |
def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]: | |
self._reset_pipeline(model_name) | |
response = self.pipeline([{"role": "user", "content": prompt}])[0][ | |
"generated_text" | |
][-1]["content"] | |
return {"response": response, "cost": 0.0} | |
def _init_model_prediction(self, model_name): | |
predict_fun = self.predict_with_api | |
if "gpt-3.5" in model_name: | |
model_name = "openai/gpt-3.5-turbo-0125" | |
elif "gpt-4o-mini" in model_name: | |
model_name = "openai/gpt-4o-mini-2024-07-18" | |
elif "o1-mini" in model_name: | |
model_name = "openai/o1-mini-2024-09-12" | |
elif "QwQ" in model_name: | |
model_name = "together_ai/Qwen/QwQ-32B" | |
elif "DeepSeek-R1-Distill-Llama-70B" in model_name: | |
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" | |
elif "llama-8" in model_name: | |
model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
predict_fun = self.predict_with_hf | |
else: | |
raise ValueError("Model forbidden") | |
return partial(predict_fun, model_name=model_name) | |