import os
import time
from datetime import datetime
import logging
from pathlib import Path  
import requests
import json

import numpy as np
import pandas as pd
import spacy
import litellm
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification, AutoConfig, Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
import torch
import cohere
from openai import OpenAI
from together import Together
import anthropic
import replicate
# import google.generativeai as genai
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting, FinishReason
from mistralai import Mistral
from qwen_vl_utils import process_vision_info


import src.backend.util as util
import src.envs as envs

litellm.set_verbose=True

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Load spacy model for word tokenization
nlp = spacy.load("en_core_web_sm")

os.environ["HUGGINGFACE_API_KEY"] =  envs.TOKEN

class ModelLoadingException(Exception):
    """Exception raised for errors in loading a model.

    Attributes:
        model_id (str): The model identifier.
        revision (str): The model revision.
    """

    def __init__(self, model_id, revision, messages="Error initializing model"):
        self.model_id = model_id
        self.revision = revision
        super().__init__(f"{messages} id={model_id} revision={revision}")


class SummaryGenerator:
    """A class to generate summaries using a causal language model.

    Attributes:
        model (str): huggingface/{model_id}
        api_base (str): https://api-inference.huggingface.co/models/{model_id}
        summaries_df (DataFrame): DataFrame to store generated summaries.
        revision (str): Model revision.
        avg_length (float): Average length of summaries.
        answer_rate (float): Rate of non-empty summaries.
    """

    def __init__(self, model_id, revision, device):
        """
        Initializes the SummaryGenerator with a model.

        Args:
            model_id (str): Identifier for the model.
            revision (str): Revision of the model.
        """
        self.model_id = model_id
        self.model = f"huggingface/{model_id}"
        self.api_base = f"https://api-inference.huggingface.co/models/{model_id}"
        self.summaries_df = pd.DataFrame()
        self.revision = revision
        self.device = device
        self.avg_length = None
        self.answer_rate = None
        self.exceptions = None
        self.local_model = None
        self.local_pipeline = None

    def generate_summaries(self, df, save_path=None):
        """Generate summaries for a given DataFrame of source docs.

        Args:
            df (DataFrame): DataFrame containing source docs.

        Returns:
            summaries_df (DataFrame): Generated summaries by the model.
        """
        exceptions = []
        if (save_path is not None) and os.path.exists(save_path):
            self.summaries_df = pd.read_csv(save_path)
            print(f'Loaded generated summaries from {save_path}')
        else:
            source, summary, dataset = [], [], [] 
            print(f"Total: {df.shape[0]}")
            for index, row in tqdm(df.iterrows(), total=df.shape[0]):
                _source = row['text']
                _dataset = row['dataset']

                system_prompt = envs.SYSTEM_PROMPT
                user_prompt = f"{envs.USER_PROMPT}\nPassage:\n{_source}"
                _summary = None

                while not _summary:
                    try:
                        _summary = self.generate_summary(system_prompt, user_prompt)
                        # print(f"Finish index {index}")
                        break
                    except Exception as e:
                        if 'Rate limit reached' in str(e):
                            wait_time = 300
                            current_time = datetime.now().strftime('%H:%M:%S')
                            print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...")
                            time.sleep(wait_time)
                        elif 'is currently loading' in str(e):
                            wait_time = 200
                            print(f"Model is loading, wait for {wait_time}")
                            time.sleep(wait_time)
                        elif '429' in str(e): # for gemini models
                            wait_time = 60
                            print(f"Quota has reached, wait for {wait_time}")
                            time.sleep(wait_time)
                        else:
                            print(f"Error at index {index}: {e}")
                            _summary = ""
                            exceptions.append(index)
                            break

                summary.append(_summary)
                source.append(_source)
                dataset.append(_dataset)

                # Sleep to prevent hitting rate limits too frequently
                time.sleep(1)

            self.summaries_df = pd.DataFrame(list(zip(source, summary, dataset)),
                                            columns=["source", "summary", "dataset"])

            if save_path is not None:
                print(f'Save summaries to {save_path}')
                fpath = Path(save_path)
                fpath.parent.mkdir(parents=True, exist_ok=True)
                self.summaries_df.to_csv(fpath) 

        self.exceptions = exceptions
        self._compute_avg_length()
        self._compute_answer_rate()

        return self.summaries_df
    
    def generate_summary(self, system_prompt: str, user_prompt: str):
        # Using Together AI API
        using_together_api = False
        together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen2-72b-instruct', 'zero-one-ai', 'llama-3.2-'] #, 'mistralai'
        using_replicate_api = False
        replicate_api_models = ['snowflake', 'llama-3.1-405b']
        using_pipeline = False
        pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo', 'llama-3.3', 'phi-4']

        for replicate_api_model in replicate_api_models:
            if replicate_api_model in self.model_id.lower():
                using_replicate_api = True
                break

        if not using_replicate_api:
            for together_ai_api_model in together_ai_api_models:
                if together_ai_api_model in self.model_id.lower():
                    using_together_api = True
                    break

        if not using_replicate_api and not using_together_api:
            for pipeline_model in pipeline_models:
                if pipeline_model in self.model_id.lower():
                    using_pipeline = True
                    break

        # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API
        if using_together_api:
            print('using together api')
            client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
            if 'llama-3.2-90b-vision' in self.model_id.lower() or 'llama-3.2-11b-vision' in self.model_id.lower():
                messages = [
                        {"role": "system","content": system_prompt},
                        {"role": "user","content": [{"type": "text","text": user_prompt}]}
                ]
            else:
                messages = [{"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt}]
            response = client.chat.completions.create(
                model=self.model_id,
                messages = messages,
                max_tokens=250,
                temperature=0,
            )
            # print(response)
            result = response.choices[0].message.content
            print(result)
            return result

        # Using OpenAI API
        elif 'openai' in self.model_id.lower():
            client = OpenAI()
            response = client.chat.completions.create(
                model=self.model_id.replace('openai/',''),
                messages=[{"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id
                        else [{"role": "user", "content": system_prompt + '\n' + user_prompt}],
                temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models 
                # max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, #  not compatible with o1 series models
            )   
            # print(response)
            result = response.choices[0].message.content
            print(result)
            return result

        # Using Grok API
        elif 'grok' in self.model_id.lower(): # xai
            XAI_API_KEY = os.getenv("XAI_API_KEY")
            client = OpenAI(
                api_key=XAI_API_KEY,
                base_url="https://api.x.ai/v1",
            )

            completion = client.chat.completions.create(
                model=self.model_id.split('/')[-1],
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=0.0
            )
            result = completion.choices[0].message.content
            print(result)
            return result

        # Using Vertex AI API for Gemini models
        elif 'gemini' in self.model_id.lower():
            vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
            model = GenerativeModel(
                self.model_id.lower().split('google/')[-1],
                system_instruction = [system_prompt]
            )
            generation_config = {
                "temperature": 0,
                "max_output_tokens": 500
            }
            safety_settings = [
                SafetySetting(
                    category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
                    threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
                ),
                SafetySetting(
                    category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
                    threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
                ),
                SafetySetting(
                    category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
                    threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
                ),
                SafetySetting(
                    category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
                    threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
                )
            ]
            response = model.generate_content(
                user_prompt,
                safety_settings=safety_settings,
                generation_config=generation_config
            )
            result = response.text
            print(result)
            return result
        
        # Using Replicate API
        elif using_replicate_api:
            print("using replicate")
            if 'snowflake' in self.model_id.lower():
                input = {
                    "prompt": user_prompt,
                    "temperature": 0,
                    "max_new_tokens": 250,
                    "stop_sequences": "<|im_end|>",
                    "prompt_template": f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + "<|im_start|>user\n{prompt}<|im_end|>\n\n<|im_start|>assistant\n",
                }
            else:
                input = {
                    "prompt": user_prompt,
                    "system_prompt": system_prompt,
                    "temperature": 0,
                    "max_new_tokens": 250
                }
            response = replicate.run(
                self.model_id,
                input=input
            )
            # print(response)
            if isinstance(response, list):
                response = ''.join(response)
                # print(response)
                # print()
            print(response)
            return response

        # Using Anthropic API for Claude models
        elif 'claude' in self.model_id.lower(): # using anthropic api
            print('using Anthropic API')
            client = anthropic.Anthropic()
            message = client.messages.create(
                model=self.model_id.split('/')[-1],
                max_tokens=1024,
                temperature=0,
                system=system_prompt,
                messages=[
                    {
                        "role": "user",
                        # "content": [
                        #     {
                        #         "type": "text",
                        #         "text": user_prompt
                        #     }
                        # ]
                        "content": user_prompt
                    }
                ]
            )
            result = message.content[0].text
            print(result)
            return result

        # Using Cohere API
        elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
            co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
            response = co.chat(
                model=self.model_id.split('/')[-1],
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0,
            )
            result = response.message.content[0].text
            print(result)
            return result

        # Using MistralAI API
        elif 'mistral-large' in self.model_id.lower():
            api_key = os.environ["MISTRAL_API_KEY"]
            client = Mistral(api_key=api_key)

            messages = [
                {
                    "role":"system", 
                    "content":system_prompt
                },
                {
                    "role":"user", 
                    "content":user_prompt
                }
            ]

            # No streaming
            chat_response = client.chat.complete(
                model=self.model_id,
                messages=messages,
            )
            result = chat_response.choices[0].message.content
            print(result)
            return result

        # Using Deepseek API
        elif 'deepseek' in self.model_id.lower():
            client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
            response = client.chat.completions.create(
                model=self.model_id.split('/')[-1],
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                max_tokens=250,
                temperature=0,
                stream=False
            )
            result = response.choices[0].message.content
            print(result)
            return result
        
        # Using HF pipeline or local checkpoints
        elif self.local_model is None and self.local_pipeline is None:
            if using_pipeline:
                self.local_pipeline = pipeline(
                    "text-generation",
                    model=self.model_id,
                    tokenizer=AutoTokenizer.from_pretrained(self.model_id),
                    torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() or 'llama-3.3' in self.model_id.lower() else "auto",
                    device_map="auto",
                    trust_remote_code=True
                )
            else:
                if 'ragamuffin' in self.model_id.lower():
                    self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))

                else:
                    self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
                print("Tokenizer loaded")
                if 'jamba' in self.model_id.lower():
                    self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2",
                                             device_map="auto",
                                             use_mamba_kernels=False)
                    
                elif 'qwen2-vl' in self.model_id.lower():
                    self.local_model = Qwen2VLForConditionalGeneration.from_pretrained(
                        self.model_id, torch_dtype="auto", device_map="auto"
                    )
                    self.processor = AutoProcessor.from_pretrained(self.model_id)

                # elif 'ragamuffin' in self.model_id.lower():
                #     print('Using ragamuffin')
                #     self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
                #                                  torch_dtype=torch.bfloat16, # forcing bfloat16 for now
                #                                  attn_implementation="flash_attention_2")
                elif 'olmo' in self.model_id.lower():
                    self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id)#torch_dtype="auto"
                
                elif 'qwq-' in self.model_id.lower():
                    self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype="auto", device_map="auto")

                else:
                    self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
                # print(self.local_model.device)
                print("Local model loaded")
                    
        # Using local model/pipeline
        if self.local_pipeline:
            print('Using Transformers pipeline')
            messages=[
                {"role": "system", "content": system_prompt}, 
                {"role": "user", "content": user_prompt}
            ]
            outputs = self.local_pipeline(
                messages,
                max_new_tokens=256,
                # return_full_text=False,
                do_sample=False
            )
            result = outputs[0]["generated_text"][-1]['content']
            print(result)
            return result

        elif self.local_model: # cannot call API. using local model / pipeline
            print('Using local model')

            # Set appropriate prompt based on model document
            if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
                messages=[
                    # gemma-1.1, mistral-7b does not accept system role
                    {"role": "user", "content": system_prompt + '\n' + user_prompt}
                ]
                prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            
            elif 'phi-2' in self.model_id.lower():
                prompt = system_prompt + '\n' + user_prompt

            elif 'intel' in self.model_id.lower():
                prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n"

            elif 'qwen2-vl' in self.model_id.lower():
                messages = [
                    {   
                        "role": "system",
                        "content": [
                            {"type": "text", "text": system_prompt}
                        ]
                    }, 
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": user_prompt},
                        ],
                    }
                ]
            else:
                messages=[
                    {"role": "system", "content": system_prompt}, 
                    {"role": "user", "content": user_prompt}
                ]
                prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
           
            # Tokenize inputs
            if 'olmo' in self.model_id.lower():
                input_ids = self.tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)#.to(self.device)
            elif 'qwq' in self.model_id.lower():
                input_ids = self.tokenizer([prompt], return_tensors="pt").to(self.device)
            else:
                input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            # Generate outputs
            if 'granite' in self.model_id.lower():
                self.local_model.eval()
                outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
            elif 'olmo' in self.model_id.lower():
                outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01)#top_k=50, top_p=0.95)
            elif 'qwq' in self.model_id.lower():
                outputs = self.local_model.generate(**input_ids, max_new_tokens=512, do_sample=True, temperature=0.01)
            else:
                with torch.no_grad():
                    outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
            if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
                outputs = outputs[:, input_ids['input_ids'].shape[1]:]
            elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower() or 'qwq-' in self.model_id.lower():
                outputs = [
                    out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
                ]
            
            # Decode outputs
            if 'qwen2-vl' in self.model_id.lower():
                result = self.processor.batch_decode(
                    outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )[0]
            elif 'olmo' in self.model_id.lower() or 'qwq' in self.model_id.lower():
                result = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            else:
                result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            if 'gemma-2' in self.model_id.lower():
                result = result.split(user_prompt + '\nmodel')[-1].strip()
            elif 'intel' in self.model_id.lower():
                result = result.split("### Assistant:\n")[-1]
            elif 'jamba' in self.model_id.lower():
                result = result.split(messages[-1]['content'])[1].strip()
            elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
                pass
            elif 'olmo' in self.model_id.lower():
                result = result.split("<|assistant|>\n")[-1]
            else:
                result = result.replace(prompt.strip(), '')
            
            print(result)
            return result

    def _compute_avg_length(self):
        """
        Compute the average length of non-empty summaries using SpaCy.
        """
        total_word_count = 0
        total_count = 0

        for summary in self.summaries_df['summary']:
            if util.is_summary_valid(summary):
                doc = nlp(summary)
                words = [token.text for token in doc if token.is_alpha]
                total_word_count += len(words)
                total_count += 1

        self.avg_length = 0 if total_count == 0 else total_word_count / total_count

    def _compute_answer_rate(self):
        """
        Compute the rate of non-empty summaries.
        """
        valid_count = sum(1 for summary in self.summaries_df['summary']
                            if util.is_summary_valid(summary))

        total_count = len(self.summaries_df)

        self.answer_rate = 0 if total_count == 0 else valid_count / total_count


class EvaluationModel:
    """A class to evaluate generated summaries.

    Attributes:
        model (CrossEncoder): The evaluation model.
        scores (list): List of evaluation scores.
        accuracy (float): Accuracy of the summaries.
        hallucination_rate (float): Rate of hallucination in summaries.
    """

    def __init__(self, model_path, device):
        """
        Initializes the EvaluationModel with a CrossEncoder model.

        Args:
            model_path (str): Path to the CrossEncoder model.
        """
        config = AutoConfig.from_pretrained('google/flan-t5-large')
        self.model = AutoModelForTokenClassification.from_pretrained(model_path, config=config)
        self.device = device
        self.model.to(self.device)
        self.scores = []
        self.factual_consistency_rate = None
        self.hallucination_rate = None
    
    def predict(self, text_pairs):
        """Load LoRA adapters of HHEM and make predictions
        All HHEM 2.1 settings, e.g., prompt template, are hardcoded in this function.
        Args:
            text_pairs: list of tuples, each tuple contains two strings (premise, hypothesis)
            checkpoint: model ID on Hugging Face
        """

        prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
        
        tokenizer = AutoTokenizer.from_pretrained('t5-base')
        inputs = tokenizer(
            [prompt.format(text1=pair[0], text2=pair[1]) for pair in text_pairs], 
            return_tensors='pt', padding='longest').to(self.device)
        
        self.model.eval()
        with torch.no_grad():
            output = self.model(**inputs)
        logits = output.logits
        logits = logits[:,0,:] # get the logits on the first token
        logits = torch.softmax(logits, dim=-1)
        scores = [round(x, 5) for x in logits[:, 1].tolist()] # list of float
        return scores

    def evaluate_hallucination(self, summaries_df):
        """
        Evaluate the hallucination rate in summaries. Updates the 'scores' attribute 
        of the instance with the computed scores.

        Args:
            summaries_df (DataFrame): DataFrame containing source docs and summaries.

        Returns:
            list: List of hallucination scores. Also updates the 'scores' attribute of the instance.
        """
        hem_scores = []
        sources = []
        summaries = []
        source_summary_pairs = util.create_pairs(summaries_df)

        for doc, summary in source_summary_pairs:
            if util.is_summary_valid(summary):
                try:
                    summary = util.normalize_summary(summary)
                    score = self.predict([(doc, summary)])[0]
                    hem_scores.append(score)
                    sources.append(doc)
                    summaries.append(summary)
                    if score < 0.5:
                        print(score)
                        print(doc)
                        print('-'*20)
                        print(summary)
                        print('='*50)
                except Exception as e:
                    logging.error(f"Error while running HEM: {e}")
                    raise

        self.scores = hem_scores
        eval_results = {'source': sources, 'summary': summaries, 'HEM scores': hem_scores}
        return hem_scores, eval_results


    def compute_factual_consistency_rate(self, threshold=0.5):
        """
        Compute the factual consistency rate of the evaluated summaries based on
        the previously calculated scores. This method relies on the 'scores'
        attribute being populated, typically via the 'evaluate_hallucination' method.

        Returns:
            float: Factual Consistency Rate. Also updates the 'factual_consistency_rate'
            and 'hallucination_rate' attributes of the instance.

        Raises:
            ValueError: If scores have not been calculated prior to calling this method.
        """
        if not self.scores:
            error_msg = "Scores not calculated. Call evaluate_hallucination() first."
            logging.error(error_msg)
            raise ValueError(error_msg)

        # Use threshold of 0.5 to compute factual_consistency_rate
        num_above_threshold = sum(score >= threshold for score in self.scores)
        num_total = len(self.scores)

        if not num_total:
            raise ValueError("No scores available to compute factual consistency rate.")

        self.factual_consistency_rate = (num_above_threshold / num_total) * 100
        self.hallucination_rate = 100 - self.factual_consistency_rate

        return self.factual_consistency_rate