Spaces:
Running
Running
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import simple_slice_viewer as ssv | |
import SimpleITK as sikt | |
import gradio as gr | |
device = torch.device('cpu') # Set to 'cuda' if using a GPU | |
dtype = torch.float32 # Data type for model processing | |
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B' | |
proj_out_num = 256 # Number of projection outputs required for the image | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map='cpu', | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
model_max_length=512, | |
padding_side="right", | |
use_fast=False, | |
trust_remote_code=True | |
) | |
def process_image(image_path, question): | |
# Load the image | |
image_np = np.load(image_path) # Load the .npy image | |
image_tokens = "<im_patch>" * proj_out_num | |
input_txt = image_tokens + question | |
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device) | |
# Prepare image for model | |
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device) | |
# Generate model response | |
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) | |
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True) | |
return generated_texts[0] | |
# Gradio Interface | |
def gradio_interface(image, question): | |
response = process_image(image.name, question) | |
return response | |
# Gradio App | |
gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image | |
gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."), | |
], | |
outputs=gr.Textbox(label="Model Response"), | |
title="Medical Image Analysis", | |
description="Upload a .npy image and ask a question to analyze it using the model." | |
).launch() | |