qatch-demo / prediction.py
simone-papicchio's picture
Fix prediction (#17)
ca89877 verified
from functools import partial
import os
import re
import time
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
from transformers import pipeline as hf_pipeline
import litellm
from tqdm import tqdm
class ModelPrediction:
def __init__(self):
self.model_name2pred_func = {
"gpt-3.5": self._init_model_prediction("gpt-3.5"),
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
"o1-mini": self._init_model_prediction("o1-mini"),
"QwQ": self._init_model_prediction("QwQ"),
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
"DeepSeek-R1-Distill-Llama-70B"
),
"llama-8": self._init_model_prediction("llama-8"),
}
self._model_name = None
self._pipeline = None
self.base_prompt= (
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
@property
def pipeline(self):
if self._pipeline is None:
self._pipeline = hf_pipeline(
task="text-generation",
model=self._model_name,
device_map="auto",
)
return self._pipeline
def _reset_pipeline(self, model_name):
if self._model_name != model_name:
self._model_name = model_name
self._pipeline = None
@staticmethod
def _extract_answer_from_pred(pred: str) -> str:
# extract with regex everything is between <answer> and </answer>
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
if matches:
return matches[-1].replace("```", "").replace("sql", "").strip()
else:
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
return matches[-1].strip() if matches else pred
def make_prediction(self, question, db_schema, model_name, prompt=None):
if model_name not in self.model_name2pred_func:
raise ValueError(
"Model not supported",
"supported models are",
self.model_name2pred_func.keys(),
)
prompt = prompt or self.base_prompt
#prompt = prompt.format(question=question, db_schema=db_schema)
start_time = time.time()
prediction = self.model_name2pred_func[model_name](prompt)
end_time = time.time()
prediction["response_parsed"] = self._extract_answer_from_pred(
prediction["response"]
)
prediction['time'] = end_time - start_time
return prediction
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
response = litellm.completion(
model=model_name,
messages=[{"role": "user", "content": prompt}],
num_retries=2,
)
response_text = response["choices"][0]["message"]["content"]
return {
"response": response_text,
"cost": response._hidden_params["response_cost"],
}
@spaces.GPU
def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]:
self._reset_pipeline(model_name)
response = self.pipeline([{"role": "user", "content": prompt}])[0][
"generated_text"
][-1]["content"]
return {"response": response, "cost": 0.0}
def _init_model_prediction(self, model_name):
predict_fun = self.predict_with_api
if "gpt-3.5" in model_name:
model_name = "openai/gpt-3.5-turbo-0125"
elif "gpt-4o-mini" in model_name:
model_name = "openai/gpt-4o-mini-2024-07-18"
elif "o1-mini" in model_name:
model_name = "openai/o1-mini-2024-09-12"
elif "QwQ" in model_name:
model_name = "together_ai/Qwen/QwQ-32B"
elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
elif "llama-8" in model_name:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
predict_fun = self.predict_with_hf
else:
raise ValueError("Model forbidden")
return partial(predict_fun, model_name=model_name)