ICLINIQ / app.py
Rohith1112's picture
update
e19a001 verified
raw
history blame
2.46 kB
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# Set device & model details
device = torch.device('cpu')
dtype = torch.float32
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256 # Number of projection outputs required
# Load model & tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=dtype, device_map='cpu', trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True
)
# Image placeholder (to maintain session context)
uploaded_image = None
def process_image(question, history):
global uploaded_image
if uploaded_image is None:
return "⚠️ Please upload an image first!"
# Load the .npy image
image_np = np.load(uploaded_image)
image_tokens = "<im_patch>" * proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
# Convert image to tensor
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
# Generate response
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
return generated_texts[0]
def upload_image(image):
""" Stores the uploaded image path to be used in chat """
global uploaded_image
uploaded_image = image.name
return f"✅ Image uploaded successfully: {image.name}"
# Chat Interface with File Upload
with gr.Blocks(theme="soft") as chat_ui:
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("# 🏥 Medical Image Chatbot")
uploaded_file = gr.File(label="Upload .npy Image", type="filepath")
upload_button = gr.Button("Upload")
status = gr.Markdown("")
chat = gr.Chatbot(height=400)
with gr.Column(scale=3):
input_box = gr.Textbox(placeholder="Ask something about the image...")
send_button = gr.Button("Send ✉️")
# Handle image upload
upload_button.click(upload_image, inputs=[uploaded_file], outputs=[status])
# Handle chat interaction
send_button.click(process_image, inputs=[input_box, chat], outputs=[chat])
chat_ui.launch()