Rohith1112 commited on
Commit
7fd41e1
·
verified ·
1 Parent(s): 58ce281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -210
app.py CHANGED
@@ -1,222 +1,101 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import gradio as gr
6
- import matplotlib.pyplot as plt
7
-
8
- device = torch.device('cpu')
9
- model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
10
-
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_name_or_path,
13
- torch_dtype=torch.float32,
14
- device_map='cpu',
15
- trust_remote_code=True
16
- )
17
- tokenizer = AutoTokenizer.from_pretrained(
18
- model_name_or_path,
19
- model_max_length=512,
20
- padding_side="right",
21
- use_fast=False,
22
- trust_remote_code=True
23
- )
24
-
25
- chat_history = []
26
- current_image = None
27
-
28
- def extract_and_display_images(image_path):
29
- npy_data = np.load(image_path)
30
- if npy_data.ndim == 4 and npy_data.shape[1] == 32:
31
- npy_data = npy_data[0]
32
- elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
33
- return "Invalid .npy format. Expected (1, 32, 256, 256) or (32, 256, 256)."
 
 
 
 
 
 
34
 
35
- fig, axes = plt.subplots(4, 8, figsize=(12, 6))
36
- for i, ax in enumerate(axes.flat):
37
- ax.imshow(npy_data[i], cmap='gray')
38
- ax.axis('off')
39
 
40
- output_path = "converted_image_preview.png"
41
- plt.savefig(output_path, bbox_inches='tight')
42
- plt.close()
43
- return output_path
44
 
45
- def upload_image(image):
46
- global current_image
47
- if image is None:
48
- return "", None
49
- current_image = image.name
50
- preview_path = extract_and_display_images(current_image)
51
- return "Image uploaded successfully!", preview_path
52
 
53
- def process_question(question):
54
- global current_image
55
- if current_image is None:
56
- return "Please upload an image first."
57
 
58
- image_np = np.load(current_image)
59
- image_tokens = "<im_patch>" * 256
60
- input_txt = image_tokens + question
61
- input_ids = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
62
 
63
- image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=torch.float32, device=device)
64
- generation = model.generate(image_pt, input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
65
- generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
66
- return generated_texts[0]
67
-
68
- def chat_with_model(user_message):
69
- global chat_history
70
- if not user_message.strip():
71
- return chat_history
72
- response = process_question(user_message)
73
- chat_history.append((user_message, response))
74
- return chat_history
75
-
76
- # Function to export chat history to a text file
77
- def export_chat_history():
78
- history_text = ""
79
- for user_msg, model_reply in chat_history:
80
- history_text += f"User: {user_msg}\nAI: {model_reply}\n\n"
81
- with open("chat_history.txt", "w") as f:
82
- f.write(history_text)
83
- return "Chat history exported as chat_history.txt"
84
-
85
- # UI
86
- with gr.Blocks(css="""
87
- body {
88
- background: #f5f5f5;
89
- font-family: 'Inter', sans-serif;
90
- color: #333333;
91
- }
92
-
93
- h1 {
94
- text-align: center;
95
- font-size: 2em;
96
- margin-bottom: 20px;
97
- color: #222;
98
- }
99
-
100
- .gr-box {
101
- background: #ffffff;
102
- padding: 20px;
103
- border-radius: 10px;
104
- box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
105
- }
106
-
107
- .gr-chatbot-container {
108
- overflow-y: auto;
109
- max-height: 500px;
110
- scroll-behavior: smooth;
111
- }
112
-
113
- .gr-chatbot-message {
114
- margin-bottom: 10px;
115
- padding: 8px;
116
- border-radius: 8px;
117
- background: #f5f5f5;
118
- animation: fadeIn 0.5s ease-out;
119
- }
120
-
121
- .gr-button {
122
- background-color: #4CAF50;
123
- color: white;
124
- border: none;
125
- padding: 8px 16px;
126
- border-radius: 6px;
127
- cursor: pointer;
128
- }
129
-
130
- .gr-button:hover {
131
- background-color: #45a049;
132
- }
133
 
134
- #message-box {
135
- display: flex;
136
- align-items: center;
137
- position: relative;
138
- transition: all 0.3s ease;
139
- }
140
 
141
- #message-input {
142
- width: 100%;
143
- padding: 10px;
144
- border-radius: 6px;
145
- border: 1px solid #ddd;
146
- font-size: 14px;
147
- margin-right: 40px; /* To give space for the icon */
148
- }
149
 
150
- #upload-icon {
151
- position: absolute;
152
- right: 10px;
153
- cursor: pointer;
154
- font-size: 24px;
155
- color: #4CAF50;
156
- animation: bounce 0.6s infinite alternate;
157
- }
158
 
159
- #upload-icon:hover {
160
- color: #45a049;
161
- transform: scale(1.1);
162
- }
 
163
 
164
- #loading-spinner {
165
- display: none;
166
- text-align: center;
167
- }
168
-
169
- #loading-spinner img {
170
- width: 50px;
171
- height: 50px;
172
- }
173
-
174
- @keyframes bounce {
175
- 0% { transform: translateY(0); }
176
- 100% { transform: translateY(-10px); }
177
- }
178
-
179
- @keyframes fadeIn {
180
- 0% { opacity: 0; }
181
- 100% { opacity: 1; }
182
- }
183
- """) as app:
184
- gr.Markdown("# AI Powered Medical Image Analysis System")
185
-
186
- with gr.Row():
187
- with gr.Column(scale=1):
188
- chatbot_ui = gr.Chatbot(value=[], label="Chat History")
189
- with gr.Column(scale=2):
190
- # Message input area with the '+' icon inside it
191
- with gr.Box(elem_id="message-box"):
192
- message_input = gr.Textbox(placeholder="Type your question here...", label="Your Message", elem_id="message-input")
193
- upload_button = gr.HTML('<span id="upload-icon">+</span>') # The + icon inside the message box
194
-
195
- upload_section = gr.File(label="Upload NPY Image", type="filepath", visible=False)
196
- upload_status = gr.Textbox(label="Status", interactive=False)
197
- preview_img = gr.Image(label="Image Preview", interactive=False)
198
- send_button = gr.Button("Send")
199
- export_button = gr.Button("Export Chat History")
200
- loading_spinner = gr.HTML('<div id="loading-spinner"><img src="https://i.imgur.com/llf5Jjs.gif" alt="Loading..."></div>')
201
-
202
- # Handle click event for the '+' icon
203
- upload_button.click(lambda: upload_section.update(visible=True), None, upload_section)
204
-
205
- # Handle file upload
206
- upload_section.upload(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), upload_section, None)
207
- upload_section.upload(upload_image, upload_section, [upload_status, preview_img])
208
-
209
- # Display loading spinner while processing question
210
- send_button.click(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), None, None)
211
- send_button.click(chat_with_model, message_input, chatbot_ui)
212
- send_button.click(lambda *args: loading_spinner.update(''), None, None)
213
- message_input.submit(chat_with_model, message_input, chatbot_ui)
214
-
215
- # Export chat history functionality
216
- export_button.click(export_chat_history)
217
-
218
- # Auto-focus typing box and scroll to bottom after message sent
219
- message_input.submit(lambda: gr.update(focus=True), None, message_input)
220
- send_button.click(lambda: gr.update(focus=True), None, message_input)
221
 
222
- app.launch()
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Model setup
9
+ device = torch.device('cpu') # Use 'cuda' if GPU is available
10
+ dtype = torch.float32
11
+ model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
12
+ proj_out_num = 256
13
+
14
+ # Load model and tokenizer
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_name_or_path,
17
+ torch_dtype=torch.float32,
18
+ device_map='cpu',
19
+ trust_remote_code=True
20
+ )
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_name_or_path,
24
+ model_max_length=512,
25
+ padding_side="right",
26
+ use_fast=False,
27
+ trust_remote_code=True
28
+ )
29
+
30
+ # Chat history storage
31
+ chat_history = []
32
+ current_image = None
33
+
34
+ def extract_and_display_images(image_path):
35
+ npy_data = np.load(image_path)
36
+ if npy_data.ndim == 4 and npy_data.shape[1] == 32:
37
+ npy_data = npy_data[0]
38
+ elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
39
+ return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
40
 
41
+ fig, axes = plt.subplots(4, 8, figsize=(12, 6))
42
+ for i, ax in enumerate(axes.flat):
43
+ ax.imshow(npy_data[i], cmap='gray')
44
+ ax.axis('off')
45
 
46
+ image_output = "extracted_images.png"
47
+ plt.savefig(image_output, bbox_inches='tight')
48
+ plt.close()
49
+ return image_output
50
 
 
 
 
 
 
 
 
51
 
52
+ def process_image(question):
53
+ global current_image
54
+ if current_image is None:
55
+ return "Please upload an image first."
56
 
57
+ image_np = np.load(current_image)
58
+ image_tokens = "<im_patch>" * proj_out_num
59
+ input_txt = image_tokens + question
60
+ input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
61
 
62
+ image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
63
+ generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
64
+ generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
65
+ return generated_texts[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
 
67
 
68
+ def chat_interface(question):
69
+ global chat_history
70
+ response = process_image(question)
71
+ chat_history.append((question, response))
72
+ return chat_history
 
 
 
73
 
 
 
 
 
 
 
 
 
74
 
75
+ def upload_image(image):
76
+ global current_image
77
+ current_image = image.name
78
+ extracted_image_path = extract_and_display_images(current_image)
79
+ return "Image uploaded and processed successfully!", extracted_image_path
80
 
81
+ # Gradio UI
82
+ with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
83
+ gr.Markdown("## ICliniq AI-Powered Medical Image Analysis Workspace")
84
+
85
+ with gr.Row():
86
+ with gr.Column(scale=1):
87
+ chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
88
+
89
+ with gr.Column(scale=2):
90
+ uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
91
+ upload_status = gr.Textbox(label="Status", interactive=False)
92
+ extracted_image = gr.Image(label="Extracted Images", type="filepath")
93
+ question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...", lines=2)
94
+ submit_button = gr.Button("Send")
95
+
96
+ # Upload and Processing Interactions
97
+ uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
98
+ submit_button.click(chat_interface, question_input, chat_list)
99
+ question_input.submit(chat_interface, question_input, chat_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ chat_ui.launch()