""" MedDiscover-HF: Hugging Face Spaces-ready Gradio app. - ZeroGPU-compatible (uses @spaces.GPU for heavy ops) - MedCPT embeddings + FAISS retrieval over uploaded PDFs - OSS generator model dropdown (gpt-oss-20b, gemma-3-12b-it, deepseek-vl2-small, granite vision/docling) """ import os import json import csv import uuid from datetime import datetime from pathlib import Path from threading import Thread from typing import List, Dict, Tuple import faiss import gradio as gr import numpy as np import spaces import torch from PyPDF2 import PdfReader from transformers import ( AutoModel, AutoTokenizer, AutoProcessor, TextIteratorStreamer, pipeline, ) # ---------------------------- # Paths and env configuration # ---------------------------- BASE_DIR = Path(__file__).parent DATA_DIR = Path(os.getenv("DATA_DIR") or (BASE_DIR / "data")) DATA_DIR.mkdir(parents=True, exist_ok=True) INDEX_PATH = DATA_DIR / "faiss_index.bin" META_PATH = DATA_DIR / "doc_metadata.json" LOGS_PATH = DATA_DIR / "logs.jsonl" HF_TOKEN = os.getenv("HF_TOKEN") # set in Space secrets if needed HF_HOME = os.getenv("HF_HOME", str(DATA_DIR / ".cache")) os.environ["HF_HOME"] = HF_HOME # Force CPU on stateless ZeroGPU environments to avoid CUDA init errors os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") # ---------------------------- # Chunking / PDF utils # ---------------------------- CHUNK_SIZE = 500 OVERLAP = 50 def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[str]: words = text.split() chunks = [] start = 0 while start < len(words): end = start + chunk_size chunk = words[start:end] if not chunk: break chunks.append(" ".join(chunk)) start += (chunk_size - overlap) return chunks def extract_text_from_pdf(path: str) -> str: buff = [] try: reader = PdfReader(path) for page in reader.pages: text = page.extract_text() if text: buff.append(text) except Exception as e: # pragma: no cover return f"Error reading {path}: {e}" return "\n".join(buff) # ---------------------------- # Embeddings: MedCPT encoders # ---------------------------- MEDCPT_ARTICLE = "ncbi/MedCPT-Article-Encoder" MEDCPT_QUERY = "ncbi/MedCPT-Query-Encoder" MAX_ART_LEN = 512 MAX_QUERY_LEN = 64 EMBED_DIM = 768 _article_tok = None _article_model = None _query_tok = None _query_model = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_medcpt(): global _article_tok, _article_model, _query_tok, _query_model if _article_model and _query_model: return _article_tok = AutoTokenizer.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN) _article_model = AutoModel.from_pretrained(MEDCPT_ARTICLE, use_auth_token=HF_TOKEN) _article_model.to(DEVICE).eval() _query_tok = AutoTokenizer.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN) _query_model = AutoModel.from_pretrained(MEDCPT_QUERY, use_auth_token=HF_TOKEN) _query_model.to(DEVICE).eval() @spaces.GPU() def embed_chunks(chunks: List[str]) -> np.ndarray: load_medcpt() all_vecs = [] with torch.no_grad(): for i in range(0, len(chunks), 8): batch = chunks[i : i + 8] enc = _article_tok( batch, truncation=True, padding=True, return_tensors="pt", max_length=MAX_ART_LEN, ).to(DEVICE) out = _article_model(**enc) vec = out.last_hidden_state[:, 0, :].cpu().numpy() all_vecs.append(vec) if not all_vecs: return np.array([]) return np.vstack(all_vecs) @spaces.GPU() def embed_query(query: str) -> np.ndarray: load_medcpt() with torch.no_grad(): enc = _query_tok( query, truncation=True, padding=True, return_tensors="pt", max_length=MAX_QUERY_LEN, ).to(DEVICE) out = _query_model(**enc) vec = out.last_hidden_state[:, 0, :].cpu().numpy() return vec # ---------------------------- # FAISS index helpers # ---------------------------- def build_index(embeddings: np.ndarray) -> faiss.IndexFlatIP: if embeddings.dtype != np.float32: embeddings = embeddings.astype(np.float32) index = faiss.IndexFlatIP(embeddings.shape[1]) index.add(embeddings) return index def save_index(index: faiss.IndexFlatIP, meta: List[Dict]): faiss.write_index(index, str(INDEX_PATH)) META_PATH.write_text(json.dumps(meta, indent=2), encoding="utf-8") def load_index() -> Tuple[faiss.IndexFlatIP, List[Dict]]: if not INDEX_PATH.exists() or not META_PATH.exists(): return None, None idx = faiss.read_index(str(INDEX_PATH)) meta = json.loads(META_PATH.read_text(encoding="utf-8")) return idx, meta def search(index: faiss.IndexFlatIP, meta: List[Dict], query_vec: np.ndarray, k: int) -> List[Dict]: if query_vec.dtype != np.float32: query_vec = query_vec.astype(np.float32) scores, inds = index.search(query_vec, k) candidates = [] for score, ind in zip(scores[0], inds[0]): if ind < 0 or ind >= len(meta): continue item = dict(meta[ind]) item["retrieval_score"] = float(score) candidates.append(item) return candidates # ---------------------------- # Model registry for generators # ---------------------------- class GeneratorWrapper: def __init__(self, name: str, load_fn, fallback=None, fallback_msg: str | None = None): self.name = name self._load_fn = load_fn self._pipe = None self._fallback = fallback self._fallback_msg = fallback_msg self._note = None def ensure(self): if self._pipe is None: try: self._pipe = self._load_fn() self._note = None except Exception as exc: print(f"[Generator:{self.name}] load failed: {exc}") if self._fallback: print(f"[Generator:{self.name}] falling back to {self._fallback.name}") self._pipe = self._fallback.ensure() self._note = self._fallback_msg or f"Falling back to {self._fallback.name}." else: raise return self._pipe def generate_stream(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float): pipe = self.ensure() streamer = TextIteratorStreamer(pipe.tokenizer, skip_special_tokens=True, skip_prompt=True) inputs = pipe.tokenizer(prompt, return_tensors="pt") device = getattr(pipe.model, "device", torch.device("cpu")) inputs = {k: v.to(device) for k, v in inputs.items()} gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "streamer": streamer, "return_dict_in_generate": True, "output_scores": False, "use_cache": False, # avoid DynamicCache issues on Phi-3 CPU } def _run(): try: pipe.model.generate(**inputs, **gen_kwargs) except Exception as exc: if self._fallback: print(f"[Generator:{self.name}] generate failed: {exc}; falling back to {self._fallback.name}") self._pipe = self._fallback.ensure() note = self._fallback_msg or f"Falling back to {self._fallback.name}." if note: streamer.put(note + " ") fb_stream = self._fallback.generate_stream(prompt, max_new_tokens, temperature, top_p) for tok in fb_stream: streamer.put(tok) else: print(f"[Generator:{self.name}] generate failed: {exc}") streamer.end() Thread(target=_run, daemon=True).start() if self._note: yield self._note + " " self._note = None for token in streamer: yield token def load_gpt_oss(): raise RuntimeError("gpt-oss-20b is disabled on ZeroGPU (too large)") def load_tinyllama(): # CPU-friendly small chat model to keep ZeroGPU happy. return pipeline( "text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="cpu", torch_dtype=torch.float32, ) def load_phi3_mini(): pipe = pipeline( "text-generation", model="microsoft/Phi-3-mini-4k-instruct", device_map="cpu", torch_dtype=torch.float32, trust_remote_code=True, model_kwargs={ "use_cache": False, "attn_implementation": "eager", }, ) # Disable cache to avoid DynamicCache.seen_tokens errors on ZeroGPU/CPU. try: pipe.model.config.use_cache = False pipe.model.generation_config.use_cache = False pipe.model.generation_config.cache_implementation = "static" except Exception: pass return pipe _tiny_wrapper = GeneratorWrapper("tinyllama-1.1b-chat", load_tinyllama) _phi_wrapper = GeneratorWrapper( "phi-3-mini-4k", load_phi3_mini, fallback=_tiny_wrapper, fallback_msg="Phi-3-mini-4k unavailable on this Space (CUDA blocked); falling back to TinyLlama CPU.", ) GENERATORS = { "tinyllama-1.1b-chat": _tiny_wrapper, "phi-3-mini-4k": _phi_wrapper, } # ---------------------------- # Prompt formatting # ---------------------------- def format_prompt(query: str, contexts: List[Dict]) -> str: context_blocks = [] for i, c in enumerate(contexts): context_blocks.append( f"--- Context {i+1} (file={c.get('filename','N/A')} chunk={c.get('chunk_id','?')}) ---\n{c.get('text','')}" ) joined = "\n\n".join(context_blocks) if context_blocks else "None." prompt = ( "You are MedDiscover, a biomedical assistant. Use ONLY the provided context to answer concisely.\n" "If the context does not contain the answer, reply: 'Not found in provided documents.'\n\n" f"{joined}\n\nQuestion: {query}\nAnswer:" ) return prompt # ---------------------------- # Gradio callbacks # ---------------------------- def ensure_session_state(session_state): if not isinstance(session_state, dict): session_state = {} if not session_state.get("id"): session_state["id"] = str(uuid.uuid4()) if "records" not in session_state or not isinstance(session_state["records"], list): session_state["records"] = [] return session_state def append_log_record(record: Dict): LOGS_PATH.parent.mkdir(parents=True, exist_ok=True) with LOGS_PATH.open("a", encoding="utf-8") as f: f.write(json.dumps(record) + "\n") @spaces.GPU() def process_pdfs(files, progress=gr.Progress()): if not files: return "Upload PDFs first." texts = [] meta = [] doc_id = 0 for idx, f in enumerate(files): progress(((idx) / max(len(files), 1)), desc=f"Reading {Path(f.name).name}") text = extract_text_from_pdf(f.name) if not text or text.startswith("Error reading"): continue chunks = chunk_text(text) for cid, chunk in enumerate(chunks): meta.append({"doc_id": doc_id, "filename": Path(f.name).name, "chunk_id": cid, "text": chunk}) texts.append(chunk) doc_id += 1 if not texts: return "No text extracted." progress(0.7, desc=f"Embedding {len(texts)} chunks") embeds = embed_chunks(texts) if embeds.size == 0: return "Embedding failed." progress(0.85, desc="Building index") idx = build_index(embeds) save_index(idx, meta) progress(1.0, desc="Ready") return f"Processed {doc_id} PDFs. Index size={idx.ntotal}, dim={idx.d}. Saved to {DATA_DIR}." def handle_query(query, model_key, k, max_new_tokens, temperature, top_p, session_state): session_state = ensure_session_state(session_state) if not query or query.strip() == "": return "Enter a query", "No context", session_state idx, meta = load_index() if idx is None or meta is None: return "Index not ready. Process PDFs first.", "No context", session_state qvec = embed_query(query) cands = search(idx, meta, qvec, int(k)) prompt = format_prompt(query, cands) wrapper = GENERATORS[model_key] stream = wrapper.generate_stream(prompt, int(max_new_tokens), float(temperature), float(top_p)) answer_accum = "" for chunk in stream: answer_accum += chunk yield answer_accum, prompt, session_state record = { "timestamp": datetime.utcnow().isoformat() + "Z", "session_id": session_state["id"], "question": query, "answer": answer_accum, "context_chunks": cands, "model": model_key, "k": int(k), "temperature": float(temperature), "top_p": float(top_p), "max_new_tokens": int(max_new_tokens), } session_state["records"].append(record) append_log_record(record) yield answer_accum, prompt, session_state def export_session_json(session_state): session_state = ensure_session_state(session_state) records = session_state.get("records", []) if not records: return None, "No session records to export." out_path = DATA_DIR / f"session-{session_state['id']}.json" with out_path.open("w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False, indent=2) return str(out_path), f"Exported {len(records)} records to JSON." def export_session_csv(session_state): session_state = ensure_session_state(session_state) records = session_state.get("records", []) if not records: return None, "No session records to export." out_path = DATA_DIR / f"session-{session_state['id']}.csv" fields = [ "timestamp", "session_id", "question", "answer", "model", "k", "temperature", "top_p", "max_new_tokens", "context", ] with out_path.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=fields) writer.writeheader() for rec in records: ctx = " ||| ".join([c.get("text", "") for c in rec.get("context_chunks", [])]) writer.writerow( { "timestamp": rec.get("timestamp", ""), "session_id": rec.get("session_id", ""), "question": rec.get("question", ""), "answer": rec.get("answer", ""), "model": rec.get("model", ""), "k": rec.get("k", ""), "temperature": rec.get("temperature", ""), "top_p": rec.get("top_p", ""), "max_new_tokens": rec.get("max_new_tokens", ""), "context": ctx, } ) return str(out_path), f"Exported {len(records)} records to CSV." def clear_session(session_state): session_state = ensure_session_state(session_state) session_state["records"] = [] session_state["id"] = str(uuid.uuid4()) return session_state, "Session cleared." # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(title="MedDiscover") as demo: gr.Markdown("# 🩺 MedDiscover\nRetrieval Augmented Generation with Large Language Models for Biomedical Discovery Presented by,\nVatsal Patel, Elena Jolkver, Anne Schwerk\nIU International University of Applied Science, Germany") with gr.Row(): with gr.Column(scale=1): api_info = gr.Markdown("") pdfs = gr.File(label="Upload PDFs", file_types=[".pdf"], file_count="multiple") process_btn = gr.Button("Process PDFs (chunk/embed/index)", variant="primary") status = gr.Textbox(label="Status", interactive=False) model_dd = gr.Dropdown( label="Generator Model", choices=list(GENERATORS.keys()), value="tinyllama-1.1b-chat", interactive=True, ) k_slider = gr.Slider(1, 10, value=3, step=1, label="Top-k chunks") max_tokens = gr.Slider(20, 512, value=128, step=8, label="Max new tokens") temp = gr.Slider(0.1, 1.5, value=0.4, step=0.1, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") with gr.Group(): gr.Markdown("Session Logs") export_json_btn = gr.Button("Export session (JSON)") export_csv_btn = gr.Button("Export session (CSV)") clear_session_btn = gr.Button("Clear session") session_status = gr.Textbox(label="Session status", interactive=False) download_json = gr.File(label="JSON export", interactive=False) download_csv = gr.File(label="CSV export", interactive=False) with gr.Column(scale=2): query = gr.Textbox(label="Query", lines=3, placeholder="Ask about your documents...") answer = gr.Textbox(label="Answer (streaming)", lines=6) context_box = gr.Textbox(label="Context used in prompt", lines=14) go_btn = gr.Button("Ask", variant="primary") session_state = gr.State({"id": None, "records": []}) process_btn.click(fn=process_pdfs, inputs=pdfs, outputs=status, show_progress="full") go_btn.click( fn=handle_query, inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state], outputs=[answer, context_box, session_state], concurrency_limit=1, ) query.submit( fn=handle_query, inputs=[query, model_dd, k_slider, max_tokens, temp, top_p, session_state], outputs=[answer, context_box, session_state], concurrency_limit=1, ) export_json_btn.click(fn=export_session_json, inputs=session_state, outputs=[download_json, session_status]) export_csv_btn.click(fn=export_session_csv, inputs=session_state, outputs=[download_csv, session_status]) clear_session_btn.click(fn=clear_session, inputs=session_state, outputs=[session_state, session_status]) if __name__ == "__main__": demo.queue().launch()