import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

# Set device & model details
device = torch.device('cpu')
dtype = torch.float32
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256  # Number of projection outputs required

# Load model & tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=dtype, device_map='cpu', trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True
)

# Image placeholder (to maintain session context)
uploaded_image = None

def process_image(question, history):
    global uploaded_image
    if uploaded_image is None:
        return "⚠️ Please upload an image first!"

    # Load the .npy image
    image_np = np.load(uploaded_image)
    image_tokens = "<im_patch>" * proj_out_num
    input_txt = image_tokens + question
    input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)

    # Convert image to tensor
    image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)

    # Generate response
    generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
    generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)

    return generated_texts[0]

def upload_image(image):
    """ Stores the uploaded image path to be used in chat """
    global uploaded_image
    uploaded_image = image.name
    return f"✅ Image uploaded successfully: {image.name}"

# Chat Interface with File Upload
with gr.Blocks(theme="soft") as chat_ui:
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("# 🏥 Medical Image Chatbot")
            uploaded_file = gr.File(label="Upload .npy Image", type="filepath")
            upload_button = gr.Button("Upload")
            status = gr.Markdown("")
            chat = gr.Chatbot(height=400)

        with gr.Column(scale=3):
            input_box = gr.Textbox(placeholder="Ask something about the image...")
            send_button = gr.Button("Send ✉️")

    # Handle image upload
    upload_button.click(upload_image, inputs=[uploaded_file], outputs=[status])

    # Handle chat interaction
    send_button.click(process_image, inputs=[input_box, chat], outputs=[chat])

chat_ui.launch()