File size: 5,111 Bytes
28f8c56
b598f3e
 
 
cc17b6f
28f8c56
aff94f6
 
cc17b6f
7e44175
aff94f6
28f8c56
51a883e
7e44175
cc17b6f
7e44175
9c21d4a
 
aff94f6
 
9c21d4a
 
cc17b6f
9c21d4a
 
 
 
 
 
 
cc17b6f
aff94f6
 
 
7e44175
 
28f8c56
 
 
aff94f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4836e10
7e44175
 
 
51a883e
9c21d4a
aff94f6
 
 
 
 
 
 
 
 
 
 
 
28f8c56
7e44175
 
51a883e
 
7e44175
 
e19a001
7e44175
 
28f8c56
 
e19a001
aff94f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e44175
 
aff94f6
51a883e
 
 
 
 
 
28f8c56
51a883e
 
7e44175
28f8c56
7e44175
 
e19a001
aff94f6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import matplotlib.pyplot as plt
from datasets import load_dataset
from evaluate import load  # For evaluation metrics

# Model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available
dtype = torch.float32
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Phi-3-4B'
proj_out_num = 256

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=dtype,
    device_map=device.type,
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=512,
    padding_side="right",
    use_fast=False,
    trust_remote_code=True
)

# Load the M3D-Cap dataset
dataset = load_dataset("GoodBaiBai88/M3D-Cap")

# Chat history storage
chat_history = []
current_image = None

def extract_and_display_images(image_path):
    try:
        npy_data = np.load(image_path)
        if npy_data.ndim == 4 and npy_data.shape[1] == 32:
            npy_data = npy_data[0]
        elif npy_data.ndim != 3 or npy_data.shape[0] != 32:
            return "Invalid .npy file format. Expected shape (1, 32, 256, 256) or (32, 256, 256)."
        
        fig, axes = plt.subplots(4, 8, figsize=(12, 6))
        for i, ax in enumerate(axes.flat):
            ax.imshow(npy_data[i], cmap='gray')
            ax.axis('off')
        
        image_output = "extracted_images.png"
        plt.savefig(image_output, bbox_inches='tight')
        plt.close()
        return image_output
    except Exception as e:
        return f"Error processing image: {str(e)}"

def process_image(question):
    global current_image
    if current_image is None:
        return "Please upload an image first."
    
    try:
        image_np = np.load(current_image)
        image_tokens = "<im_patch>" * proj_out_num
        input_txt = image_tokens + question
        input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
        
        image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
        generation = model.generate(image_pt, input_id, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
        generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
        return generated_texts[0]
    except Exception as e:
        return f"Error generating response: {str(e)}"

def chat_interface(question):
    global chat_history
    response = process_image(question)
    chat_history.append((question, response))
    return chat_history

def upload_image(image):
    global current_image
    current_image = image.name
    extracted_image_path = extract_and_display_images(current_image)
    return "Image uploaded and processed successfully!", extracted_image_path

def test_model_with_dataset():
    # Load evaluation metrics
    bleu = load("bleu")
    rouge = load("rouge")

    # Initialize lists to store predictions and references
    predictions = []
    references = []

    # Iterate over the dataset
    for example in dataset['train']:  # Use 'train', 'validation', or 'test' split
        image_path = example['image']  # Assuming 'image' contains the path to the .npy file
        question = example['caption']  # Assuming 'caption' contains the question or caption

        # Upload the image
        upload_image({"name": image_path})

        # Get the model's response
        response = process_image(question)

        # Store predictions and references
        predictions.append(response)
        references.append(question)

        # Print results for debugging
        print(f"Question: {question}")
        print(f"Model Response: {response}")
        print("---")

    # Compute evaluation metrics
    bleu_score = bleu.compute(predictions=predictions, references=references)
    rouge_score = rouge.compute(predictions=predictions, references=references)

    print(f"BLEU Score: {bleu_score}")
    print(f"ROUGE Score: {rouge_score}")

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as chat_ui:
    gr.Markdown("ICliniq AI-Powered Medical Image Analysis Workspace")
    with gr.Row():
        with gr.Column(scale=1, min_width=200):
            chat_list = gr.Chatbot(value=[], label="Chat History", elem_id="chat-history")
        with gr.Column(scale=4):
            uploaded_image = gr.File(label="Upload .npy Image", type="filepath")
            upload_status = gr.Textbox(label="Status", interactive=False)
            extracted_image = gr.Image(label="Extracted Images")
            question_input = gr.Textbox(label="Ask a question", placeholder="Ask something about the image...")
            submit_button = gr.Button("Send")
    
    uploaded_image.upload(upload_image, uploaded_image, [upload_status, extracted_image])
    submit_button.click(chat_interface, question_input, chat_list)
    question_input.submit(chat_interface, question_input, chat_list)

# Uncomment to test the model with the dataset
# test_model_with_dataset()

chat_ui.launch()