Spaces:
Running
Running
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# Set device & model details | |
device = torch.device('cpu') | |
dtype = torch.float32 | |
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B' | |
proj_out_num = 256 # Number of projection outputs required | |
# Load model & tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, torch_dtype=dtype, device_map='cpu', trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True | |
) | |
# Image placeholder (to maintain session context) | |
uploaded_image = None | |
def process_image(question, history): | |
global uploaded_image | |
if uploaded_image is None: | |
return "⚠️ Please upload an image first!" | |
# Load the .npy image | |
image_np = np.load(uploaded_image) | |
image_tokens = "<im_patch>" * proj_out_num | |
input_txt = image_tokens + question | |
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device) | |
# Convert image to tensor | |
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device) | |
# Generate response | |
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) | |
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True) | |
return generated_texts[0] | |
def upload_image(image): | |
""" Stores the uploaded image path to be used in chat """ | |
global uploaded_image | |
uploaded_image = image.name | |
return f"✅ Image uploaded successfully: {image.name}" | |
# Chat Interface with File Upload | |
with gr.Blocks(theme="soft") as chat_ui: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("# 🏥 Medical Image Chatbot") | |
uploaded_file = gr.File(label="Upload .npy Image", type="filepath") | |
upload_button = gr.Button("Upload") | |
status = gr.Markdown("") | |
chat = gr.Chatbot(height=400) | |
with gr.Column(scale=3): | |
input_box = gr.Textbox(placeholder="Ask something about the image...") | |
send_button = gr.Button("Send ✉️") | |
# Handle image upload | |
upload_button.click(upload_image, inputs=[uploaded_file], outputs=[status]) | |
# Handle chat interaction | |
send_button.click(process_image, inputs=[input_box, chat], outputs=[chat]) | |
chat_ui.launch() | |