ICLINIQ / app.py
Rohith1112's picture
Update app.py
982538d verified
raw
history blame
6.19 kB
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import matplotlib.pyplot as plt
device = torch.device('cpu')
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float32,
device_map='cpu',
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
chat_history = []
current_image = None
def extract_and_display_images(image_path):
npy_data = np.load(image_path)
if npy_data.ndim == 4 and npy_data.shape[1] == 32:
npy_data = npy_data[0]
elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
return "Invalid .npy format. Expected (1, 32, 256, 256) or (32, 256, 256)."
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(npy_data[i], cmap='gray')
ax.axis('off')
output_path = "converted_image_preview.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
return output_path
def upload_image(image):
global current_image
if image is None:
return "", None
current_image = image.name
preview_path = extract_and_display_images(current_image)
return "Image uploaded successfully!", preview_path
def process_question(question):
global current_image
if current_image is None:
return "Please upload an image first."
image_np = np.load(current_image)
image_tokens = "<im_patch>" * 256
input_txt = image_tokens + question
input_ids = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=torch.float32, device=device)
generation = model.generate(image_pt, input_ids, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
return generated_texts[0]
def chat_with_model(user_message):
global chat_history
if not user_message.strip():
return chat_history
response = process_question(user_message)
chat_history.append((user_message, response))
return chat_history
# Function to export chat history to a text file
def export_chat_history():
history_text = ""
for user_msg, model_reply in chat_history:
history_text += f"User: {user_msg}\nAI: {model_reply}\n\n"
with open("chat_history.txt", "w") as f:
f.write(history_text)
return "Chat history exported as chat_history.txt"
# UI
with gr.Blocks(css="""
body {
background: #f5f5f5;
font-family: 'Inter', sans-serif;
color: #333333;
}
h1 {
text-align: center;
font-size: 2em;
margin-bottom: 20px;
color: #222;
}
.gr-box {
background: #ffffff;
padding: 20px;
border-radius: 10px;
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
}
.gr-chatbot-container {
overflow-y: auto;
max-height: 500px;
scroll-behavior: smooth;
}
.gr-chatbot-message {
margin-bottom: 10px;
padding: 8px;
border-radius: 8px;
background: #f5f5f5;
animation: fadeIn 0.5s ease-out;
}
.gr-button {
background-color: #4CAF50;
color: white;
border: none;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
}
.gr-button:hover {
background-color: #45a049;
}
.gr-upload-btn {
background-color: #4CAF50;
color: white;
border-radius: 50%;
width: 50px;
height: 50px;
font-size: 24px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
border: none;
margin-top: 10px;
}
#loading-spinner {
display: none;
text-align: center;
}
#loading-spinner img {
width: 50px;
height: 50px;
}
@keyframes fadeIn {
0% { opacity: 0; }
100% { opacity: 1; }
}
""") as app:
gr.Markdown("# AI Powered Medical Image Analysis System")
with gr.Row():
with gr.Column(scale=1):
chatbot_ui = gr.Chatbot(value=[], label="Chat History")
with gr.Column(scale=2):
# Create the "+" button for uploading
upload_button = gr.Button("+", elem_id="upload_btn", visible=True, interactive=True)
upload_section = gr.File(label="Upload NPY Image", type="filepath", visible=False)
upload_status = gr.Textbox(label="Status", interactive=False)
preview_img = gr.Image(label="Image Preview", interactive=False)
message_input = gr.Textbox(placeholder="Type your question here...", label="Your Message")
send_button = gr.Button("Send")
export_button = gr.Button("Export Chat History")
loading_spinner = gr.HTML('<div id="loading-spinner"><img src="https://i.imgur.com/llf5Jjs.gif" alt="Loading..."></div>')
# Handle file upload when "+" button is clicked
upload_button.click(lambda: upload_section.update(visible=True), None, upload_section)
# Display loading spinner when uploading an image
upload_section.upload(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), upload_section, None)
upload_section.upload(upload_image, upload_section, [upload_status, preview_img])
# Display loading spinner while processing question
send_button.click(lambda *args: loading_spinner.update("<div id='loading-spinner'><img src='https://i.imgur.com/llf5Jjs.gif' alt='Loading...'></div>"), None, None)
send_button.click(chat_with_model, message_input, chatbot_ui)
send_button.click(lambda *args: loading_spinner.update(''), None, None)
message_input.submit(chat_with_model, message_input, chatbot_ui)
# Export chat history functionality
export_button.click(export_chat_history)
# Auto-focus typing box and scroll to bottom after message sent
message_input.submit(lambda: gr.update(focus=True), None, message_input)
send_button.click(lambda: gr.update(focus=True), None, message_input)
app.launch()