Rohith1112 commited on
Commit
9cd5c21
·
verified ·
1 Parent(s): 8b6818e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -85
app.py CHANGED
@@ -1,99 +1,296 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import gradio as gr
6
- import matplotlib.pyplot as plt
 
 
7
 
8
- # Model setup
9
- device = torch.device('cpu') # Use 'cuda' if GPU is available
10
- dtype = torch.float32
11
- model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
- proj_out_num = 256
13
 
14
- # Load model and tokenizer
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name_or_path,
17
- torch_dtype=torch.float32,
18
- device_map='cpu',
19
- trust_remote_code=True
20
- )
21
 
22
- tokenizer = AutoTokenizer.from_pretrained(
23
- model_name_or_path,
24
- model_max_length=512,
25
- padding_side="right",
26
- use_fast=False,
27
- trust_remote_code=True
28
- )
 
29
 
30
- # Chat history storage
31
- chat_history = []
32
- current_image = None
 
 
 
 
 
33
 
34
- def extract_and_display_images(image_path):
35
- npy_data = np.load(image_path)
36
- if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
- npy_data = npy_data[0]
38
- elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
- return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
-
41
- fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
- for i, ax in enumerate(axes.flat):
43
- ax.imshow(npy_data[i], cmap='gray')
44
- ax.axis('off')
45
-
46
- image_output = "extracted_images.png"
47
- plt.savefig(image_output, bbox_inches='tight')
48
- plt.close()
49
- return image_output
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def process_image(question):
53
- global current_image
54
- if current_image is None:
55
- return "Please upload an image first."
56
-
57
- image_np = np.load(current_image)
58
- image_tokens = "<im_patch>" * proj_out_num
59
- input_txt = image_tokens + question
60
- input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
-
62
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
63
- generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
- generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
- return generated_texts[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- def chat_interface(question):
69
- global chat_history
70
- response = process_image(question)
71
- chat_history.append((question, response)) # Save dynamic chat history
72
- return chat_history # Return updated chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- def upload_image(image):
76
- global current_image
77
- current_image = image.name
78
- extracted_image_path = extract_and_display_images(current_image)
79
- return "Image uploaded and processed successfully!", extracted_image_path
 
80
 
81
- # Gradio UI
82
- with gr.Blocks(theme="soft") as chat_ui:
83
- with gr.Row():
84
- # Dynamic Chat Area
85
- with gr.Column(scale=4):
86
- gr.Markdown("### ICliniq AI-Powered Medical Image Analysis")
87
- chat_list = gr.Chatbot(value=[], label="Chat History", type='messages', elem_id="chat-history") # Dynamic chat with messages type
88
- uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
89
- upload_status = gr.Textbox(label="Status", interactive=False)
90
- extracted_image = gr.Image(label="Extracted Images")
91
- question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
92
- submit_button = gr.Button("Send")
93
 
94
- # Upload and Processing Interactions
95
- uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
96
- submit_button.click(chat_interface, question_input, chat_list)
97
- question_input.submit(chat_interface, question_input, chat_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- chat_ui.launch()
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ from datetime import datetime
8
+ import json
9
 
10
+ # Model setup
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ dtype = torch.float32
13
+ model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
+ proj_out_num = 256
15
 
16
+ # Create directory for saving chat histories
17
+ os.makedirs('chat_histories', exist_ok=True)
 
 
 
 
 
18
 
19
+ # Load model and tokenizer
20
+ print("Loading model and tokenizer...")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name_or_path,
23
+ torch_dtype=torch.float32,
24
+ device_map=device,
25
+ trust_remote_code=True
26
+ )
27
 
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name_or_path,
30
+ model_max_length=512,
31
+ padding_side="right",
32
+ use_fast=False,
33
+ trust_remote_code=True
34
+ )
35
+ print("Model loaded successfully!")
36
 
37
+ # Chat and image storage
38
+ chat_history = []
39
+ current_image = None
40
+ session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
41
+ chat_metadata = {
42
+ "session_id": session_id,
43
+ "start_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
44
+ "image_info": None
45
+ }
 
 
 
 
 
 
 
46
 
47
+ def save_chat_history():
48
+ """Save the current chat history to a JSON file"""
49
+ if not chat_history:
50
+ return
51
+
52
+ filename = f"chat_histories/session_{session_id}.json"
53
+ data = {
54
+ "metadata": chat_metadata,
55
+ "conversation": [{"user": q, "assistant": a} for q, a in chat_history]
56
+ }
57
+
58
+ with open(filename, 'w', encoding='utf-8') as f:
59
+ json.dump(data, f, ensure_ascii=False, indent=2)
60
+
61
+ return filename
62
 
63
+ def extract_and_display_images(image_path):
64
+ """Process .npy file and create a visualization of the medical images"""
65
+ try:
66
+ npy_data = np.load(image_path)
67
+
68
+ # Handle different possible shapes of the .npy file
69
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
70
+ npy_data = npy_data[0] # Extract first batch if batched
71
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
72
+ return None, "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
73
+
74
+ # Update metadata with image information
75
+ global chat_metadata
76
+ chat_metadata["image_info"] = {
77
+ "filename": os.path.basename(image_path),
78
+ "shape": npy_data.shape,
79
+ "processed_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
80
+ }
81
+
82
+ # Normalize for better visualization if needed
83
+ for i in range(npy_data.shape[0]):
84
+ slice_data = npy_data[i]
85
+ if slice_data.max() > 0: # Avoid division by zero
86
+ npy_data[i] = (slice_data - slice_data.min()) / (slice_data.max() - slice_data.min())
87
+
88
+ # Create grid visualization
89
+ rows, cols = 4, 8
90
+ fig, axes = plt.subplots(rows, cols, figsize=(16, 8))
91
+ for i, ax in enumerate(axes.flat):
92
+ if i < npy_data.shape[0]:
93
+ ax.imshow(npy_data[i], cmap='gray')
94
+ ax.set_title(f"Slice {i+1}", fontsize=8)
95
+ ax.axis('off')
96
+
97
+ plt.tight_layout()
98
+ image_output = f"temp_images/extracted_{session_id}.png"
99
+ os.makedirs("temp_images", exist_ok=True)
100
+ plt.savefig(image_output, bbox_inches='tight', dpi=150)
101
+ plt.close()
102
+
103
+ return image_output, "Image processed successfully!"
104
+ except Exception as e:
105
+ return None, f"Error processing image: {str(e)}"
106
 
107
+ def process_image(question):
108
+ """Process a question about the current medical image using the AI model"""
109
+ global current_image
110
+
111
+ if current_image is None:
112
+ return "Please upload a medical image (.npy file) first."
113
+
114
+ try:
115
+ # Load the image data
116
+ image_np = np.load(current_image)
117
+
118
+ # Prepare input for the model
119
+ image_tokens = "<im_patch>" * proj_out_num
120
+ input_txt = image_tokens + question
121
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
122
+
123
+ # Convert image to tensor
124
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
125
+
126
+ # Generate response from model
127
+ generation = model.generate(
128
+ image_pt,
129
+ input_id,
130
+ max_new_tokens=256,
131
+ do_sample=True,
132
+ top_p=0.9,
133
+ temperature=0.8 # Slightly reduced for more consistent responses
134
+ )
135
+
136
+ # Decode the generated text
137
+ generated_text = tokenizer.batch_decode(generation, skip_special_tokens=True)[0]
138
+
139
+ # Remove the input prompt from the response if needed
140
+ if image_tokens in generated_text:
141
+ generated_text = generated_text.split(image_tokens)[-1]
142
+
143
+ return generated_text
144
+
145
+ except Exception as e:
146
+ return f"Error processing your question: {str(e)}"
147
 
148
+ def chat_interface(question):
149
+ """Handle the chat interface and maintain conversation history"""
150
+ global chat_history
151
+
152
+ if not question.strip():
153
+ return chat_history
154
+
155
+ # Process the question
156
+ response = process_image(question)
157
+
158
+ # Add to chat history
159
+ chat_history.append((question, response))
160
+
161
+ # Save chat history periodically
162
+ save_chat_history()
163
+
164
+ # Return the updated chat history for display
165
+ return chat_history
166
 
167
+ def upload_image(image):
168
+ """Handle image upload and processing"""
169
+ global current_image
170
+
171
+ if image is None:
172
+ return "No file uploaded.", None
173
+
174
+ # Check if file exists and is .npy
175
+ if not os.path.exists(image.name) or not image.name.lower().endswith('.npy'):
176
+ return "Please upload a valid .npy file.", None
177
+
178
+ # Set as current image
179
+ current_image = image.name
180
+
181
+ # Process and extract images
182
+ extracted_image_path, status_message = extract_and_display_images(current_image)
183
+
184
+ if extracted_image_path is None:
185
+ return status_message, None
186
+
187
+ return status_message, extracted_image_path
188
 
189
+ def clear_conversation():
190
+ """Clear the current conversation history"""
191
+ global chat_history
192
+ old_history = chat_history.copy()
193
+ chat_history = []
194
+ return [], f"Conversation cleared. Previous conversation saved to {save_chat_history()}"
195
 
196
+ # CSS for better UI
197
+ custom_css = """
198
+ .gradio-container {max-width: 1200px !important}
199
+ #chat-history {height: 400px; overflow-y: auto;}
200
+ .image-preview {border-radius: 10px; border: 1px solid #ddd;}
201
+ """
 
 
 
 
 
 
202
 
203
+ # Gradio UI
204
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as chat_ui:
205
+ with gr.Row():
206
+ with gr.Column(scale=3):
207
+ gr.Markdown("# ICliniq AI-Powered Medical Image Analysis")
208
+ gr.Markdown("""
209
+ This system analyzes medical images in .npy format and answers your questions.
210
+
211
+ ## How to use:
212
+ 1. Upload your medical image (.npy format)
213
+ 2. Wait for the image to be processed
214
+ 3. Ask questions about the image
215
+ """)
216
+
217
+ with gr.Row():
218
+ with gr.Column(scale=1):
219
+ uploaded_image = gr.File(
220
+ label="Upload Medical Image (.npy format)",
221
+ file_types=[".npy"],
222
+ type="file"
223
+ )
224
+
225
+ with gr.Column(scale=1):
226
+ upload_status = gr.Textbox(
227
+ label="Upload Status",
228
+ interactive=False
229
+ )
230
+
231
+ extracted_image = gr.Image(
232
+ label="Processed Image Preview",
233
+ elem_id="image-preview"
234
+ )
235
+
236
+ with gr.Column(scale=4):
237
+ chat_list = gr.Chatbot(
238
+ value=[],
239
+ label="Conversation",
240
+ elem_id="chat-history",
241
+ height=500
242
+ )
243
+
244
+ with gr.Row():
245
+ question_input = gr.Textbox(
246
+ label="Ask about the medical image",
247
+ placeholder="What abnormalities do you see in this scan?",
248
+ lines=2
249
+ )
250
+
251
+ with gr.Row():
252
+ clear_button = gr.Button("Clear Conversation", variant="secondary")
253
+ submit_button = gr.Button("Send Question", variant="primary")
254
+
255
+ gr.Markdown("### System Status")
256
+ system_status = gr.Textbox(
257
+ label="",
258
+ value=f"Model loaded: {model_name_or_path}\nDevice: {device}\nSession ID: {session_id}",
259
+ interactive=False
260
+ )
261
+
262
+ # Set up event handlers
263
+ uploaded_image.upload(
264
+ upload_image,
265
+ inputs=[uploaded_image],
266
+ outputs=[upload_status, extracted_image]
267
+ )
268
+
269
+ submit_button.click(
270
+ chat_interface,
271
+ inputs=[question_input],
272
+ outputs=[chat_list]
273
+ ).then(
274
+ lambda: "", # Clear input after sending
275
+ outputs=question_input
276
+ )
277
+
278
+ question_input.submit(
279
+ chat_interface,
280
+ inputs=[question_input],
281
+ outputs=[chat_list]
282
+ ).then(
283
+ lambda: "", # Clear input after sending
284
+ outputs=question_input
285
+ )
286
+
287
+ clear_button.click(
288
+ clear_conversation,
289
+ inputs=[],
290
+ outputs=[chat_list, system_status]
291
+ )
292
 
293
+ # Launch the interface
294
+ if __name__ == "__main__":
295
+ print("Starting ICliniq Medical Image Analysis System...")
296
+ chat_ui.launch(share=True)