Rohith1112 commited on
Commit
cecc48c
·
verified ·
1 Parent(s): 0c5b5f7
Files changed (1) hide show
  1. app.py +5 -27
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import simple_slice_viewer as ssv
5
  import SimpleITK as sikt
6
  import gradio as gr
7
- import matplotlib.pyplot as plt
8
 
9
  device = torch.device('cpu') # Set to 'cuda' if using a GPU
10
  dtype = torch.float32 # Data type for model processing
@@ -42,30 +41,12 @@ def process_image(image_path, question):
42
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
43
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
44
 
45
- return generated_texts[0], image_np
46
 
47
  # Gradio Interface
48
  def gradio_interface(image, question):
49
- response, image_np = process_image(image.name, question)
50
-
51
- # Extract slices from the image
52
- slices = []
53
- for i in range(image_np.shape[0]): # Assuming the image is 3D
54
- slices.append(image_np[i, :, :]) # Extract each slice
55
-
56
- # Plot the slices and save them as images
57
- fig, axes = plt.subplots(1, len(slices), figsize=(15, 5))
58
- if len(slices) == 1:
59
- axes = [axes]
60
-
61
- for ax, slice_data in zip(axes, slices):
62
- ax.imshow(slice_data, cmap='gray')
63
- ax.axis('off')
64
-
65
- plt.tight_layout()
66
- plt.savefig('slices.png') # Save the slices as a PNG image
67
-
68
- return response, 'slices.png'
69
 
70
  # Gradio App
71
  gr.Interface(
@@ -74,10 +55,7 @@ gr.Interface(
74
  gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
75
  gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
76
  ],
77
- outputs=[
78
- gr.Textbox(label="Model Response"),
79
- gr.Image(label="Image Slices", type="filepath", image_mode='L')
80
- ],
81
  title="Medical Image Analysis",
82
- description="Upload a .npy image and ask a question to analyze it using the model. The image slices will be displayed."
83
  ).launch()
 
4
  import simple_slice_viewer as ssv
5
  import SimpleITK as sikt
6
  import gradio as gr
 
7
 
8
  device = torch.device('cpu') # Set to 'cuda' if using a GPU
9
  dtype = torch.float32 # Data type for model processing
 
41
  generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
42
  generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
43
 
44
+ return generated_texts[0]
45
 
46
  # Gradio Interface
47
  def gradio_interface(image, question):
48
+ response = process_image(image.name, question)
49
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Gradio App
52
  gr.Interface(
 
55
  gr.File(label="Upload .npy Image", type="filepath"), # For uploading .npy image
56
  gr.Textbox(label="Enter your question", placeholder="Ask something about the image..."),
57
  ],
58
+ outputs=gr.Textbox(label="Model Response"),
 
 
 
59
  title="Medical Image Analysis",
60
+ description="Upload a .npy image and ask a question to analyze it using the model."
61
  ).launch()