import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import InferenceClient
import re
import torch

# Model and tokenizer loading (outside the respond function)
try:
    tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
    base_model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1")
    peft_model = PeftModel.from_pretrained(base_model, "KGSAGAR/Sarvam-1-text-normalization-3r")
    peft_model = peft_model.merge_and_unload()
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    tokenizer = None
    base_model = None
    peft_model = None
    


def respond(message, history, system_message, max_tokens, temperature, top_p):
    global tokenizer, peft_model
    
    if tokenizer is None or peft_model is None:
        return "Model loading failed. Please check the logs."
    prompt = ''
    # Construct the prompt
    prompt = system_message
    # for user_msg, assistant_msg in history:
    #     if user_msg:
    #         prompt += f"\nUser: {user_msg} output: "
    #     if assistant_msg:
    #         prompt += f"\nAssistant: {assistant_msg}"
    prompt += f"\n<user> input:{message} output:"
    
    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    
    try:
        outputs = peft_model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            # top_p=top_p,
            # do_sample=True
        )
        generated_text = None
        # Get the last message from the generated text
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
            return f"Generation error: {e}"
        
    def extract_user_content(text):
        """
        Extracts and returns the content that follows the word 'output' in the given text.
        If 'output' is not found, returns an empty string.
        """
        # Define the regular expression pattern to match 'output' and capture everything after it
        pattern = re.compile('<user>(.*?)</user>', re.IGNORECASE | re.DOTALL)
        match = pattern.search(text)
        if match:
            # Return the captured group, which is the content after 'output'
            return match.group(1).strip()
        else:
            # If 'output' is not found, return an empty string
            return "Retry to get output, the model failed to generated required output(This occurs rarelyЁЯд╖тАНтЩВя╕П)"
    print(generated_text)        
    lines = extract_user_content(generated_text)
    print(lines)
            
    return lines
        
        

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Take the user input in Hindi language and normalize specific entities, Only including: Dates (any format) Currencies Scientific units, Here's an example input and output of the task <Example> Exampleinput :  2012тАУ13 рдореЗрдВ рд░рдХреНрд╖рд╛ рд╕реЗрд╡рд╛рдУрдВ рдХреЗ рд▓рд┐рдП 1,93,407 рдХрд░реЛрдбрд╝ рд░реБрдкрдП рдХрд╛ рдкреНрд░рд╛рд╡рдзрд╛рди рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛, рдЬрдмрдХрд┐ 2011тАУ2012 рдореЗрдВ рдпрд╣ рд░рд╛рд╢рд┐ 1,64,415 рдХрд░реЛрдЗрд╝ рдереА, Exampleoutput: рдЯреНрд╡реЗрдиреНрдЯреА рдЯреНрд╡реЗрд▓реНрд╡ рдлрд┐рдлреНрдЯреАрди рдореЗрдВ рд░рдХреНрд╖рд╛ рд╕реЗрд╡рд╛рдУрдВ рдХреЗ рд▓рд┐рдП рд╡рди рдХрд░реЛрдбрд╝ рдирд┐рдиреЗрдЯреА рдереНрд░реА рдерд╛рдЙрдЬреЗрдВрдб рдлреЛрд░ рд╣рдВрдбреНрд░реЗрдб рд╕реЗрд╡рди рдХрд░реЛрдбрд╝ рд░реБрдкрдП рдХрд╛ рдкреНрд░рд╛рд╡рдзрд╛рди рдХрд┐рдпрд╛ рдЧрдпрд╛ рдерд╛, рдЬрдмрдХрд┐ рдЯреНрд╡реЗрдиреНрдЯреА рдПрд▓реЗрд╡рди рдЯреНрд╡реЗрд▓реНрд╡ рдореЗрдВ рдпрд╣ рд░рд╛рд╢рд┐ рд╡рди рдХрд░реЛрдбрд╝ рд╕рд┐рдХреНрд╕реНрдЯреА рдлреЛрд░ рдерд╛рдЙрдЬреЗрдВрдб рдлреЛрд░ рд╣рдВрдбреНрд░реЗрдб рдлрд┐рдлреНрдЯреАрди рдХрд░реЛрдбрд╝ рдереА </Example>. Understand the task and Only provide the normalized output with atmost accuracy",label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.95, step=0.1, label="Temperature")
    ],
)

if __name__ == "__main__":
    demo.launch()