Spaces:
Running
Running
File size: 5,111 Bytes
28f8c56 b598f3e cc17b6f 28f8c56 aff94f6 cc17b6f 7e44175 aff94f6 28f8c56 51a883e 7e44175 cc17b6f 7e44175 9c21d4a aff94f6 9c21d4a cc17b6f 9c21d4a cc17b6f aff94f6 7e44175 28f8c56 aff94f6 4836e10 7e44175 51a883e 9c21d4a aff94f6 28f8c56 7e44175 51a883e 7e44175 e19a001 7e44175 28f8c56 e19a001 aff94f6 7e44175 aff94f6 51a883e 28f8c56 51a883e 7e44175 28f8c56 7e44175 e19a001 aff94f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import matplotlib.pyplot as plt
from datasets import load_dataset
from evaluate import load # For evaluation metrics
# Model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
dtype = torch.float32
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map=device.type,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
# Load the M3D-Cap dataset
dataset = load_dataset("GoodBaiBai88/M3D-Cap")
# Chat history storage
chat_history = []
current_image = None
def extract_and_display_images(image_path):
try:
npy_data = np.load(image_path)
if npy_data.ndim == 4 and npy_data.shape[1] == 32:
npy_data = npy_data[0]
elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(npy_data[i], cmap='gray')
ax.axis('off')
image_output = "extracted_images.png"
plt.savefig(image_output, bbox_inches='tight')
plt.close()
return image_output
except Exception as e:
return f"Error processing image: {str(e)}"
def process_image(question):
global current_image
if current_image is None:
return "Please upload an image first."
try:
image_np = np.load(current_image)
image_tokens = "<im_patch>" * proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
return generated_texts[0]
except Exception as e:
return f"Error generating response: {str(e)}"
def chat_interface(question):
global chat_history
response = process_image(question)
chat_history.append((question, response))
return chat_history
def upload_image(image):
global current_image
current_image = image.name
extracted_image_path = extract_and_display_images(current_image)
return "Image uploaded and processed successfully!", extracted_image_path
def test_model_with_dataset():
# Load evaluation metrics
bleu = load("bleu")
rouge = load("rouge")
# Initialize lists to store predictions and references
predictions = []
references = []
# Iterate over the dataset
for example in dataset['train']: # Use 'train', 'validation', or 'test' split
image_path = example['image'] # Assuming 'image' contains the path to the .npy file
question = example['caption'] # Assuming 'caption' contains the question or caption
# Upload the image
upload_image({"name": image_path})
# Get the model's response
response = process_image(question)
# Store predictions and references
predictions.append(response)
references.append(question)
# Print results for debugging
print(f"Question: {question}")
print(f"Model Response: {response}")
print("---")
# Compute evaluation metrics
bleu_score = bleu.compute(predictions=predictions, references=references)
rouge_score = rouge.compute(predictions=predictions, references=references)
print(f"BLEU Score: {bleu_score}")
print(f"ROUGE Score: {rouge_score}")
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
with gr.Row():
with gr.Column(scale=1, min_width=200):
chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
with gr.Column(scale=4):
uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
upload_status = gr.Textbox(label="Status", interactive=False)
extracted_image = gr.Image(label="Extracted Images")
question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
submit_button = gr.Button("Send")
uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
submit_button.click(chat_interface, question_input, chat_list)
question_input.submit(chat_interface, question_input, chat_list)
# Uncomment to test the model with the dataset
# test_model_with_dataset()
chat_ui.launch() |