mgbam commited on
Commit
b8986a1
·
1 Parent(s): 879a34e

Add application file

Browse files
Files changed (7) hide show
  1. README.md +86 -5
  2. app.py +38 -0
  3. backend.py +42 -0
  4. mini_ladder.py +56 -0
  5. requirements.txt +9 -0
  6. retrieval.py +126 -0
  7. visualization.py +42 -0
README.md CHANGED
@@ -1,12 +1,93 @@
1
  ---
2
- title: Medic
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.43.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Med
3
+ emoji: 🏆
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.43.2
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Medical field with next‑gen technology
11
  ---
12
 
13
+
14
+ # AI-Powered Medical Knowledge Graph Assistant (Mini-LADDER Demo)
15
+
16
+ This repository demonstrates a Streamlit application that:
17
+ - Retrieves PubMed abstracts via NCBI’s E-utilities.
18
+ - Indexes and retrieves relevant documents using ChromaDB.
19
+ - Generates biomedical answers using Microsoft BioGPT-Large-PubMedQA.
20
+ - Applies a two-stage self-improvement mechanism (inspired by Tufa Labs’ LADDER) that:
21
+ - Generates naive sub-questions.
22
+ - Produces an initial answer.
23
+ - Self-critiques and refines the answer.
24
+ - Visualizes key terms in a knowledge graph using PyVis.
25
+
26
+ ## Key Features
27
+
28
+ 1. **PubMed + Chroma**: Retrieves and indexes relevant abstracts.
29
+ 2. **BioGPT-Large-PubMedQA**: Generates an initial answer.
30
+ 3. **Mini-LADDER Approach**:
31
+ - **Sub-Question Decomposition**: Generates sub-questions from the main query.
32
+ - **Self-Critique & Refinement**: Uses a second pass to critique and refine the answer.
33
+ 4. **Interactive Knowledge Graph**: Displays a PyVis graph of the top documents and key terms.
34
+
35
+ ## Setup Instructions
36
+
37
+ 1. **Install Dependencies**
38
+
39
+ Create a virtual environment (optional) and install the required packages:
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ Set Environment Variables
43
+
44
+ Set your PubMed API key:
45
+
46
+ bash
47
+ Copy
48
+ export PUBMED_API_KEY=<YOUR_NCBI_API_KEY>
49
+ Run the App
50
+
51
+ Launch the Streamlit app:
52
+
53
+ bash
54
+ Copy
55
+ streamlit run app.py
56
+ Access the App
57
+
58
+ Open your browser at http://localhost:8501.
59
+
60
+ Project Structure
61
+ Copy
62
+ .
63
+ ├── app.py
64
+ ├── backend.py
65
+ ├── mini_ladder.py
66
+ ├── retrieval.py
67
+ ├── visualization.py
68
+ ├── README.md
69
+ └── requirements.txt
70
+ About the Mini-LADDER Approach
71
+ Inspired by Tufa Labs’ LADDER (Learning through Autonomous Difficulty-Driven Example Recursion), this demo shows how a model might:
72
+
73
+ Decompose a query into simpler sub-questions.
74
+ Generate an initial answer using retrieval-augmented generation.
75
+ Self-critique and refine the answer based on detected gaps.
76
+ Ultimately, this approach hints at how autonomous, recursive learning could be implemented.
77
+ Enjoy exploring potential extensions into code generation, theorem proving, or other domains!
78
+
79
+ yaml
80
+ Copy
81
+
82
+ ---
83
+
84
+ ## File: `requirements.txt`
85
+
86
+ ```txt
87
+ streamlit
88
+ pyvis
89
+ chromadb
90
+ transformers
91
+ sentence-transformers
92
+ torch
93
+ requests
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from backend import process_medical_query, docs_cache
4
+ from visualization import create_medical_graph
5
+
6
+ def main():
7
+ st.title("AI-Powered Medical Knowledge Graph Assistant")
8
+ st.markdown(
9
+ "**Using BioGPT-Large-PubMedQA + PubMed + Chroma** for advanced retrieval-augmented generation."
10
+ )
11
+
12
+ user_query = st.text_input("Enter biomedical/medical query", "Malaria and cough treatment")
13
+ if st.button("Submit"):
14
+ with st.spinner("Generating answer..."):
15
+ final_answer, sub_questions, initial_answer, critique = process_medical_query(user_query)
16
+
17
+ st.subheader("Sub-Question Decomposition")
18
+ st.write(sub_questions)
19
+
20
+ st.subheader("Initial AI Answer")
21
+ st.write(initial_answer)
22
+
23
+ st.subheader("Self-Critique")
24
+ st.write(critique)
25
+
26
+ st.subheader("Refined AI Answer")
27
+ st.write(final_answer)
28
+
29
+ st.subheader("Knowledge Graph")
30
+ docs = docs_cache.get(user_query, [])
31
+ if docs:
32
+ graph_html = create_medical_graph(user_query, docs)
33
+ components.html(graph_html, height=600, scrolling=True)
34
+ else:
35
+ st.info("No documents to visualize.")
36
+
37
+ if __name__ == "__main__":
38
+ main()
backend.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from retrieval import get_relevant_pubmed_docs
3
+ from mini_ladder import generate_sub_questions, self_critique_and_refine
4
+
5
+ # Use Microsoft BioGPT-Large-PubMedQA for generation
6
+ MODEL_NAME = "microsoft/BioGPT-Large-PubMedQA"
7
+ qa_pipeline = pipeline("text-generation", model=MODEL_NAME)
8
+
9
+ # In-memory cache for documents (used for graph generation)
10
+ docs_cache = {}
11
+
12
+ def process_medical_query(query: str):
13
+ """
14
+ Processes the query in four steps:
15
+ 1. Generate sub-questions.
16
+ 2. Retrieve relevant PubMed documents.
17
+ 3. Generate an initial answer.
18
+ 4. Self-critique and refine the answer.
19
+ """
20
+ # Step 1: Generate sub-questions (naively)
21
+ sub_questions = generate_sub_questions(query)
22
+
23
+ # Step 2: Retrieve relevant documents via PubMed and Chroma
24
+ relevant_docs = get_relevant_pubmed_docs(query)
25
+ docs_cache[query] = relevant_docs
26
+
27
+ if not relevant_docs:
28
+ return ("No documents found for this query.", sub_questions, "", "")
29
+
30
+ # Step 3: Generate an initial answer
31
+ context_text = "\n\n".join(relevant_docs)
32
+ prompt = f"Question: {query}\nContext: {context_text}\nAnswer:"
33
+ initial_gen = qa_pipeline(prompt, max_new_tokens=100, truncation=True)
34
+ if initial_gen and isinstance(initial_gen, list):
35
+ initial_answer = initial_gen[0]["generated_text"]
36
+ else:
37
+ initial_answer = "No answer found."
38
+
39
+ # Step 4: Self-critique and refine the answer
40
+ final_answer, critique = self_critique_and_refine(query, initial_answer, relevant_docs)
41
+
42
+ return (final_answer, sub_questions, initial_answer, critique)
mini_ladder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ # A second pipeline for self-critique (using a lighter model for demonstration)
4
+ CRITIQUE_MODEL = "gpt2" # This can be replaced with another model as needed
5
+ critique_pipeline = pipeline("text-generation", model=CRITIQUE_MODEL)
6
+
7
+ def generate_sub_questions(main_query: str):
8
+ """
9
+ Naively generates sub-questions for the given main query.
10
+ """
11
+ return [
12
+ f"1) What are common causes of {main_query}?",
13
+ f"2) Which medications are typically used for {main_query}?",
14
+ f"3) What are non-pharmacological approaches to {main_query}?"
15
+ ]
16
+
17
+ def self_critique_and_refine(query: str, initial_answer: str, docs: list):
18
+ """
19
+ Critiques the initial answer and refines it if necessary.
20
+ """
21
+ # Step 1: Generate a critique using a critique prompt
22
+ critique_prompt = (
23
+ f"The following is an answer to the question '{query}'. "
24
+ "Evaluate its correctness, clarity, and completeness. "
25
+ "List any missing details or inaccuracies.\n\n"
26
+ f"ANSWER:\n{initial_answer}\n\n"
27
+ "CRITIQUE:"
28
+ )
29
+ critique_gen = critique_pipeline(critique_prompt, max_new_tokens=80, truncation=True)
30
+ if critique_gen and isinstance(critique_gen, list):
31
+ critique_text = critique_gen[0]["generated_text"]
32
+ else:
33
+ critique_text = "No critique generated."
34
+
35
+ # Step 2: If the critique suggests issues, refine the answer using the original QA pipeline.
36
+ if any(word in critique_text.lower() for word in ["missing", "incomplete", "incorrect", "lacks"]):
37
+ refine_prompt = (
38
+ f"Question: {query}\n"
39
+ f"Current Answer: {initial_answer}\n"
40
+ f"Critique: {critique_text}\n"
41
+ "Refine the answer by adding missing or corrected information. "
42
+ "Use the context below if needed:\n\n"
43
+ + "\n\n".join(docs)
44
+ + "\nREFINED ANSWER:"
45
+ )
46
+ # Import the qa_pipeline from backend to reuse it (local import to avoid circular dependencies)
47
+ from backend import qa_pipeline
48
+ refined_gen = qa_pipeline(refine_prompt, max_new_tokens=120, truncation=True)
49
+ if refined_gen and isinstance(refined_gen, list):
50
+ refined_answer = refined_gen[0]["generated_text"]
51
+ else:
52
+ refined_answer = initial_answer
53
+ else:
54
+ refined_answer = initial_answer
55
+
56
+ return refined_answer, critique_text
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ uvicorn
2
+ sacremoses
3
+ streamlit
4
+ pyvis
5
+ chromadb
6
+ transformers
7
+ sentence-transformers
8
+ torch
9
+ requests
retrieval.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import torch
4
+ from typing import List
5
+ import chromadb
6
+ from chromadb.config import Settings
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ # Optional: Set your PubMed API key from environment variables
10
+ PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
11
+
12
+ #############################################
13
+ # 1) FETCH PUBMED ABSTRACTS
14
+ #############################################
15
+ def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
16
+ """
17
+ Fetches PubMed abstracts for the given query using NCBI's E-utilities.
18
+ Returns a list of abstract texts.
19
+ """
20
+ search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
21
+ params = {
22
+ "db": "pubmed",
23
+ "term": query,
24
+ "retmax": max_results,
25
+ "api_key": PUBMED_API_KEY,
26
+ "retmode": "json"
27
+ }
28
+ r = requests.get(search_url, params=params)
29
+ r.raise_for_status()
30
+ data = r.json()
31
+
32
+ pmid_list = data["esearchresult"].get("idlist", [])
33
+ abstracts = []
34
+ fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
35
+ for pmid in pmid_list:
36
+ fetch_params = {
37
+ "db": "pubmed",
38
+ "id": pmid,
39
+ "rettype": "abstract",
40
+ "retmode": "text",
41
+ "api_key": PUBMED_API_KEY
42
+ }
43
+ fetch_resp = requests.get(fetch_url, params=fetch_params)
44
+ fetch_resp.raise_for_status()
45
+ abstract_text = fetch_resp.text.strip()
46
+ if abstract_text:
47
+ abstracts.append(abstract_text)
48
+ return abstracts
49
+
50
+ #############################################
51
+ # 2) CHROMA + EMBEDDINGS SETUP
52
+ #############################################
53
+ class EmbedFunction:
54
+ """
55
+ Wraps a Hugging Face embedding model to produce embeddings for a list of strings.
56
+ """
57
+ def __init__(self, model_name: str):
58
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
59
+ self.model = AutoModel.from_pretrained(model_name)
60
+ self.model.eval()
61
+
62
+ def __call__(self, input: List[str]) -> List[List[float]]:
63
+ if not input:
64
+ return []
65
+ tokenized = self.tokenizer(
66
+ input,
67
+ return_tensors="pt",
68
+ padding=True,
69
+ truncation=True,
70
+ max_length=512
71
+ )
72
+ with torch.no_grad():
73
+ outputs = self.model(**tokenized, output_hidden_states=True)
74
+ last_hidden = outputs.hidden_states[-1]
75
+ pooled = last_hidden.mean(dim=1)
76
+ embeddings = pooled.cpu().tolist()
77
+ return embeddings
78
+
79
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
80
+ embed_function = EmbedFunction(EMBED_MODEL_NAME)
81
+
82
+ client = chromadb.Client(
83
+ settings=Settings(
84
+ persist_directory="chromadb_data",
85
+ anonymized_telemetry=False
86
+ )
87
+ )
88
+
89
+ # Updated collection name for clarity.
90
+ collection = client.get_or_create_collection(
91
+ name="ai_medical_knowledge",
92
+ embedding_function=embed_function
93
+ )
94
+
95
+ def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
96
+ """
97
+ Adds documents to the Chroma collection with unique IDs.
98
+ """
99
+ for i, doc in enumerate(docs):
100
+ if doc.strip():
101
+ doc_id = f"{prefix}-{i}"
102
+ collection.add(documents=[doc], ids=[doc_id])
103
+
104
+ def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
105
+ """
106
+ Retrieves the top_k similar documents from Chroma based on embedding similarity.
107
+ """
108
+ results = collection.query(query_texts=[query], n_results=top_k)
109
+ return results["documents"][0] if results and results["documents"] else []
110
+
111
+ #############################################
112
+ # 3) MAIN RETRIEVAL PIPELINE
113
+ #############################################
114
+ def get_relevant_pubmed_docs(user_query: str) -> List[str]:
115
+ """
116
+ End-to-end pipeline:
117
+ 1. Fetch PubMed abstracts for the query.
118
+ 2. Index them in Chroma.
119
+ 3. Retrieve the top relevant documents.
120
+ """
121
+ new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
122
+ if not new_abstracts:
123
+ return []
124
+ index_pubmed_docs(new_abstracts, prefix=user_query)
125
+ top_docs = query_similar_docs(user_query, top_k=3)
126
+ return top_docs
visualization.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import tempfile
3
+ import os
4
+ from pyvis.network import Network
5
+
6
+ def extract_key_terms(text: str):
7
+ """
8
+ A naive approach to extract key terms by matching capitalized words.
9
+ """
10
+ return re.findall(r"\b[A-Z][a-zA-Z]+\b", text)
11
+
12
+ def create_medical_graph(query: str, docs: list) -> str:
13
+ """
14
+ Builds a PyVis network:
15
+ - A central "QUERY" node.
16
+ - A node for each retrieved document.
17
+ - Sub-nodes for extracted key terms.
18
+ Returns the full HTML of the generated graph.
19
+ """
20
+ net = Network(height="600px", width="100%", directed=False)
21
+ net.add_node("QUERY", label=f"Query: {query}", color="red", shape="star")
22
+
23
+ for i, doc in enumerate(docs):
24
+ doc_id = f"Doc_{i}"
25
+ net.add_node(doc_id, label=f"Abstract {i+1}", color="blue")
26
+ net.add_edge("QUERY", doc_id)
27
+
28
+ terms = extract_key_terms(doc)
29
+ for term in set(terms):
30
+ term_id = f"{doc_id}_{term}"
31
+ net.add_node(term_id, label=term, color="green")
32
+ net.add_edge(doc_id, term_id)
33
+
34
+ # Write the network HTML to a temporary file and return its content
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp:
36
+ temp_filename = tmp.name
37
+ net.show(temp_filename)
38
+
39
+ with open(temp_filename, "r", encoding="utf-8") as f:
40
+ html_content = f.read()
41
+ os.remove(temp_filename)
42
+ return html_content