File size: 6,187 Bytes
28f8c56
b598f3e
 
 
cc17b6f
28f8c56
cc17b6f
26c257c
51a883e
cc17b6f
9c21d4a
 
e7bcdf0
 
9c21d4a
 
 
 
 
 
 
 
 
cc17b6f
7e44175
28f8c56
 
 
e7bcdf0
 
 
 
f94d228
e7bcdf0
 
 
 
 
 
f94d228
 
e7bcdf0
f94d228
e7bcdf0
26c257c
 
 
 
 
 
 
 
f94d228
287acbd
7e44175
287acbd
9c21d4a
e7bcdf0
f94d228
e7bcdf0
f94d228
e7bcdf0
f94d228
 
287acbd
 
 
f94d228
 
26c257c
 
f94d228
 
 
 
26c257c
 
 
 
 
 
 
 
 
 
287acbd
f94d228
26c257c
 
 
f94d228
 
 
 
26c257c
f94d228
26c257c
 
 
 
 
 
 
 
f94d228
 
 
 
 
26c257c
f94d228
 
 
26c257c
 
 
 
 
f94d228
 
 
26c257c
f94d228
26c257c
 
 
 
f94d228
 
 
26c257c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bef4b4
 
 
 
 
 
 
 
26c257c
 
 
 
 
f94d228
 
26c257c
f94d228
51a883e
26c257c
fb44cbf
f94d228
982538d
fb44cbf
 
 
 
 
 
 
0bef4b4
26c257c
 
 
 
 
2b392fa
 
26c257c
 
2b392fa
26c257c
2b392fa
26c257c
 
 
 
 
 
 
 
e19a001
f94d228
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import matplotlib.pyplot as plt

device = torch.device('cpu')
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.float32,
    device_map='cpu',
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=512,
    padding_side="right",
    use_fast=False,
    trust_remote_code=True
)

chat_history = []
current_image = None

def extract_and_display_images(image_path):
    npy_data = np.load(image_path)
    if npy_data.ndim == 4 and npy_data.shape[1] == 32:
        npy_data = npy_data[0]
    elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
        return "Invalid .npy format. Expected (1, 32, 256, 256) or (32, 256, 256)."
    
    fig, axes = plt.subplots(4, 8, figsize=(12, 6))
    for i, ax in enumerate(axes.flat):
        ax.imshow(npy_data[i], cmap='gray')
        ax.axis('off')
    
    output_path = "converted_image_preview.png"
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    return output_path

def upload_image(image):
    global current_image
    if image is None:
        return "", None
    current_image = image.name
    preview_path = extract_and_display_images(current_image)
    return "Image uploaded successfully!", preview_path

def process_question(question):
    global current_image
    if current_image is None:
        return "Please upload an image first."
    
    image_np = np.load(current_image)
    image_tokens = "<im_patch>" * 256
    input_txt = image_tokens + question
    input_ids = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
    
    image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=torch.float32, device=device)
    generation = model.generate(image_pt, input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
    generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
    return generated_texts[0]

def chat_with_model(user_message):
    global chat_history
    if not user_message.strip():
        return chat_history
    response = process_question(user_message)
    chat_history.append((user_message, response))
    return chat_history

# Function to export chat history to a text file
def export_chat_history():
    history_text = ""
    for user_msg, model_reply in chat_history:
        history_text += f"User: {user_msg}\nAI: {model_reply}\n\n"
    with open("chat_history.txt", "w") as f:
        f.write(history_text)
    return "Chat history exported as chat_history.txt"

# UI
with gr.Blocks(css="""
body {
    background: #f5f5f5;
    font-family: 'Inter', sans-serif;
    color: #333333;
}

h1 {
    text-align: center;
    font-size: 2em;
    margin-bottom: 20px;
    color: #222;
}

.gr-box {
    background: #ffffff;
    padding: 20px;
    border-radius: 10px;
    box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
}

.gr-chatbot-container {
    overflow-y: auto;
    max-height: 500px;
    scroll-behavior: smooth;
}

.gr-chatbot-message {
    margin-bottom: 10px;
    padding: 8px;
    border-radius: 8px;
    background: #f5f5f5;
    animation: fadeIn 0.5s ease-out;
}

.gr-button {
    background-color: #4CAF50;
    color: white;
    border: none;
    padding: 8px 16px;
    border-radius: 6px;
    cursor: pointer;
}

.gr-button:hover {
    background-color: #45a049;
}

.gr-upload-btn {
    background-color: #4CAF50;
    color: white;
    border-radius: 50%;
    width: 50px;
    height: 50px;
    font-size: 24px;
    display: flex;
    align-items: center;
    justify-content: center;
    cursor: pointer;
    border: none;
    margin-top: 10px;
}

#loading-spinner {
    display: none;
    text-align: center;
}

#loading-spinner img {
    width: 50px;
    height: 50px;
}

@keyframes fadeIn {
    0% { opacity: 0; }
    100% { opacity: 1; }
}
""") as app:
    gr.Markdown("# AI Powered Medical Image Analysis System")

    with gr.Row():
        with gr.Column(scale=1):
            chatbot_ui = gr.Chatbot(value=[], label="Chat History")
        with gr.Column(scale=2):
            # Create the "+" button for uploading
            upload_button = gr.Button("+", elem_id="upload_btn", visible=True, interactive=True)
            upload_section = gr.File(label="Upload NPY Image", type="filepath", visible=False)
            upload_status = gr.Textbox(label="Status", interactive=False)
            preview_img = gr.Image(label="Image Preview", interactive=False)
            message_input = gr.Textbox(placeholder="Type your question here...", label="Your Message")
            send_button = gr.Button("Send")
            export_button = gr.Button("Export Chat History")
            loading_spinner = gr.HTML('<div id="loading-spinner"><img src="https://i.imgur.com/llf5Jjs.gif" alt="Loading..."></div>')

    # Handle file upload when "+" button is clicked
    upload_button.click(lambda: upload_section.update(visible=True), None, upload_section)

    # Display loading spinner when uploading an image
    upload_section.upload(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), upload_section, None)
    upload_section.upload(upload_image, upload_section, [upload_status, preview_img])

    # Display loading spinner while processing question
    send_button.click(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), None, None)
    send_button.click(chat_with_model, message_input, chatbot_ui)
    send_button.click(lambda *args: loading_spinner.update(''), None, None)
    message_input.submit(chat_with_model, message_input, chatbot_ui)

    # Export chat history functionality
    export_button.click(export_chat_history)

    # Auto-focus typing box and scroll to bottom after message sent
    message_input.submit(lambda: gr.update(focus=True), None, message_input)
    send_button.click(lambda: gr.update(focus=True), None, message_input)

app.launch()