Spaces:
Running
Running
app.py
CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
4 |
import simple_slice_viewer as ssv
|
5 |
import SimpleITK as sikt
|
6 |
import gradio as gr
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
|
9 |
device = torch.device('cpu') # Set to 'cuda' if using a GPU
|
10 |
dtype = torch.float32 # Data type for model processing
|
@@ -42,30 +41,12 @@ def process_image(image_path, question):
|
|
42 |
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
|
43 |
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
|
44 |
|
45 |
-
return generated_texts[0]
|
46 |
|
47 |
# Gradio Interface
|
48 |
def gradio_interface(image, question):
|
49 |
-
response
|
50 |
-
|
51 |
-
# Extract slices from the image
|
52 |
-
slices = []
|
53 |
-
for i in range(image_np.shape[0]): # Assuming the image is 3D
|
54 |
-
slices.append(image_np[i, :, :]) # Extract each slice
|
55 |
-
|
56 |
-
# Plot the slices and save them as images
|
57 |
-
fig, axes = plt.subplots(1, len(slices), figsize=(15, 5))
|
58 |
-
if len(slices) == 1:
|
59 |
-
axes = [axes]
|
60 |
-
|
61 |
-
for ax, slice_data in zip(axes, slices):
|
62 |
-
ax.imshow(slice_data, cmap='gray')
|
63 |
-
ax.axis('off')
|
64 |
-
|
65 |
-
plt.tight_layout()
|
66 |
-
plt.savefig('slices.png') # Save the slices as a PNG image
|
67 |
-
|
68 |
-
return response, 'slices.png'
|
69 |
|
70 |
# Gradio App
|
71 |
gr.Interface(
|
@@ -74,10 +55,7 @@ gr.Interface(
|
|
74 |
gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
|
75 |
gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
|
76 |
],
|
77 |
-
outputs=
|
78 |
-
gr.Textbox(label="Model Response"),
|
79 |
-
gr.Image(label="Image Slices", type="filepath", image_mode='L')
|
80 |
-
],
|
81 |
title="Medical Image Analysis",
|
82 |
-
description="Upload a .npy image and ask a question to analyze it using the model.
|
83 |
).launch()
|
|
|
4 |
import simple_slice_viewer as ssv
|
5 |
import SimpleITK as sikt
|
6 |
import gradio as gr
|
|
|
7 |
|
8 |
device = torch.device('cpu') # Set to 'cuda' if using a GPU
|
9 |
dtype = torch.float32 # Data type for model processing
|
|
|
41 |
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
|
42 |
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
|
43 |
|
44 |
+
return generated_texts[0]
|
45 |
|
46 |
# Gradio Interface
|
47 |
def gradio_interface(image, question):
|
48 |
+
response = process_image(image.name, question)
|
49 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Gradio App
|
52 |
gr.Interface(
|
|
|
55 |
gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
|
56 |
gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
|
57 |
],
|
58 |
+
outputs=gr.Textbox(label="Model Response"),
|
|
|
|
|
|
|
59 |
title="Medical Image Analysis",
|
60 |
+
description="Upload a .npy image and ask a question to analyze it using the model."
|
61 |
).launch()
|