from functools import partial
import os
import re
import time
from xml.parsers.expat import model

# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper

from transformers import pipeline as hf_pipeline
import torch
import litellm

from tqdm import tqdm
import subprocess

# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132
# subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)

# pipeline = hf_pipeline(
#     "text-generation",
#     model="meta-llama/Meta-Llama-3.1-8B-Instruct",
#     model_kwargs={"torch_dtype": 'bfloat16'},
#     device_map="auto",
# )


class ModelPrediction:
    def __init__(self):
        self.model_name2pred_func = {
            "gpt-3.5": self._init_model_prediction("gpt-3.5"),
            "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
            "llama-70": self._init_model_prediction("llama-70"),
            "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
                "DeepSeek-R1-Distill-Llama-70B"
            ),
            "llama-8": self._init_model_prediction("llama-8"),
        }

        self._model_name = None
        self._pipeline = None
        self.base_prompt= (
            "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )
        self.base_prompt_QA= (
            "Return the answer of the following question based on the provided database."
            " Return your answer as the result of a query executed over the database."
            " Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n"
            "Return the answer in answer tag as <answer> </answer>"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )

    @staticmethod
    def _extract_answer_from_pred(pred: str) -> str:
        # extract with regex everything is between <answer> and </answer>
        matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
        if matches:
            return matches[-1].replace("```", "").replace("sql", "").strip()
        else:
            matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
            return matches[-1].strip() if matches else pred


    def make_prediction(self, question, db_schema,  model_name, prompt=None, task='SP'):
        if model_name not in self.model_name2pred_func:
            raise ValueError(
                "Model not supported",
                "supported models are",
                self.model_name2pred_func.keys(),
            )
        
        if task == 'SP':
            prompt = prompt or self.base_prompt
        else: 
            prompt = prompt or self.base_prompt_QA

        prompt = prompt.format(question=question, db_schema=db_schema)
        
        start_time = time.time()
        prediction = self.model_name2pred_func[model_name](prompt)
        end_time = time.time()
        prediction["response_parsed"] = self._extract_answer_from_pred(
            prediction["response"]
        )
        prediction['time'] = end_time - start_time
        return prediction

   
    def predict_with_api(self, prompt, model_name):  # -> dict[str, Any | float]:
        response = litellm.completion(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            num_retries=2,
        )
        response_text = response["choices"][0]["message"]["content"]
        return {
            "response": response_text,
            "cost": response._hidden_params["response_cost"],
        }

    def _init_model_prediction(self, model_name):
        predict_fun = self.predict_with_api
        if "gpt-3.5" in model_name:
            model_name = "openai/gpt-3.5-turbo-0125"
        elif "gpt-4o-mini" in model_name:
            model_name = "openai/gpt-4o-mini-2024-07-18"
        elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
            model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
        elif "llama-8" in model_name:
            model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
        elif "llama-70" in model_name:
            model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
        else:
            raise ValueError("Model forbidden")
            
        return partial(predict_fun, model_name=model_name)