Rohith1112 commited on
Commit
4ea26bf
·
verified ·
1 Parent(s): f55ed76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -163
app.py CHANGED
@@ -4,201 +4,95 @@ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
- from datetime import datetime
8
- import json
9
- from PIL import Image
10
 
11
  # Model setup
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
13
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
14
  proj_out_num = 256
15
 
16
- # Create directory for saving chat histories and temp images
17
- os.makedirs('chat_histories', exist_ok=True)
18
- os.makedirs('temp_images', exist_ok=True)
19
-
20
  # Load model and tokenizer
21
- print("Loading model and tokenizer...")
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_name_or_path,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
- device_map="auto",
26
  trust_remote_code=True
27
  )
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(
30
  model_name_or_path,
31
- model_max_length=4096,
32
  padding_side="right",
33
  use_fast=False,
34
  trust_remote_code=True
35
  )
36
- print("Model loaded successfully!")
37
 
38
- # Session and chat history
39
  chat_history = []
40
- current_image_path = None
41
- session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
42
- chat_metadata = {
43
- "session_id": session_id,
44
- "start_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
45
- "image_info": None
46
- }
47
-
48
- def save_chat_history():
49
- """Save the chat history into a JSON file."""
50
- if not chat_history:
51
- return
52
- filename = f"chat_histories/session_{session_id}.json"
53
- data = {
54
- "metadata": chat_metadata,
55
- "conversation": [{"user": q, "assistant": a} for q, a in chat_history]
56
- }
57
- with open(filename, 'w', encoding='utf-8') as f:
58
- json.dump(data, f, ensure_ascii=False, indent=2)
59
- return filename
60
 
61
  def extract_and_display_images(image_path):
62
- """Extract slices from .npy medical file and create a JPEG preview."""
63
- try:
64
- npy_data = np.load(image_path)
65
-
66
- if npy_data.ndim == 4:
67
- npy_data = npy_data[0] # Take first batch
68
-
69
- if npy_data.shape[0] != 32:
70
- return None, "Invalid .npy shape. Expected 32 slices."
71
-
72
- # Normalize slices
73
- npy_data = (npy_data - npy_data.min()) / (npy_data.max() - npy_data.min())
74
-
75
- # Create visualization grid
76
- fig, axes = plt.subplots(4, 8, figsize=(16, 8))
77
- for i, ax in enumerate(axes.flat):
78
- ax.imshow(npy_data[i], cmap='gray')
79
- ax.axis('off')
80
- ax.set_title(f"Slice {i+1}", fontsize=8)
81
-
82
- plt.tight_layout()
83
- temp_png = f"temp_images/preview_{session_id}.png"
84
- plt.savefig(temp_png, dpi=150, bbox_inches='tight')
85
- plt.close()
86
-
87
- # Convert PNG to JPEG if needed
88
- img = Image.open(temp_png).convert("RGB")
89
- temp_jpeg = f"temp_images/preview_{session_id}.jpg"
90
- img.save(temp_jpeg, format="JPEG", quality=95)
91
-
92
- # Update metadata
93
- chat_metadata["image_info"] = {
94
- "filename": os.path.basename(image_path),
95
- "shape": npy_data.shape,
96
- "processed_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
97
- }
98
-
99
- return temp_jpeg, "Image processed successfully!"
100
 
101
- except Exception as e:
102
- return None, f"Error: {str(e)}"
103
-
104
- def process_image_question(question):
105
- """Process user question about uploaded medical image."""
106
- if current_image_path is None:
107
- return "Please upload a medical image first."
108
- try:
109
- # Create fake image patch tokens
110
- image_tokens = "<im_patch>" * proj_out_num
111
- input_prompt = image_tokens + question
112
-
113
- # Tokenize input
114
- input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to(device)
115
 
116
- # Generate answer
117
- output = model.generate(
118
- input_ids=input_ids,
119
- max_new_tokens=256,
120
- do_sample=True,
121
- top_p=0.9,
122
- temperature=0.7
123
- )
124
 
125
- answer = tokenizer.decode(output[0], skip_special_tokens=True)
126
- if image_tokens in answer:
127
- answer = answer.split(image_tokens)[-1]
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- return answer.strip()
130
- except Exception as e:
131
- return f"Error answering question: {str(e)}"
132
 
133
  def chat_interface(question):
134
- """Handles chat conversation."""
135
  global chat_history
136
- if not question.strip():
137
- return chat_history
138
-
139
- response = process_image_question(question)
140
  chat_history.append((question, response))
141
- save_chat_history()
142
  return chat_history
143
 
144
- def upload_image(image):
145
- """Handles image upload."""
146
- global current_image_path
147
- if image is None:
148
- return "No file uploaded.", None
149
-
150
- if not image.name.lower().endswith('.npy'):
151
- return "Please upload a .npy file only.", None
152
-
153
- current_image_path = image.name
154
- extracted_image_path, status_message = extract_and_display_images(current_image_path)
155
- if extracted_image_path is None:
156
- return status_message, None
157
-
158
- return status_message, extracted_image_path
159
-
160
- def clear_conversation():
161
- """Clears chat conversation."""
162
- global chat_history
163
- old_chat = chat_history.copy()
164
- chat_history = []
165
- return [], f"Conversation cleared. Saved to {save_chat_history()}."
166
-
167
- # Custom CSS
168
- custom_css = """
169
- .gradio-container {max-width: 1200px !important}
170
- #chat-history {height: 400px; overflow-y: auto;}
171
- """
172
 
173
- # Build Gradio UI
174
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
 
 
 
 
 
 
 
175
  with gr.Row():
176
- with gr.Column(scale=3):
177
- gr.Markdown("# 🏥 ICliniq AI - Medical Image Analyzer")
178
- gr.Markdown("""
179
- Upload a **.npy** medical scan file, view extracted slices, and ask clinical questions.
180
- """)
181
- uploaded_image = gr.File(label="Upload Medical Image (.npy)", file_types=[".npy"], type="filepath")
182
- upload_status = gr.Textbox(label="Upload Status", interactive=False)
183
- extracted_image = gr.Image(label="Preview of Medical Image", elem_id="image-preview")
184
-
185
  with gr.Column(scale=4):
186
- chat_list = gr.Chatbot(value=[], label="Conversation", elem_id="chat-history", height=500)
187
- question_input = gr.Textbox(label="Ask your question", placeholder="e.g., Are there fractures visible?")
188
- with gr.Row():
189
- submit_button = gr.Button("Send Question", variant="primary")
190
- clear_button = gr.Button("Clear Conversation", variant="secondary")
191
- system_status = gr.Textbox(
192
- value=f"Model loaded: {model_name_or_path}\nDevice: {device}\nSession ID: {session_id}",
193
- interactive=False
194
- )
195
-
196
- uploaded_image.upload(upload_image, inputs=[uploaded_image], outputs=[upload_status, extracted_image])
197
- submit_button.click(chat_interface, inputs=[question_input], outputs=[chat_list]).then(lambda: "", outputs=question_input)
198
- question_input.submit(chat_interface, inputs=[question_input], outputs=[chat_list]).then(lambda: "", outputs=question_input)
199
- clear_button.click(clear_conversation, inputs=[], outputs=[chat_list, system_status])
200
 
201
- # Run
202
- if __name__ == "__main__":
203
- print("Launching Medical Image Analyzer...")
204
- demo.launch(share=True)
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
 
 
7
 
8
  # Model setup
9
+ device = torch.device('cpu') # Use 'cuda' if GPU is available
10
+ dtype = torch.float32
11
  model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
  proj_out_num = 256
13
 
 
 
 
 
14
  # Load model and tokenizer
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name_or_path,
17
+ torch_dtype=torch.float32,
18
+ device_map='cpu',
19
  trust_remote_code=True
20
  )
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(
23
  model_name_or_path,
24
+ model_max_length=512,
25
  padding_side="right",
26
  use_fast=False,
27
  trust_remote_code=True
28
  )
 
29
 
30
+ # Chat history storage
31
  chat_history = []
32
+ current_image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def extract_and_display_images(image_path):
35
+ npy_data = np.load(image_path)
36
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
+ npy_data = npy_data[0]
38
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
+ return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
+ for i, ax in enumerate(axes.flat):
43
+ ax.imshow(npy_data[i], cmap='gray')
44
+ ax.axis('off')
45
+
46
+ image_output = "extracted_images.png"
47
+ plt.savefig(image_output, bbox_inches='tight')
48
+ plt.close()
49
+ return image_output
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
51
 
52
+ def process_image(question):
53
+ global current_image
54
+ if current_image is None:
55
+ return "Please upload an image first."
56
+
57
+ image_np = np.load(current_image)
58
+ image_tokens = "<im_patch>" * proj_out_num
59
+ input_txt = image_tokens + question
60
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
+
62
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
63
+ generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
+ generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
+ return generated_texts[0]
66
 
 
 
 
67
 
68
  def chat_interface(question):
 
69
  global chat_history
70
+ response = process_image(question)
 
 
 
71
  chat_history.append((question, response))
 
72
  return chat_history
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def upload_image(image):
76
+ global current_image
77
+ current_image = image.name
78
+ extracted_image_path = extract_and_display_images(current_image)
79
+ return "Image uploaded and processed successfully!", extracted_image_path
80
+
81
+ # Gradio UI
82
+ with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
83
+ gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
84
  with gr.Row():
85
+ with gr.Column(scale=1, min_width=200):
86
+ chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
 
 
 
 
 
 
 
87
  with gr.Column(scale=4):
88
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
89
+ upload_status = gr.Textbox(label="Status", interactive=False)
90
+ extracted_image = gr.Image(label="Extracted Images")
91
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
92
+ submit_button = gr.Button("Send")
93
+
94
+ uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
95
+ submit_button.click(chat_interface, question_input, chat_list)
96
+ question_input.submit(chat_interface, question_input, chat_list)
 
 
 
 
 
97
 
98
+ chat_ui.launch()