#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import json
import os
import platform
import re
import string
from typing import List

from project_settings import project_path

os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix()

import gradio as gr
from threading import Thread
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.generation.streamers import TextIteratorStreamer
import torch


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_new_tokens", default=512, type=int)
    parser.add_argument("--top_p", default=0.9, type=float)
    parser.add_argument("--temperature", default=0.35, type=float)
    parser.add_argument("--repetition_penalty", default=1.0, type=float)
    parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str)

    parser.add_argument(
        "--examples_json_file",
        default="examples.json",
        type=str
    )
    args = parser.parse_args()
    return args


def repl1(match):
    result = "{}{}".format(match.group(1), match.group(2))
    return result


def repl2(match):
    result = "{}".format(match.group(1))
    return result


def remove_space_between_cn_en(text):
    splits = re.split(" ", text)
    if len(splits) < 2:
        return text

    result = ""
    for t in splits:
        if t == "":
            continue
        if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t):
            result += " "
            result += t
        else:
            if not result == "":
                result += t
            else:
                result = t

    if text.endswith(" "):
        result += " "
    return result


def main():
    args = get_args()

    description = """
    ## GPT2 Chat
    """

    # example json
    with open(args.examples_json_file, "r", encoding="utf-8") as f:
        examples = json.load(f)

    if args.device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    input_text_box = gr.Text(label="text")
    output_text_box = gr.Text(lines=4, label="generated_content")

    def fn_stream(text: str,
                  max_new_tokens: int = 200,
                  top_p: float = 0.85,
                  temperature: float = 0.35,
                  repetition_penalty: float = 1.2,
                  model_name: str = "qgyd2021/lip_service_4chan",
                  is_chat: bool = True,
                  ):
        tokenizer = BertTokenizer.from_pretrained(model_name)
        model = GPT2LMHeadModel.from_pretrained(model_name)
        model = model.eval()

        text_encoded = tokenizer.__call__(text, add_special_tokens=False)
        input_ids_ = text_encoded["input_ids"]

        input_ids = [tokenizer.cls_token_id]
        input_ids.extend(input_ids_)
        if is_chat:
            input_ids.append(tokenizer.sep_token_id)

        input_ids = torch.tensor([input_ids], dtype=torch.long)
        input_ids = input_ids.to(device)

        streamer = TextIteratorStreamer(tokenizer=tokenizer)

        generation_kwargs = dict(
            inputs=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            eos_token_id=tokenizer.sep_token_id if is_chat else None,
            pad_token_id=tokenizer.pad_token_id,
            streamer=streamer,
        )
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        output: str = ""
        first_answer = True
        for output_ in streamer:
            if first_answer:
                first_answer = False
                continue

            output_ = output_.replace("[UNK] ", "")
            output_ = output_.replace("[UNK]", "")
            output_ = output_.replace("[CLS] ", "")
            output_ = output_.replace("[CLS]", "")

            output += output_
            if output.startswith("[SEP]"):
                output = output[5:]

            output = output.lstrip(" ,.!?")
            output = remove_space_between_cn_en(output)
            # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ([,。!?\u4e00-\u9fa5])", repl1, output)
            # output = re.sub(r"([,。!?\u4e00-\u9fa5]) ", repl2, output)

            output = output.replace("[SEP] ", "\n")
            output = output.replace("[SEP]", "\n")

            yield output

    model_name_choices = [
        "trained_models/lip_service_4chan",
        "trained_models/chinese_porn_novel"
    ] if platform.system() == "Windows" else \
        [
            "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat",
            "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2_base",
            "qgyd2021/similar_question_generation",
        ]

    # model_name_choices = [
    #     "qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat",
    #     "qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2_base",
    #     "qgyd2021/similar_question_generation",
    # ]

    demo = gr.Interface(
        fn=fn_stream,
        inputs=[
            input_text_box,
            gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"),
            gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
            gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
            gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
            gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"),
            gr.Checkbox(value=True, label="is_chat")
        ],
        outputs=[output_text_box],
        examples=examples,
        cache_examples=False,
        examples_per_page=50,
        title="GPT2 Chat",
        description=description,
    )
    demo.queue().launch()

    return


if __name__ == '__main__':
    main()