import os
import gc
import torch
import torch.nn as nn
import argparse
import gradio as gr
import time
from transformers import AutoTokenizer, LlamaForCausalLM
from utils import SteamGenerationMixin
import requests

auth_token = os.getenv("Zimix")
url_api = os.getenv('api_url')
# print(url_api)
URL = f'http://120.234.0.81:8808/{url_api}'
def cc(q,r):
    try:
        requests.request('get',URL,params={'query':q,'response':r,'time':time.time()})
    except:
        print('推送失败-_- !')


class MindBot(object):
    def __init__(self, model_path, tokenizer_path,if_int8=False):
        # self.device = torch.device("cuda")
        # device_ids = [1, 2]
        if if_int8:
            self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto', load_in_8bit=True,use_auth_token=auth_token).eval()
        else:
            self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto',use_auth_token=auth_token).half().eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,use_auth_token=auth_token)
        # sp_tokens = {'additional_special_tokens': ['<human>', '<bot>']}
        # self.tokenizer.add_special_tokens(sp_tokens)
        self.history = []
    
    def build_prompt(self, instruction, history, human='<human>', bot='<bot>'):
        pmt = ''
        if len(history) > 0:
            for line in history:
                pmt += f'{human}: {line[0].strip()}\n{bot}: {line[1]}\n'
        pmt += f'{human}: {instruction.strip()}\n{bot}: \n'
        return pmt
    
    def common_generate(self, instruction, clear_history=False, max_memory=1024):
        if clear_history:
            self.history = []
        
        prompt = self.build_prompt(instruction, self.history)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        if input_ids.shape[1] > max_memory:
            input_ids = input_ids[:, -max_memory:]
            
        prompt_len = input_ids.shape[1]
        # common method
        generation_output = self.model.generate(
            input_ids.cuda(),
            max_new_tokens=1024, 
            do_sample=True,
            top_p=0.85, 
            temperature=0.8, 
            repetition_penalty=1., 
            eos_token_id=2, 
            bos_token_id=1, 
            pad_token_id=0
        )
        
        s = generation_output[0][prompt_len:]
        output = self.tokenizer.decode(s, skip_special_tokens=True)
        # output = output
        output = output.replace("Belle", "IDEA")
        self.history.append((instruction, output))
        print('api history: ======> \n', self.history)

        return output
        
    
    def interaction(
        self,
        instruction,
        history,
        max_memory=1024
    ):
               
        prompt = self.build_prompt(instruction, history)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        if input_ids.shape[1] > max_memory:
            input_ids = input_ids[:, -max_memory:]
            
        prompt_len = input_ids.shape[1]
        # stream generation method
        try:
            tmp = history.copy()
            output = ''
            with torch.no_grad():
                for generation_output in self.model.stream_generate(
                    input_ids.cuda(),
                    max_new_tokens=1024, 
                    do_sample=True,
                    top_p=0.85, 
                    temperature=0.8, 
                    repetition_penalty=1., 
                    eos_token_id=2, 
                    bos_token_id=1, 
                    pad_token_id=0
                ):
                    s = generation_output[0][prompt_len:]
                    output = self.tokenizer.decode(s, skip_special_tokens=True)
                    output = output.replace('\n', '<br>')
                    tmp.append((instruction, output))
                    yield  '', tmp
                    tmp.pop()
                    # gc.collect()
                    # torch.cuda.empty_cache()
                history.append((instruction, output))
                print('input -----> \n', prompt)
                print('output -------> \n', output)
                print('history: ======> \n', history)
                cc(prompt,output)
        except torch.cuda.OutOfMemoryError:
            gc.collect()
            torch.cuda.empty_cache()
            self.model.empty_cache()
        return "", history
        
    def new_chat_bot(self):
        
        with gr.Blocks(title='IDEA Ziya', css=".gradio-container {max-width: 50% !important;} .bgcolor {color: white !important; background: #FFA500 !important;}") as demo:
            gr.Markdown("<center><h1>IDEA Ziya</h1></center>")
            gr.Markdown("<center>本页面基于hugging face支持的设备搭建 模型版本v1.1</center>")
            with gr.Row():
                chatbot = gr.Chatbot(label='Ziya').style(height=500)
            with gr.Row():
                msg = gr.Textbox(label="Input")
            with gr.Row():
                with gr.Column(scale=0.5):
                    clear = gr.Button("Clear")
                with gr.Column(scale=0.5):
                    submit = gr.Button("Submit", elem_classes='bgcolor')
            
            msg.submit(self.interaction, [msg, chatbot], [msg, chatbot])
            clear.click(lambda: None, None, chatbot, queue=False)
            submit.click(self.interaction, [msg, chatbot], [msg, chatbot])
        return demo.queue(concurrency_count=5)
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_path", 
        type=str,
        default="/cognitive_comp/songchao/checkpoints/global_step3200-hf"
    )
    args = parser.parse_args()
    
    mind_bot = MindBot(args.model_path)
    demo = mind_bot.new_chat_bot()