File size: 2,762 Bytes
b598f3e
 
 
 
 
cc17b6f
0c5b5f7
cc17b6f
b598f3e
 
cc17b6f
b598f3e
 
cc17b6f
b598f3e
 
 
 
 
 
cc17b6f
 
b598f3e
 
 
 
 
 
 
cc17b6f
b598f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
0c5b5f7
b598f3e
 
 
0c5b5f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b598f3e
 
 
 
 
 
 
 
0c5b5f7
 
 
 
b598f3e
0c5b5f7
b598f3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import simple_slice_viewer as ssv
import SimpleITK as sikt
import gradio as gr
import matplotlib.pyplot as plt

device = torch.device('cpu')  # Set to 'cuda' if using a GPU
dtype = torch.float32  # Data type for model processing

model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256  # Number of projection outputs required for the image

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.float32,
    device_map='cpu',
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=512,
    padding_side="right",
    use_fast=False,
    trust_remote_code=True
)

def process_image(image_path, question):
    # Load the image
    image_np = np.load(image_path)  # Load the .npy image
    image_tokens = "<im_patch>" * proj_out_num
    input_txt = image_tokens + question
    input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)

    # Prepare image for model
    image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)

    # Generate model response
    generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
    generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)

    return generated_texts[0], image_np

# Gradio Interface
def gradio_interface(image, question):
    response, image_np = process_image(image.name, question)
    
    # Extract slices from the image
    slices = []
    for i in range(image_np.shape[0]):  # Assuming the image is 3D
        slices.append(image_np[i, :, :])  # Extract each slice
    
    # Plot the slices and save them as images
    fig, axes = plt.subplots(1, len(slices), figsize=(15, 5))
    if len(slices) == 1:
        axes = [axes]
    
    for ax, slice_data in zip(axes, slices):
        ax.imshow(slice_data, cmap='gray')
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('slices.png')  # Save the slices as a PNG image

    return response, 'slices.png'

# Gradio App
gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.File(label="Upload .npy Image", type="filepath"),  # For uploading .npy image
        gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
    ],
    outputs=[
        gr.Textbox(label="Model Response"),
        gr.Image(label="Image Slices", type="filepath", image_mode='L')
    ],
    title="Medical Image Analysis",
    description="Upload a .npy image and ask a question to analyze it using the model. The image slices will be displayed."
).launch()