import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import gradio as gr
from gradio import deploy

def generate_prompt(instruction, input=""):
    instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
    input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
    if input:
        return f"""Instruction: {instruction}

Input: {input}

Response:"""
    else:
        return f"""User: hi

Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.

User: {instruction}

Assistant:"""

model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    trust_remote_code=True, 
    # use_flash_attention_2=False
).to(torch.float32)

# Create a custom tokenizer (make sure to download vocab.json)
tokenizer = AutoTokenizer.from_pretrained(
    model_path, 
    bos_token="</s>",
    eos_token="</ s>",
    unk_token="<unk>",
    pad_token="<pad>",
    trust_remote_code=True, 
    padding_side='left', 
    clean_up_tokenization_spaces=False  # Or set to True if you prefer
)

# Function to handle text generation with word-by-word output and stop sequence
def generate_text(input_text):
    prompt = generate_prompt(input_text)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    generated_text = ""

    for i in range(333):
        output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0)
        new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True)

        print(new_word, end="", flush=True)  # Print to console for monitoring
        generated_text += new_word

        input_ids = output

        yield generated_text  # Yield the updated text after each word

# Create the Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs="text",
    outputs="text",
    title="RWKV Chatbot",
    description="Enter your prompt below:",
    # flagging_callback=None  
    flagging_dir="gradio_flagged/" 
)

# For local testing:
iface.launch(share=False)
# deploy()


# Hugging Face Spaces will automatically launch the interface.