Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
device = torch.device('cpu') | |
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B' | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
torch_dtype=torch.float32, | |
device_map='cpu', | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
model_max_length=512, | |
padding_side="right", | |
use_fast=False, | |
trust_remote_code=True | |
) | |
chat_history = [] | |
current_image = None | |
def extract_and_display_images(image_path): | |
npy_data = np.load(image_path) | |
if npy_data.ndim == 4 and npy_data.shape[1] == 32: | |
npy_data = npy_data[0] | |
elif npy_data.ndim != 3 or npy_data.shape[0] != 32: | |
return "Invalid .npy format. Expected (1, 32, 256, 256) or (32, 256, 256)." | |
fig, axes = plt.subplots(4, 8, figsize=(12, 6)) | |
for i, ax in enumerate(axes.flat): | |
ax.imshow(npy_data[i], cmap='gray') | |
ax.axis('off') | |
output_path = "converted_image_preview.png" | |
plt.savefig(output_path, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def upload_image(image): | |
global current_image | |
if image is None: | |
return "", None | |
current_image = image.name | |
preview_path = extract_and_display_images(current_image) | |
return "Image uploaded successfully!", preview_path | |
def process_question(question): | |
global current_image | |
if current_image is None: | |
return "Please upload an image first." | |
image_np = np.load(current_image) | |
image_tokens = "<im_patch>" * 256 | |
input_txt = image_tokens + question | |
input_ids = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device) | |
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=torch.float32, device=device) | |
generation = model.generate(image_pt, input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) | |
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True) | |
return generated_texts[0] | |
def chat_with_model(user_message): | |
global chat_history | |
if not user_message.strip(): | |
return chat_history | |
response = process_question(user_message) | |
chat_history.append((user_message, response)) | |
return chat_history | |
# Function to export chat history to a text file | |
def export_chat_history(): | |
history_text = "" | |
for user_msg, model_reply in chat_history: | |
history_text += f"User: {user_msg}\nAI: {model_reply}\n\n" | |
with open("chat_history.txt", "w") as f: | |
f.write(history_text) | |
return "Chat history exported as chat_history.txt" | |
# UI | |
with gr.Blocks(css=""" | |
body { | |
background: #f5f5f5; | |
font-family: 'Inter', sans-serif; | |
color: #333333; | |
} | |
h1 { | |
text-align: center; | |
font-size: 2em; | |
margin-bottom: 20px; | |
color: #222; | |
} | |
.gr-box { | |
background: #ffffff; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1); | |
} | |
.gr-chatbot-container { | |
overflow-y: auto; | |
max-height: 500px; | |
scroll-behavior: smooth; | |
} | |
.gr-chatbot-message { | |
margin-bottom: 10px; | |
padding: 8px; | |
border-radius: 8px; | |
background: #f5f5f5; | |
animation: fadeIn 0.5s ease-out; | |
} | |
.gr-button { | |
background-color: #4CAF50; | |
color: white; | |
border: none; | |
padding: 8px 16px; | |
border-radius: 6px; | |
cursor: pointer; | |
} | |
.gr-button:hover { | |
background-color: #45a049; | |
} | |
.gr-upload-btn { | |
background-color: #4CAF50; | |
color: white; | |
border-radius: 50%; | |
width: 50px; | |
height: 50px; | |
font-size: 24px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
cursor: pointer; | |
border: none; | |
margin-top: 10px; | |
} | |
#loading-spinner { | |
display: none; | |
text-align: center; | |
} | |
#loading-spinner img { | |
width: 50px; | |
height: 50px; | |
} | |
@keyframes fadeIn { | |
0% { opacity: 0; } | |
100% { opacity: 1; } | |
} | |
""") as app: | |
gr.Markdown("# AI Powered Medical Image Analysis System") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot_ui = gr.Chatbot(value=[], label="Chat History") | |
with gr.Column(scale=2): | |
# Create the "+" button for uploading | |
upload_button = gr.Button("+", elem_id="upload_btn", visible=True, interactive=True) | |
upload_section = gr.File(label="Upload NPY Image", type="filepath", visible=False) | |
upload_status = gr.Textbox(label="Status", interactive=False) | |
preview_img = gr.Image(label="Image Preview", interactive=False) | |
message_input = gr.Textbox(placeholder="Type your question here...", label="Your Message") | |
send_button = gr.Button("Send") | |
export_button = gr.Button("Export Chat History") | |
loading_spinner = gr.HTML('<div id="loading-spinner"><img src="https://i.imgur.com/llf5Jjs.gif" alt="Loading..."></div>') | |
# Handle file upload when "+" button is clicked | |
upload_button.click(lambda: upload_section.update(visible=True), None, upload_section) | |
# Display loading spinner when uploading an image | |
upload_section.upload(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), upload_section, None) | |
upload_section.upload(upload_image, upload_section, [upload_status, preview_img]) | |
# Display loading spinner while processing question | |
send_button.click(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), None, None) | |
send_button.click(chat_with_model, message_input, chatbot_ui) | |
send_button.click(lambda *args: loading_spinner.update(''), None, None) | |
message_input.submit(chat_with_model, message_input, chatbot_ui) | |
# Export chat history functionality | |
export_button.click(export_chat_history) | |
# Auto-focus typing box and scroll to bottom after message sent | |
message_input.submit(lambda: gr.update(focus=True), None, message_input) | |
send_button.click(lambda: gr.update(focus=True), None, message_input) | |
app.launch() | |