Rohith1112 commited on
Commit
0c5b5f7
·
verified ·
1 Parent(s): e483357
Files changed (1) hide show
  1. app.py +27 -5
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import simple_slice_viewer as ssv
5
  import SimpleITK as sikt
6
  import gradio as gr
 
7
 
8
  device = torch.device('cpu') # Set to 'cuda' if using a GPU
9
  dtype = torch.float32 # Data type for model processing
@@ -41,12 +42,30 @@ def process_image(image_path, question):
41
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
42
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
43
 
44
- return generated_texts[0]
45
 
46
  # Gradio Interface
47
  def gradio_interface(image, question):
48
- response = process_image(image.name, question)
49
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Gradio App
52
  gr.Interface(
@@ -55,7 +74,10 @@ gr.Interface(
55
  gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
56
  gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
57
  ],
58
- outputs=gr.Textbox(label="Model Response"),
 
 
 
59
  title="Medical Image Analysis",
60
- description="Upload a .npy image and ask a question to analyze it using the model."
61
  ).launch()
 
4
  import simple_slice_viewer as ssv
5
  import SimpleITK as sikt
6
  import gradio as gr
7
+ import matplotlib.pyplot as plt
8
 
9
  device = torch.device('cpu') # Set to 'cuda' if using a GPU
10
  dtype = torch.float32 # Data type for model processing
 
42
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
43
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
44
 
45
+ return generated_texts[0], image_np
46
 
47
  # Gradio Interface
48
  def gradio_interface(image, question):
49
+ response, image_np = process_image(image.name, question)
50
+
51
+ # Extract slices from the image
52
+ slices = []
53
+ for i in range(image_np.shape[0]): # Assuming the image is 3D
54
+ slices.append(image_np[i, :, :]) # Extract each slice
55
+
56
+ # Plot the slices and save them as images
57
+ fig, axes = plt.subplots(1, len(slices), figsize=(15, 5))
58
+ if len(slices) == 1:
59
+ axes = [axes]
60
+
61
+ for ax, slice_data in zip(axes, slices):
62
+ ax.imshow(slice_data, cmap='gray')
63
+ ax.axis('off')
64
+
65
+ plt.tight_layout()
66
+ plt.savefig('slices.png') # Save the slices as a PNG image
67
+
68
+ return response, 'slices.png'
69
 
70
  # Gradio App
71
  gr.Interface(
 
74
  gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
75
  gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
76
  ],
77
+ outputs=[
78
+ gr.Textbox(label="Model Response"),
79
+ gr.Image(label="Image Slices", type="filepath", image_mode='L')
80
+ ],
81
  title="Medical Image Analysis",
82
+ description="Upload a .npy image and ask a question to analyze it using the model. The image slices will be displayed."
83
  ).launch()