import numpy as np import torch from transformers import AutoTokenizer, AutoModelForCausalLM import simple_slice_viewer as ssv import SimpleITK as sikt import gradio as gr import matplotlib.pyplot as plt device = torch.device('cpu') # Set to 'cuda' if using a GPU dtype = torch.float32 # Data type for model processing model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B' proj_out_num = 256 # Number of projection outputs required for the image # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.float32, device_map='cpu', trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, model_max_length=512, padding_side="right", use_fast=False, trust_remote_code=True ) def process_image(image_path, question): # Load the image image_np = np.load(image_path) # Load the .npy image image_tokens = "" * proj_out_num input_txt = image_tokens + question input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device) # Prepare image for model image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device) # Generate model response generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True) return generated_texts[0], image_np # Gradio Interface def gradio_interface(image, question): response, image_np = process_image(image.name, question) # Extract slices from the image slices = [] for i in range(image_np.shape[0]): # Assuming the image is 3D slices.append(image_np[i, :, :]) # Extract each slice # Plot the slices and save them as images fig, axes = plt.subplots(1, len(slices), figsize=(15, 5)) if len(slices) == 1: axes = [axes] for ax, slice_data in zip(axes, slices): ax.imshow(slice_data, cmap='gray') ax.axis('off') plt.tight_layout() plt.savefig('slices.png') # Save the slices as a PNG image return response, 'slices.png' # Gradio App gr.Interface( fn=gradio_interface, inputs=[ gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."), ], outputs=[ gr.Textbox(label="Model Response"), gr.Image(label="Image Slices", type="filepath", image_mode='L') ], title="Medical Image Analysis", description="Upload a .npy image and ask a question to analyze it using the model. The image slices will be displayed." ).launch()