import faiss import torch import numpy as np from transformers import AutoTokenizer, AutoModel, pipeline import json import gradio as gr import matplotlib.pyplot as plt import tempfile import os class MedicalRAG: def __init__(self, embed_path, pmids_path, content_path): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load data self.embeddings = np.load(embed_path) self.index = self._create_faiss_index(self.embeddings) self.pmids, self.content = self._load_json_files(pmids_path, content_path) # Setup models self.encoder, self.tokenizer = self._setup_encoder() self.generator = self._setup_generator() def _create_faiss_index(self, embeddings): index = faiss.IndexFlatIP(768) # 768 is embedding dimension index.add(embeddings) return index def _load_json_files(self, pmids_path, content_path): with open(pmids_path) as f: pmids = json.load(f) with open(content_path) as f: content = json.load(f) return pmids, content def _setup_encoder(self): model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(self.device) tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") return model, tokenizer def _setup_generator(self): return pipeline( "text-generation", model="HuggingFaceTB/SmolLM2-1.7B-Instruct", device=self.device, torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32 ) def encode_query(self, query): with torch.no_grad(): inputs = self.tokenizer([query], truncation=True, padding=True, return_tensors='pt', max_length=64).to(self.device) embeddings = self.encoder(**inputs).last_hidden_state[:, 0, :] return embeddings.cpu().numpy() def search_documents(self, query_embedding, k=8): scores, indices = self.index.search(query_embedding, k=k) return [(self.pmids[idx], float(score)) for idx, score in zip(indices[0], scores[0])], indices[0] def get_document_content(self, pmid): doc = self.content.get(pmid, {}) return { 'title': doc.get('t', '').strip(), 'date': doc.get('d', '').strip(), 'abstract': doc.get('a', '').strip() } def visualize_embeddings(self, query_embed, relevant_indices, labels): plt.figure(figsize=(20, len(relevant_indices) + 1)) # Prepare embeddings for visualization embeddings = np.vstack([query_embed[0], self.embeddings[relevant_indices]]) normalized_embeddings = embeddings / np.max(np.abs(embeddings)) # plt for idx, (embedding, label) in enumerate(zip(normalized_embeddings, labels)): y_pos = len(labels) - 1 - idx plt.imshow(embedding.reshape(1, -1), aspect='auto', extent=[0, 768, y_pos, y_pos+0.8], cmap='inferno') # Add labels and styling plt.yticks(range(len(labels)), labels) plt.xlabel('Embedding Dimensions') plt.colorbar(label='Normalized Value') plt.title('Query and Retrieved Document Embeddings') # Save plot temp_path = os.path.join(tempfile.gettempdir(), f'embeddings_{hash(str(embeddings))}.png') plt.savefig(temp_path, bbox_inches='tight', dpi=150) plt.close() return temp_path def generate_answer(self, query, contexts): prompt = ( "<|im_start|>system\n" "You are a helpful medical assistant. Answer questions based on the provided literature." "<|im_end|>\n<|im_start|>user\n" f"Based on these medical articles, answer this question:\n\n" f"Question: {query}\n\n" f"Relevant Literature:\n{contexts}\n" "<|im_end|>\n<|im_start|>assistant" ) response = self.generator( prompt, max_new_tokens=200, temperature=0.3, top_p=0.95, do_sample=True ) return response[0]['generated_text'].split("<|im_start|>assistant")[-1].strip() def process_query(self, query): try: # Encode and search query_embed = self.encode_query(query) doc_matches, indices = self.search_documents(query_embed) # Prepare documents and labels documents = [] sources = [] labels = ["Query"] for pmid, score in doc_matches: doc = self.get_document_content(pmid) if doc['abstract']: documents.append(f"Title: {doc['title']}\nAbstract: {doc['abstract']}") sources.append(f"PMID: {pmid}, Score: {score:.3f}, Link: https://pubmed.ncbi.nlm.nih.gov/{pmid}/") labels.append(f"Doc {len(labels)}: {doc['title'][:30]}...") # Generate outputs visualization = self.visualize_embeddings(query_embed, indices, labels) answer = self.generate_answer(query, "\n\n".join(documents[:3])) sources_text = "\n".join(sources) context = "\n\n".join(documents) return answer, sources_text, context, visualization except Exception as e: print(f"Error: {str(e)}") return str(e), "Error retrieving sources", "", None def create_interface(): rag = MedicalRAG( embed_path="embeds_chunk_36.npy", pmids_path="pmids_chunk_36.json", content_path="pubmed_chunk_36.json" ) with gr.Blocks(title="Medical Literature QA") as interface: gr.Markdown("# Medical Literature Question Answering") with gr.Row(): with gr.Column(): query = gr.Textbox(lines=2, placeholder="Enter your medical question...", label="Question") submit = gr.Button("Submit", variant="primary") sources = gr.Textbox(label="Sources", lines=3) plot = gr.Image(label="Embedding Visualization") with gr.Column(): answer = gr.Textbox(label="Answer", lines=5) context = gr.Textbox(label="Context", lines=6) with gr.Row(): gr.Examples( examples=[ ["What are the latest treatments for diabetes?"], ["How effective are COVID-19 vaccines?"], ["What are common symptoms of the flu?"], ["How can I maintain good heart health?"] ], inputs=query ) submit.click( fn=rag.process_query, inputs=query, outputs=[answer, sources, context, plot] ) return interface if __name__ == "__main__": demo = create_interface() demo.launch(share=True)