import gradio as gr
from transformers import AutoTokenizer, AutoProcessor, VisionEncoderDecoderModel, TrOCRProcessor
from vllm import LLM, SamplingParams
from PIL import Image

# Load the language model and tokenizer from Hugging Face
model_name = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Initialize vLLM with CPU configuration
vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu")

# Load the OCR model and processor
ocr_model_name = "microsoft/trocr-small-handwritten"
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name)
ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name)
#ocr_processor = AutoProcessor.from_pretrained(ocr_model_name)

def generate_response(prompt, max_tokens, temperature, top_p):
    # Define sampling parameters
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    
    # Generate text using vLLM (input is the raw string `prompt`)
    output = vllm_model.generate(prompt, sampling_params)
    
    # Extract and decode the generated tokens
    generated_text = output[0].outputs[0].text
    return generated_text

def ocr_image(image_path):
    # Open the image from the file path
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image for the OCR model
    pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
    
    # Perform OCR inference
    outputs = ocr_model.generate(pixel_values)
    
    # Decode the generated tokens into text
    text = ocr_processor.batch_decode(outputs, skip_special_tokens=True)[0]
    return text

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# 🚀 Hugging Face Integration with vLLM and OCR (CPU)")
    gr.Markdown("Upload an image to extract text using OCR or generate text using the vLLM integration.")

    with gr.Tab("Text Generation"):
        with gr.Row():
            with gr.Column():
                prompt_input = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your prompt here...",
                    lines=3,
                )
                max_tokens = gr.Slider(
                    label="Max Tokens",
                    minimum=10,
                    maximum=500,
                    value=100,
                    step=10,
                )
                temperature = gr.Slider(
                    label="Temperature",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.7,
                    step=0.1,
                )
                top_p = gr.Slider(
                    label="Top P",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                )
                submit_button = gr.Button("Generate")
            
            with gr.Column():
                output_text = gr.Textbox(
                    label="Generated Text",
                    lines=10,
                    interactive=False,
                )
        
        submit_button.click(
            generate_response,
            inputs=[prompt_input, max_tokens, temperature, top_p],
            outputs=output_text,
        )

    with gr.Tab("OCR"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    label="Upload Image",
                    type="filepath",  # Corrected type
                    image_mode="RGB",
                )
                ocr_submit_button = gr.Button("Extract Text")
            
            with gr.Column():
                ocr_output = gr.Textbox(
                    label="Extracted Text",
                    lines=10,
                    interactive=False,
                )
        
        ocr_submit_button.click(
            ocr_image,
            inputs=[image_input],
            outputs=ocr_output,
        )

# Launch the app
demo.launch()