import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import LlamaForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import os
from threading import Thread
from polyglot.detect import Detector
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL = "LLaMAX/LLaMAX3-8B-Alpaca"
RELATIVE_MODEL="LLaMAX/LLaMAX3-8B"
TITLE = "
LLaMAX3-8B-Translation
"
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = LlamaForCausalLM.from_pretrained(
        MODEL,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
def lang_detector(text):
    min_chars = 5
    if len(text) < min_chars:
        return "Input text too short"
    try:
        detector = Detector(text).language
        lang_info = str(detector)
        code = re.search(r"name: (\w+)", lang_info).group(1)
        return code
    except Exception as e:
        return f"ERROR:{str(e)}"
def Prompt_template(query, src_language, trg_language):
    instruction = f'Translate the following sentences from {src_language} to {trg_language}.'
    prompt = (
        'Below is an instruction that describes a task, paired with an input that provides further context. '
        'Write a response that appropriately completes the request.\n'
        f'### Instruction:\n{instruction}\n'
        f'### Input:\n{query}\n### Response:'
    )
    return prompt
# Unfinished
def chunk_text():
    pass
    
@spaces.GPU()
def translate(
    source_text: str, 
    source_lang: str,
    target_lang: str, 
    max_chunk: int,
    max_length: int,
    temperature: float):
    
    print(f'Text is - {source_text}')
    
    prompt = Prompt_template(source_text, source_lang, target_lang)
    inputs = tokenizer(prompt, return_tensors="pt")
    
    input_ids = inputs.to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
    generate_kwargs = dict(
        input_ids=input_ids, 
        streamer=streamer,
        max_length=max_length, 
        do_sample=True, 
        temperature=temperature,
    )
    
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer
CSS = """
    h1 {
        text-align: center;
        display: block;
        height: 10vh;
        align-content: center;
    }
    footer {
        visibility: hidden;
    }
"""
DESCRIPTION = """
- LLaMAX is a language model with powerful multilingual capabilities without loss instruction-following capabilities.
- Source Language auto detected, input your Target language and country.
"""
chatbot = gr.Chatbot(height=600)
with gr.Blocks(theme="soft", css=CSS) as demo:
    gr.Markdown(TITLE)
    with gr.Row():
        with gr.Column(scale=1):
            source_lang = gr.Textbox(
                label="Source Lang(Auto-Detect)",
                value="English",
            )
            target_lang = gr.Textbox(
                label="Target Lang",
                value="Spanish",
            )
            max_chunk = gr.Slider(
                label="Max tokens Per Chunk",
                minimum=512,
                maximum=2046,
                value=1000,
                step=8,
            )
            max_length = gr.Slider(
                label="Context Window",
                minimum=512,
                maximum=8192,
                value=4096,
                step=8,
            )
            temperature = gr.Slider(
                label="Temperature",
                minimum=0,
                maximum=1,
                value=0.3,
                step=0.1,
            )
        with gr.Column(scale=4):
            gr.Markdown(DESCRIPTION)
            source_text = gr.Textbox(
                label="Source Text",
                value="How we live is so different from how we ought to live that he who studies "+\
                "what ought to be done rather than what is done will learn the way to his downfall "+\
                "rather than to his preservation.",
                lines=10,
            )
            output_text = gr.Textbox(
                label="Output Text",
                lines=10,
            )
    with gr.Row():
        submit = gr.Button(value="Submit")
        clear = gr.ClearButton([source_text, output_text])
        
    source_text.change(lang_detector, source_text, source_lang)
    submit.click(fn=translate, inputs=[source_lang, target_lang, source_text, max_chunk, max_length, temperature], outputs=[output_text])
if __name__ == "__main__":
    demo.launch()