ICLINIQ / app.py
Rohith1112's picture
COMMIT
b598f3e verified
raw
history blame
2.02 kB
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import simple_slice_viewer as ssv
import SimpleITK as sikt
import gradio as gr
device = torch.device('cpu') # Set to 'cuda' if using a GPU
dtype = torch.float32 # Data type for model processing
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256 # Number of projection outputs required for the image
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float32,
device_map='cpu',
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
def process_image(image_path, question):
# Load the image
image_np = np.load(image_path) # Load the .npy image
image_tokens = "<im_patch>" * proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
# Prepare image for model
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
# Generate model response
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
return generated_texts[0]
# Gradio Interface
def gradio_interface(image, question):
response = process_image(image.name, question)
return response
# Gradio App
gr.Interface(
fn=gradio_interface,
inputs=[
gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
],
outputs=gr.Textbox(label="Model Response"),
title="Medical Image Analysis",
description="Upload a .npy image and ask a question to analyze it using the model."
).launch()