Spaces:
Sleeping
Sleeping
| """ | |
| MedDiscover-HF: Hugging Face Spaces-ready Gradio app. | |
| - ZeroGPU-compatible (uses @spaces.GPU for heavy ops) | |
| - MedCPT embeddings + FAISS retrieval over uploaded PDFs | |
| - OSS generator model dropdown (gpt-oss-20b, gemma-3-12b-it, deepseek-vl2-small, granite vision/docling) | |
| """ | |
| import os | |
| import json | |
| import csv | |
| import uuid | |
| from datetime import datetime | |
| from pathlib import Path | |
| from threading import Thread | |
| from typing import List, Dict, Tuple | |
| import faiss | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from PyPDF2 import PdfReader | |
| from transformers import ( | |
| AutoModel, | |
| AutoTokenizer, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| pipeline, | |
| ) | |
| # ---------------------------- | |
| # Paths and env configuration | |
| # ---------------------------- | |
| BASE_DIR = Path(__file__).parent | |
| DATA_DIR = Path(os.getenv("DATA_DIR") or (BASE_DIR / "data")) | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| INDEX_PATH = DATA_DIR / "faiss_index.bin" | |
| META_PATH = DATA_DIR / "doc_metadata.json" | |
| LOGS_PATH = DATA_DIR / "logs.jsonl" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # set in Space secrets if needed | |
| HF_HOME = os.getenv("HF_HOME", str(DATA_DIR / ".cache")) | |
| os.environ["HF_HOME"] = HF_HOME | |
| # Force CPU on stateless ZeroGPU environments to avoid CUDA init errors | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
| # ---------------------------- | |
| # Chunking / PDF utils | |
| # ---------------------------- | |
| CHUNK_SIZE = 500 | |
| OVERLAP = 50 | |
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[str]: | |
| words = text.split() | |
| chunks = [] | |
| start = 0 | |
| while start < len(words): | |
| end = start + chunk_size | |
| chunk = words[start:end] | |
| if not chunk: | |
| break | |
| chunks.append(" ".join(chunk)) | |
| start += (chunk_size - overlap) | |
| return chunks | |
| def extract_text_from_pdf(path: str) -> str: | |
| buff = [] | |
| try: | |
| reader = PdfReader(path) | |
| for page in reader.pages: | |
| text = page.extract_text() | |
| if text: | |
| buff.append(text) | |
| except Exception as e: # pragma: no cover | |
| return f"Error reading {path}: {e}" | |
| return "\n".join(buff) | |
| # ---------------------------- | |
| # Embeddings: MedCPT encoders | |
| # ---------------------------- | |
| MEDCPT_ARTICLE = "ncbi/MedCPT-Article-Encoder" | |
| MEDCPT_QUERY = "ncbi/MedCPT-Query-Encoder" | |
| MAX_ART_LEN = 512 | |
| MAX_QUERY_LEN = 64 | |
| EMBED_DIM = 768 | |
| _article_tok = None | |
| _article_model = None | |
| _query_tok = None | |
| _query_model = None | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_medcpt(): | |
| global _article_tok, _article_model, _query_tok, _query_model | |
| if _article_model and _query_model: | |
| return | |
| _article_tok = AutoTokenizer.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN) | |
| _article_model = AutoModel.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN) | |
| _article_model.to(DEVICE).eval() | |
| _query_tok = AutoTokenizer.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN) | |
| _query_model = AutoModel.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN) | |
| _query_model.to(DEVICE).eval() | |
| def embed_chunks(chunks: List[str]) -> np.ndarray: | |
| load_medcpt() | |
| all_vecs = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(chunks), 8): | |
| batch = chunks[i : i + 8] | |
| enc = _article_tok( | |
| batch, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt", | |
| max_length=MAX_ART_LEN, | |
| ).to(DEVICE) | |
| out = _article_model(**enc) | |
| vec = out.last_hidden_state[:, 0, :].cpu().numpy() | |
| all_vecs.append(vec) | |
| if not all_vecs: | |
| return np.array([]) | |
| return np.vstack(all_vecs) | |
| def embed_query(query: str) -> np.ndarray: | |
| load_medcpt() | |
| with torch.no_grad(): | |
| enc = _query_tok( | |
| query, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt", | |
| max_length=MAX_QUERY_LEN, | |
| ).to(DEVICE) | |
| out = _query_model(**enc) | |
| vec = out.last_hidden_state[:, 0, :].cpu().numpy() | |
| return vec | |
| # ---------------------------- | |
| # FAISS index helpers | |
| # ---------------------------- | |
| def build_index(embeddings: np.ndarray) -> faiss.IndexFlatIP: | |
| if embeddings.dtype != np.float32: | |
| embeddings = embeddings.astype(np.float32) | |
| index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index.add(embeddings) | |
| return index | |
| def save_index(index: faiss.IndexFlatIP, meta: List[Dict]): | |
| faiss.write_index(index, str(INDEX_PATH)) | |
| META_PATH.write_text(json.dumps(meta, indent=2), encoding="utf-8") | |
| def load_index() -> Tuple[faiss.IndexFlatIP, List[Dict]]: | |
| if not INDEX_PATH.exists() or not META_PATH.exists(): | |
| return None, None | |
| idx = faiss.read_index(str(INDEX_PATH)) | |
| meta = json.loads(META_PATH.read_text(encoding="utf-8")) | |
| return idx, meta | |
| def search(index: faiss.IndexFlatIP, meta: List[Dict], query_vec: np.ndarray, k: int) -> List[Dict]: | |
| if query_vec.dtype != np.float32: | |
| query_vec = query_vec.astype(np.float32) | |
| scores, inds = index.search(query_vec, k) | |
| candidates = [] | |
| for score, ind in zip(scores[0], inds[0]): | |
| if ind < 0 or ind >= len(meta): | |
| continue | |
| item = dict(meta[ind]) | |
| item["retrieval_score"] = float(score) | |
| candidates.append(item) | |
| return candidates | |
| # ---------------------------- | |
| # Model registry for generators | |
| # ---------------------------- | |
| class GeneratorWrapper: | |
| def __init__(self, name: str, load_fn, fallback=None, fallback_msg: str | None = None): | |
| self.name = name | |
| self._load_fn = load_fn | |
| self._pipe = None | |
| self._fallback = fallback | |
| self._fallback_msg = fallback_msg | |
| self._note = None | |
| def ensure(self): | |
| if self._pipe is None: | |
| try: | |
| self._pipe = self._load_fn() | |
| self._note = None | |
| except Exception as exc: | |
| print(f"[Generator:{self.name}] load failed: {exc}") | |
| if self._fallback: | |
| print(f"[Generator:{self.name}] falling back to {self._fallback.name}") | |
| self._pipe = self._fallback.ensure() | |
| self._note = self._fallback_msg or f"Falling back to {self._fallback.name}." | |
| else: | |
| raise | |
| return self._pipe | |
| def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float): | |
| pipe = self.ensure() | |
| streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True) | |
| inputs = pipe.tokenizer(prompt, return_tensors="pt") | |
| device = getattr(pipe.model, "device", torch.device("cpu")) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "streamer": streamer, | |
| "return_dict_in_generate": True, | |
| "output_scores": False, | |
| "use_cache": False, # avoid DynamicCache issues on Phi-3 CPU | |
| } | |
| def _run(): | |
| try: | |
| pipe.model.generate(**inputs, **gen_kwargs) | |
| except Exception as exc: | |
| if self._fallback: | |
| print(f"[Generator:{self.name}] generate failed: {exc}; falling back to {self._fallback.name}") | |
| self._pipe = self._fallback.ensure() | |
| note = self._fallback_msg or f"Falling back to {self._fallback.name}." | |
| if note: | |
| streamer.put(note + " ") | |
| fb_stream = self._fallback.generate_stream(prompt, max_new_tokens, temperature, top_p) | |
| for tok in fb_stream: | |
| streamer.put(tok) | |
| else: | |
| print(f"[Generator:{self.name}] generate failed: {exc}") | |
| streamer.end() | |
| Thread(target=_run, daemon=True).start() | |
| if self._note: | |
| yield self._note + " " | |
| self._note = None | |
| for token in streamer: | |
| yield token | |
| def load_gpt_oss(): | |
| raise RuntimeError("gpt-oss-20b is disabled on ZeroGPU (too large)") | |
| def load_tinyllama(): | |
| # CPU-friendly small chat model to keep ZeroGPU happy. | |
| return pipeline( | |
| "text-generation", | |
| model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| ) | |
| def load_phi3_mini(): | |
| pipe = pipeline( | |
| "text-generation", | |
| model="microsoft/Phi-3-mini-4k-instruct", | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| model_kwargs={ | |
| "use_cache": False, | |
| "attn_implementation": "eager", | |
| }, | |
| ) | |
| # Disable cache to avoid DynamicCache.seen_tokens errors on ZeroGPU/CPU. | |
| try: | |
| pipe.model.config.use_cache = False | |
| pipe.model.generation_config.use_cache = False | |
| pipe.model.generation_config.cache_implementation = "static" | |
| except Exception: | |
| pass | |
| return pipe | |
| _tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama) | |
| _phi_wrapper = GeneratorWrapper( | |
| "phi-3-mini-4k", | |
| load_phi3_mini, | |
| fallback=_tiny_wrapper, | |
| fallback_msg="Phi-3-mini-4k unavailable on this Space (CUDA blocked); falling back to TinyLlama CPU.", | |
| ) | |
| GENERATORS = { | |
| "tinyllama-1.1b-chat": _tiny_wrapper, | |
| "phi-3-mini-4k": _phi_wrapper, | |
| } | |
| # ---------------------------- | |
| # Prompt formatting | |
| # ---------------------------- | |
| def format_prompt(query: str, contexts: List[Dict]) -> str: | |
| context_blocks = [] | |
| for i, c in enumerate(contexts): | |
| context_blocks.append( | |
| f"--- Context {i+1} (file={c.get('filename','N/A')} chunk={c.get('chunk_id','?')}) ---\n{c.get('text','')}" | |
| ) | |
| joined = "\n\n".join(context_blocks) if context_blocks else "None." | |
| prompt = ( | |
| "You are MedDiscover, a biomedical assistant. Use ONLY the provided context to answer concisely.\n" | |
| "If the context does not contain the answer, reply: 'Not found in provided documents.'\n\n" | |
| f"{joined}\n\nQuestion: {query}\nAnswer:" | |
| ) | |
| return prompt | |
| # ---------------------------- | |
| # Gradio callbacks | |
| # ---------------------------- | |
| def ensure_session_state(session_state): | |
| if not isinstance(session_state, dict): | |
| session_state = {} | |
| if not session_state.get("id"): | |
| session_state["id"] = str(uuid.uuid4()) | |
| if "records" not in session_state or not isinstance(session_state["records"], list): | |
| session_state["records"] = [] | |
| return session_state | |
| def append_log_record(record: Dict): | |
| LOGS_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with LOGS_PATH.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(record) + "\n") | |
| def process_pdfs(files, progress=gr.Progress()): | |
| if not files: | |
| return "Upload PDFs first." | |
| texts = [] | |
| meta = [] | |
| doc_id = 0 | |
| for idx, f in enumerate(files): | |
| progress(((idx) / max(len(files), 1)), desc=f"Reading {Path(f.name).name}") | |
| text = extract_text_from_pdf(f.name) | |
| if not text or text.startswith("Error reading"): | |
| continue | |
| chunks = chunk_text(text) | |
| for cid, chunk in enumerate(chunks): | |
| meta.append({"doc_id": doc_id, "filename": Path(f.name).name, "chunk_id": cid, "text": chunk}) | |
| texts.append(chunk) | |
| doc_id += 1 | |
| if not texts: | |
| return "No text extracted." | |
| progress(0.7, desc=f"Embedding {len(texts)} chunks") | |
| embeds = embed_chunks(texts) | |
| if embeds.size == 0: | |
| return "Embedding failed." | |
| progress(0.85, desc="Building index") | |
| idx = build_index(embeds) | |
| save_index(idx, meta) | |
| progress(1.0, desc="Ready") | |
| return f"Processed {doc_id} PDFs. Index size={idx.ntotal}, dim={idx.d}. Saved to {DATA_DIR}." | |
| def handle_query(query, model_key, k, max_new_tokens, temperature, top_p, session_state): | |
| session_state = ensure_session_state(session_state) | |
| if not query or query.strip() == "": | |
| return "Enter a query", "No context", session_state | |
| idx, meta = load_index() | |
| if idx is None or meta is None: | |
| return "Index not ready. Process PDFs first.", "No context", session_state | |
| qvec = embed_query(query) | |
| cands = search(idx, meta, qvec, int(k)) | |
| prompt = format_prompt(query, cands) | |
| wrapper = GENERATORS[model_key] | |
| stream = wrapper.generate_stream(prompt, int(max_new_tokens), float(temperature), float(top_p)) | |
| answer_accum = "" | |
| for chunk in stream: | |
| answer_accum += chunk | |
| yield answer_accum, prompt, session_state | |
| record = { | |
| "timestamp": datetime.utcnow().isoformat() + "Z", | |
| "session_id": session_state["id"], | |
| "question": query, | |
| "answer": answer_accum, | |
| "context_chunks": cands, | |
| "model": model_key, | |
| "k": int(k), | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_new_tokens": int(max_new_tokens), | |
| } | |
| session_state["records"].append(record) | |
| append_log_record(record) | |
| yield answer_accum, prompt, session_state | |
| def export_session_json(session_state): | |
| session_state = ensure_session_state(session_state) | |
| records = session_state.get("records", []) | |
| if not records: | |
| return None, "No session records to export." | |
| out_path = DATA_DIR / f"session-{session_state['id']}.json" | |
| with out_path.open("w", encoding="utf-8") as f: | |
| json.dump(records, f, ensure_ascii=False, indent=2) | |
| return str(out_path), f"Exported {len(records)} records to JSON." | |
| def export_session_csv(session_state): | |
| session_state = ensure_session_state(session_state) | |
| records = session_state.get("records", []) | |
| if not records: | |
| return None, "No session records to export." | |
| out_path = DATA_DIR / f"session-{session_state['id']}.csv" | |
| fields = [ | |
| "timestamp", | |
| "session_id", | |
| "question", | |
| "answer", | |
| "model", | |
| "k", | |
| "temperature", | |
| "top_p", | |
| "max_new_tokens", | |
| "context", | |
| ] | |
| with out_path.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=fields) | |
| writer.writeheader() | |
| for rec in records: | |
| ctx = " ||| ".join([c.get("text", "") for c in rec.get("context_chunks", [])]) | |
| writer.writerow( | |
| { | |
| "timestamp": rec.get("timestamp", ""), | |
| "session_id": rec.get("session_id", ""), | |
| "question": rec.get("question", ""), | |
| "answer": rec.get("answer", ""), | |
| "model": rec.get("model", ""), | |
| "k": rec.get("k", ""), | |
| "temperature": rec.get("temperature", ""), | |
| "top_p": rec.get("top_p", ""), | |
| "max_new_tokens": rec.get("max_new_tokens", ""), | |
| "context": ctx, | |
| } | |
| ) | |
| return str(out_path), f"Exported {len(records)} records to CSV." | |
| def clear_session(session_state): | |
| session_state = ensure_session_state(session_state) | |
| session_state["records"] = [] | |
| session_state["id"] = str(uuid.uuid4()) | |
| return session_state, "Session cleared." | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="MedDiscover") as demo: | |
| gr.Markdown("# 🩺 MedDiscover\nRetrieval Augmented Generation with Large Language Models for Biomedical Discovery Presented by,\nVatsal Patel, Elena Jolkver, Anne Schwerk\nIU International University of Applied Science, Germany") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_info = gr.Markdown("") | |
| pdfs = gr.File(label="Upload PDFs", file_types=[".pdf"], file_count="multiple") | |
| process_btn = gr.Button("Process PDFs (chunk/embed/index)", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| model_dd = gr.Dropdown( | |
| label="Generator Model", | |
| choices=list(GENERATORS.keys()), | |
| value="tinyllama-1.1b-chat", | |
| interactive=True, | |
| ) | |
| k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks") | |
| max_tokens = gr.Slider(20, 512, value=128, step=8, label="Max new tokens") | |
| temp = gr.Slider(0.1, 1.5, value=0.4, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| with gr.Group(): | |
| gr.Markdown("Session Logs") | |
| export_json_btn = gr.Button("Export session (JSON)") | |
| export_csv_btn = gr.Button("Export session (CSV)") | |
| clear_session_btn = gr.Button("Clear session") | |
| session_status = gr.Textbox(label="Session status", interactive=False) | |
| download_json = gr.File(label="JSON export", interactive=False) | |
| download_csv = gr.File(label="CSV export", interactive=False) | |
| with gr.Column(scale=2): | |
| query = gr.Textbox(label="Query", lines=3, placeholder="Ask about your documents...") | |
| answer = gr.Textbox(label="Answer (streaming)", lines=6) | |
| context_box = gr.Textbox(label="Context used in prompt", lines=14) | |
| go_btn = gr.Button("Ask", variant="primary") | |
| session_state = gr.State({"id": None, "records": []}) | |
| process_btn.click(fn=process_pdfs, inputs=pdfs, outputs=status, show_progress="full") | |
| go_btn.click( | |
| fn=handle_query, | |
| inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state], | |
| outputs=[answer, context_box, session_state], | |
| concurrency_limit=1, | |
| ) | |
| query.submit( | |
| fn=handle_query, | |
| inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state], | |
| outputs=[answer, context_box, session_state], | |
| concurrency_limit=1, | |
| ) | |
| export_json_btn.click(fn=export_session_json, inputs=session_state, outputs=[download_json, session_status]) | |
| export_csv_btn.click(fn=export_session_csv, inputs=session_state, outputs=[download_csv, session_status]) | |
| clear_session_btn.click(fn=clear_session, inputs=session_state, outputs=[session_state, session_status]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |