Rohith1112 commited on
Commit
f55ed76
·
verified ·
1 Parent(s): 1105285

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -188
app.py CHANGED
@@ -6,37 +6,38 @@ import gradio as gr
6
  import matplotlib.pyplot as plt
7
  from datetime import datetime
8
  import json
 
9
 
10
  # Model setup
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
- dtype = torch.float32
13
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
  proj_out_num = 256
15
 
16
- # Create directory for saving chat histories
17
  os.makedirs('chat_histories', exist_ok=True)
 
18
 
19
  # Load model and tokenizer
20
  print("Loading model and tokenizer...")
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name_or_path,
23
- torch_dtype=torch.float32,
24
- device_map=device,
25
  trust_remote_code=True
26
  )
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(
29
  model_name_or_path,
30
- model_max_length=512,
31
  padding_side="right",
32
  use_fast=False,
33
  trust_remote_code=True
34
  )
35
  print("Model loaded successfully!")
36
 
37
- # Chat and image storage
38
  chat_history = []
39
- current_image = None
40
  session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
41
  chat_metadata = {
42
  "session_id": session_id,
@@ -45,252 +46,159 @@ chat_metadata = {
45
  }
46
 
47
  def save_chat_history():
48
- """Save the current chat history to a JSON file"""
49
  if not chat_history:
50
  return
51
-
52
  filename = f"chat_histories/session_{session_id}.json"
53
  data = {
54
  "metadata": chat_metadata,
55
  "conversation": [{"user": q, "assistant": a} for q, a in chat_history]
56
  }
57
-
58
  with open(filename, 'w', encoding='utf-8') as f:
59
  json.dump(data, f, ensure_ascii=False, indent=2)
60
-
61
  return filename
62
 
63
  def extract_and_display_images(image_path):
64
- """Process .npy file and create a visualization of the medical images"""
65
  try:
66
  npy_data = np.load(image_path)
67
 
68
- # Handle different possible shapes of the .npy file
69
- if npy_data.ndim == 4 and npy_data.shape[1] == 32:
70
- npy_data = npy_data[0] # Extract first batch if batched
71
- elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
72
- return None, "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
73
-
74
- # Update metadata with image information
75
- global chat_metadata
76
- chat_metadata["image_info"] = {
77
- "filename": os.path.basename(image_path),
78
- "shape": npy_data.shape,
79
- "processed_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
80
- }
81
 
82
- # Normalize for better visualization if needed
83
- for i in range(npy_data.shape[0]):
84
- slice_data = npy_data[i]
85
- if slice_data.max() > 0: # Avoid division by zero
86
- npy_data[i] = (slice_data - slice_data.min()) / (slice_data.max() - slice_data.min())
87
 
88
- # Create grid visualization
89
- rows, cols = 4, 8
90
- fig, axes = plt.subplots(rows, cols, figsize=(16, 8))
91
  for i, ax in enumerate(axes.flat):
92
- if i < npy_data.shape[0]:
93
- ax.imshow(npy_data[i], cmap='gray')
94
- ax.set_title(f"Slice {i+1}", fontsize=8)
95
  ax.axis('off')
 
96
 
97
  plt.tight_layout()
98
- image_output = f"temp_images/extracted_{session_id}.png"
99
- os.makedirs("temp_images", exist_ok=True)
100
- plt.savefig(image_output, bbox_inches='tight', dpi=150)
101
  plt.close()
 
 
 
 
 
102
 
103
- return image_output, "Image processed successfully!"
 
 
 
 
 
 
 
 
104
  except Exception as e:
105
- return None, f"Error processing image: {str(e)}"
106
 
107
- def process_image(question):
108
- """Process a question about the current medical image using the AI model"""
109
- global current_image
110
-
111
- if current_image is None:
112
- return "Please upload a medical image (.npy file) first."
113
-
114
  try:
115
- # Load the image data
116
- image_np = np.load(current_image)
117
-
118
- # Prepare input for the model
119
  image_tokens = "<im_patch>" * proj_out_num
120
- input_txt = image_tokens + question
121
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
122
-
123
- # Convert image to tensor
124
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
125
-
126
- # Generate response from model
127
- generation = model.generate(
128
- image_pt,
129
- input_id,
130
- max_new_tokens=256,
131
- do_sample=True,
132
- top_p=0.9,
133
- temperature=0.8 # Slightly reduced for more consistent responses
134
  )
135
-
136
- # Decode the generated text
137
- generated_text = tokenizer.batch_decode(generation, skip_special_tokens=True)[0]
138
-
139
- # Remove the input prompt from the response if needed
140
- if image_tokens in generated_text:
141
- generated_text = generated_text.split(image_tokens)[-1]
142
-
143
- return generated_text
144
-
145
  except Exception as e:
146
- return f"Error processing your question: {str(e)}"
147
 
148
  def chat_interface(question):
149
- """Handle the chat interface and maintain conversation history"""
150
  global chat_history
151
-
152
  if not question.strip():
153
  return chat_history
154
 
155
- # Process the question
156
- response = process_image(question)
157
-
158
- # Add to chat history
159
  chat_history.append((question, response))
160
-
161
- # Save chat history periodically
162
  save_chat_history()
163
-
164
- # Return the updated chat history for display
165
  return chat_history
166
 
167
  def upload_image(image):
168
- """Handle image upload and processing"""
169
- global current_image
170
-
171
  if image is None:
172
  return "No file uploaded.", None
173
-
174
- # Check if file exists and is .npy
175
- if not os.path.exists(image.name) or not image.name.lower().endswith('.npy'):
176
- return "Please upload a valid .npy file.", None
177
-
178
- # Set as current image
179
- current_image = image.name
180
-
181
- # Process and extract images
182
- extracted_image_path, status_message = extract_and_display_images(current_image)
183
-
184
  if extracted_image_path is None:
185
  return status_message, None
186
 
187
  return status_message, extracted_image_path
188
 
189
  def clear_conversation():
190
- """Clear the current conversation history"""
191
  global chat_history
192
- old_history = chat_history.copy()
193
  chat_history = []
194
- return [], f"Conversation cleared. Previous conversation saved to {save_chat_history()}"
195
 
196
- # CSS for better UI
197
  custom_css = """
198
  .gradio-container {max-width: 1200px !important}
199
  #chat-history {height: 400px; overflow-y: auto;}
200
- .image-preview {border-radius: 10px; border: 1px solid #ddd;}
201
  """
202
 
203
- # Gradio UI
204
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as chat_ui:
205
  with gr.Row():
206
  with gr.Column(scale=3):
207
- gr.Markdown("# ICliniq AI-Powered Medical Image Analysis")
208
  gr.Markdown("""
209
- This system analyzes medical images in .npy format and answers your questions.
210
-
211
- ## How to use:
212
- 1. Upload your medical image (.npy format)
213
- 2. Wait for the image to be processed
214
- 3. Ask questions about the image
215
- """)
216
-
217
- with gr.Row():
218
- with gr.Column(scale=1):
219
- uploaded_image = gr.File(
220
- label="Upload Medical Image (.npy format)",
221
- file_types=[".npy"],
222
- type="filepath"
223
- )
224
-
225
- with gr.Column(scale=1):
226
- upload_status = gr.Textbox(
227
- label="Upload Status",
228
- interactive=False
229
- )
230
-
231
- extracted_image = gr.Image(
232
- label="Processed Image Preview",
233
- elem_id="image-preview"
234
- )
235
-
236
  with gr.Column(scale=4):
237
- chat_list = gr.Chatbot(
238
- value=[],
239
- label="Conversation",
240
- elem_id="chat-history",
241
- height=500
242
- )
243
-
244
- with gr.Row():
245
- question_input = gr.Textbox(
246
- label="Ask about the medical image",
247
- placeholder="What abnormalities do you see in this scan?",
248
- lines=2
249
- )
250
-
251
  with gr.Row():
252
- clear_button = gr.Button("Clear Conversation", variant="secondary")
253
  submit_button = gr.Button("Send Question", variant="primary")
254
-
255
- gr.Markdown("### System Status")
256
  system_status = gr.Textbox(
257
- label="",
258
  value=f"Model loaded: {model_name_or_path}\nDevice: {device}\nSession ID: {session_id}",
259
  interactive=False
260
  )
261
-
262
- # Set up event handlers
263
- uploaded_image.upload(
264
- upload_image,
265
- inputs=[uploaded_image],
266
- outputs=[upload_status, extracted_image]
267
- )
268
-
269
- submit_button.click(
270
- chat_interface,
271
- inputs=[question_input],
272
- outputs=[chat_list]
273
- ).then(
274
- lambda: "", # Clear input after sending
275
- outputs=question_input
276
- )
277
-
278
- question_input.submit(
279
- chat_interface,
280
- inputs=[question_input],
281
- outputs=[chat_list]
282
- ).then(
283
- lambda: "", # Clear input after sending
284
- outputs=question_input
285
- )
286
-
287
- clear_button.click(
288
- clear_conversation,
289
- inputs=[],
290
- outputs=[chat_list, system_status]
291
- )
292
 
293
- # Launch the interface
 
 
 
 
 
294
  if __name__ == "__main__":
295
- print("Starting ICliniq Medical Image Analysis System...")
296
- chat_ui.launch(share=True)
 
6
  import matplotlib.pyplot as plt
7
  from datetime import datetime
8
  import json
9
+ from PIL import Image
10
 
11
  # Model setup
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
13
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
  proj_out_num = 256
15
 
16
+ # Create directory for saving chat histories and temp images
17
  os.makedirs('chat_histories', exist_ok=True)
18
+ os.makedirs('temp_images', exist_ok=True)
19
 
20
  # Load model and tokenizer
21
  print("Loading model and tokenizer...")
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_name_or_path,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
+ device_map="auto",
26
  trust_remote_code=True
27
  )
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(
30
  model_name_or_path,
31
+ model_max_length=4096,
32
  padding_side="right",
33
  use_fast=False,
34
  trust_remote_code=True
35
  )
36
  print("Model loaded successfully!")
37
 
38
+ # Session and chat history
39
  chat_history = []
40
+ current_image_path = None
41
  session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
42
  chat_metadata = {
43
  "session_id": session_id,
 
46
  }
47
 
48
  def save_chat_history():
49
+ """Save the chat history into a JSON file."""
50
  if not chat_history:
51
  return
 
52
  filename = f"chat_histories/session_{session_id}.json"
53
  data = {
54
  "metadata": chat_metadata,
55
  "conversation": [{"user": q, "assistant": a} for q, a in chat_history]
56
  }
 
57
  with open(filename, 'w', encoding='utf-8') as f:
58
  json.dump(data, f, ensure_ascii=False, indent=2)
 
59
  return filename
60
 
61
  def extract_and_display_images(image_path):
62
+ """Extract slices from .npy medical file and create a JPEG preview."""
63
  try:
64
  npy_data = np.load(image_path)
65
 
66
+ if npy_data.ndim == 4:
67
+ npy_data = npy_data[0] # Take first batch
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ if npy_data.shape[0] != 32:
70
+ return None, "Invalid .npy shape. Expected 32 slices."
71
+
72
+ # Normalize slices
73
+ npy_data = (npy_data - npy_data.min()) / (npy_data.max() - npy_data.min())
74
 
75
+ # Create visualization grid
76
+ fig, axes = plt.subplots(4, 8, figsize=(16, 8))
 
77
  for i, ax in enumerate(axes.flat):
78
+ ax.imshow(npy_data[i], cmap='gray')
 
 
79
  ax.axis('off')
80
+ ax.set_title(f"Slice {i+1}", fontsize=8)
81
 
82
  plt.tight_layout()
83
+ temp_png = f"temp_images/preview_{session_id}.png"
84
+ plt.savefig(temp_png, dpi=150, bbox_inches='tight')
 
85
  plt.close()
86
+
87
+ # Convert PNG to JPEG if needed
88
+ img = Image.open(temp_png).convert("RGB")
89
+ temp_jpeg = f"temp_images/preview_{session_id}.jpg"
90
+ img.save(temp_jpeg, format="JPEG", quality=95)
91
 
92
+ # Update metadata
93
+ chat_metadata["image_info"] = {
94
+ "filename": os.path.basename(image_path),
95
+ "shape": npy_data.shape,
96
+ "processed_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
97
+ }
98
+
99
+ return temp_jpeg, "Image processed successfully!"
100
+
101
  except Exception as e:
102
+ return None, f"Error: {str(e)}"
103
 
104
+ def process_image_question(question):
105
+ """Process user question about uploaded medical image."""
106
+ if current_image_path is None:
107
+ return "Please upload a medical image first."
 
 
 
108
  try:
109
+ # Create fake image patch tokens
 
 
 
110
  image_tokens = "<im_patch>" * proj_out_num
111
+ input_prompt = image_tokens + question
112
+
113
+ # Tokenize input
114
+ input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to(device)
115
+
116
+ # Generate answer
117
+ output = model.generate(
118
+ input_ids=input_ids,
119
+ max_new_tokens=256,
120
+ do_sample=True,
121
+ top_p=0.9,
122
+ temperature=0.7
 
 
123
  )
124
+
125
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
126
+ if image_tokens in answer:
127
+ answer = answer.split(image_tokens)[-1]
128
+
129
+ return answer.strip()
 
 
 
 
130
  except Exception as e:
131
+ return f"Error answering question: {str(e)}"
132
 
133
  def chat_interface(question):
134
+ """Handles chat conversation."""
135
  global chat_history
 
136
  if not question.strip():
137
  return chat_history
138
 
139
+ response = process_image_question(question)
 
 
 
140
  chat_history.append((question, response))
 
 
141
  save_chat_history()
 
 
142
  return chat_history
143
 
144
  def upload_image(image):
145
+ """Handles image upload."""
146
+ global current_image_path
 
147
  if image is None:
148
  return "No file uploaded.", None
149
+
150
+ if not image.name.lower().endswith('.npy'):
151
+ return "Please upload a .npy file only.", None
152
+
153
+ current_image_path = image.name
154
+ extracted_image_path, status_message = extract_and_display_images(current_image_path)
 
 
 
 
 
155
  if extracted_image_path is None:
156
  return status_message, None
157
 
158
  return status_message, extracted_image_path
159
 
160
  def clear_conversation():
161
+ """Clears chat conversation."""
162
  global chat_history
163
+ old_chat = chat_history.copy()
164
  chat_history = []
165
+ return [], f"Conversation cleared. Saved to {save_chat_history()}."
166
 
167
+ # Custom CSS
168
  custom_css = """
169
  .gradio-container {max-width: 1200px !important}
170
  #chat-history {height: 400px; overflow-y: auto;}
 
171
  """
172
 
173
+ # Build Gradio UI
174
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
175
  with gr.Row():
176
  with gr.Column(scale=3):
177
+ gr.Markdown("# 🏥 ICliniq AI - Medical Image Analyzer")
178
  gr.Markdown("""
179
+ Upload a **.npy** medical scan file, view extracted slices, and ask clinical questions.
180
+ """)
181
+ uploaded_image = gr.File(label="Upload Medical Image (.npy)", file_types=[".npy"], type="filepath")
182
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
183
+ extracted_image = gr.Image(label="Preview of Medical Image", elem_id="image-preview")
184
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Column(scale=4):
186
+ chat_list = gr.Chatbot(value=[], label="Conversation", elem_id="chat-history", height=500)
187
+ question_input = gr.Textbox(label="Ask your question", placeholder="e.g., Are there fractures visible?")
 
 
 
 
 
 
 
 
 
 
 
 
188
  with gr.Row():
 
189
  submit_button = gr.Button("Send Question", variant="primary")
190
+ clear_button = gr.Button("Clear Conversation", variant="secondary")
 
191
  system_status = gr.Textbox(
 
192
  value=f"Model loaded: {model_name_or_path}\nDevice: {device}\nSession ID: {session_id}",
193
  interactive=False
194
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ uploaded_image.upload(upload_image, inputs=[uploaded_image], outputs=[upload_status, extracted_image])
197
+ submit_button.click(chat_interface, inputs=[question_input], outputs=[chat_list]).then(lambda: "", outputs=question_input)
198
+ question_input.submit(chat_interface, inputs=[question_input], outputs=[chat_list]).then(lambda: "", outputs=question_input)
199
+ clear_button.click(clear_conversation, inputs=[], outputs=[chat_list, system_status])
200
+
201
+ # Run
202
  if __name__ == "__main__":
203
+ print("Launching Medical Image Analyzer...")
204
+ demo.launch(share=True)