File size: 7,062 Bytes
ef6a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import faiss
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, pipeline
import json
import gradio as gr
import matplotlib.pyplot as plt
import tempfile
import os

class MedicalRAG:
    def __init__(self, embed_path, pmids_path, content_path):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Load data
        self.embeddings = np.load(embed_path)
        self.index = self._create_faiss_index(self.embeddings)
        self.pmids, self.content = self._load_json_files(pmids_path, content_path)
        # Setup models
        self.encoder, self.tokenizer = self._setup_encoder()
        self.generator = self._setup_generator()

    def _create_faiss_index(self, embeddings):
        index = faiss.IndexFlatIP(768)  # 768 is embedding dimension
        index.add(embeddings)
        return index

    def _load_json_files(self, pmids_path, content_path):
        with open(pmids_path) as f:
            pmids = json.load(f)
        with open(content_path) as f:
            content = json.load(f)
        return pmids, content

    def _setup_encoder(self):
        model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(self.device)
        tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")
        return model, tokenizer

    def _setup_generator(self):
        return pipeline(
            "text-generation",
            model="HuggingFaceTB/SmolLM2-1.7B-Instruct",
            device=self.device,
            torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32
        )

    def encode_query(self, query):
        with torch.no_grad():
            inputs = self.tokenizer([query], truncation=True, padding=True, 
                                  return_tensors='pt', max_length=64).to(self.device)
            embeddings = self.encoder(**inputs).last_hidden_state[:, 0, :]
            return embeddings.cpu().numpy()

    def search_documents(self, query_embedding, k=8):
        scores, indices = self.index.search(query_embedding, k=k)
        return [(self.pmids[idx], float(score)) for idx, score in zip(indices[0], scores[0])], indices[0]

    def get_document_content(self, pmid):
        doc = self.content.get(pmid, {})
        return {
            'title': doc.get('t', '').strip(),
            'date': doc.get('d', '').strip(),
            'abstract': doc.get('a', '').strip()
        }

    def visualize_embeddings(self, query_embed, relevant_indices, labels):
        plt.figure(figsize=(20, len(relevant_indices) + 1))
        
        # Prepare embeddings for visualization
        embeddings = np.vstack([query_embed[0], self.embeddings[relevant_indices]])
        normalized_embeddings = embeddings / np.max(np.abs(embeddings))
        # plt
        for idx, (embedding, label) in enumerate(zip(normalized_embeddings, labels)):
            y_pos = len(labels) - 1 - idx
            plt.imshow(embedding.reshape(1, -1), aspect='auto', extent=[0, 768, y_pos, y_pos+0.8],
                      cmap='inferno')
        
        # Add labels and styling
        plt.yticks(range(len(labels)), labels)
        plt.xlabel('Embedding Dimensions')
        plt.colorbar(label='Normalized Value')
        plt.title('Query and Retrieved Document Embeddings')
        
        # Save plot
        temp_path = os.path.join(tempfile.gettempdir(), f'embeddings_{hash(str(embeddings))}.png')
        plt.savefig(temp_path, bbox_inches='tight', dpi=150)
        plt.close()
        return temp_path

    def generate_answer(self, query, contexts):
        prompt = (
            "<|im_start|>system\n"
            "You are a helpful medical assistant. Answer questions based on the provided literature."
            "<|im_end|>\n<|im_start|>user\n"
            f"Based on these medical articles, answer this question:\n\n"
            f"Question: {query}\n\n"
            f"Relevant Literature:\n{contexts}\n"
            "<|im_end|>\n<|im_start|>assistant"
        )
        
        response = self.generator(
            prompt,
            max_new_tokens=200,
            temperature=0.3,
            top_p=0.95,
            do_sample=True
        )
        return response[0]['generated_text'].split("<|im_start|>assistant")[-1].strip()

    def process_query(self, query):
        try:
            # Encode and search
            query_embed = self.encode_query(query)
            doc_matches, indices = self.search_documents(query_embed)
            
            # Prepare documents and labels
            documents = []
            sources = []
            labels = ["Query"]
            
            for pmid, score in doc_matches:
                doc = self.get_document_content(pmid)
                if doc['abstract']:
                    documents.append(f"Title: {doc['title']}\nAbstract: {doc['abstract']}")
                    sources.append(f"PMID: {pmid}, Score: {score:.3f}, Link: https://pubmed.ncbi.nlm.nih.gov/{pmid}/")
                    labels.append(f"Doc {len(labels)}: {doc['title'][:30]}...")

            
            # Generate outputs
            visualization = self.visualize_embeddings(query_embed, indices, labels)
            answer = self.generate_answer(query, "\n\n".join(documents[:3]))
            sources_text = "\n".join(sources)
            context = "\n\n".join(documents)
            
            return answer, sources_text, context, visualization
            
        except Exception as e:
            print(f"Error: {str(e)}")
            return str(e), "Error retrieving sources", "", None
def create_interface():
    rag = MedicalRAG(
        embed_path="embeds_chunk_36.npy",
        pmids_path="pmids_chunk_36.json",
        content_path="pubmed_chunk_36.json"
    )
    
    with gr.Blocks(title="Medical Literature QA") as interface:
        gr.Markdown("# Medical Literature Question Answering")
        with gr.Row():
            with gr.Column():
                query = gr.Textbox(lines=2, placeholder="Enter your medical question...", label="Question")
                submit = gr.Button("Submit", variant="primary")
                sources = gr.Textbox(label="Sources", lines=3)
                plot = gr.Image(label="Embedding Visualization")
            with gr.Column():
                answer = gr.Textbox(label="Answer", lines=5)
                context = gr.Textbox(label="Context", lines=6)      
        with gr.Row():
            gr.Examples(
                examples=[
                    ["What are the latest treatments for diabetes?"],
                    ["How effective are COVID-19 vaccines?"],
                    ["What are common symptoms of the flu?"],
                    ["How can I maintain good heart health?"]
                ],
                inputs=query
            )
        
        submit.click(
            fn=rag.process_query,
            inputs=query,
            outputs=[answer, sources, context, plot]
        )
    
    return interface

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True)