# app.py # RAG app for chatting with research papers (optimized for Hugging Face Spaces) import os, sys, subprocess, re, json, uuid, gc from typing import List, Dict, Tuple # ----------------------------- # Auto-install deps if missing # ----------------------------- def ensure(pkg, pip_name=None): try: __import__(pkg) except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name or pkg]) ensure("torch") ensure("transformers") ensure("accelerate") ensure("gradio") ensure("faiss", "faiss-cpu") ensure("sentence_transformers", "sentence-transformers") ensure("pypdf") ensure("docx", "python-docx") import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer ) from sentence_transformers import SentenceTransformer import faiss, gradio as gr from pypdf import PdfReader # ----------------------------- # Config # ----------------------------- DATA_DIR = "rag_data" os.makedirs(DATA_DIR, exist_ok=True) INDEX_PATH = os.path.join(DATA_DIR, "faiss.index") DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl") # Default Models default_emb_model = "allenai/specter2_base" default_llm_model = "microsoft/Phi-3-mini-4k-instruct" EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model) LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # File loaders # ----------------------------- def read_txt(path): return open(path, "r", encoding="utf-8", errors="ignore").read() def read_pdf(path): r = PdfReader(path) return "\n".join([p.extract_text() or "" for p in r.pages]) def read_docx(path): import docx d = docx.Document(path) return "\n".join([p.text for p in d.paragraphs]) def load_file(path): ext = os.path.splitext(path)[1].lower() if ext in [".txt", ".md"]: return read_txt(path) if ext == ".pdf": return read_pdf(path) if ext == ".docx": return read_docx(path) return read_txt(path) # ----------------------------- # Chunking # ----------------------------- def normalize_ws(s: str): return re.sub(r"\s+", " ", s).strip() def chunk_text(text, chunk_size=900, overlap=150): text = normalize_ws(text) chunks = [] for i in range(0, len(text), chunk_size - overlap): chunks.append(text[i:i+chunk_size]) return chunks # ----------------------------- # VectorStore # ----------------------------- class VectorStore: def __init__(self, emb_model): self.emb_model = emb_model self.dim = emb_model.get_sentence_embedding_dimension() if os.path.exists(INDEX_PATH): self.index = faiss.read_index(INDEX_PATH) self.meta = [json.loads(l) for l in open(DOCS_PATH, "r", encoding="utf-8")] else: self.index = faiss.IndexFlatIP(self.dim) self.meta = [] def _embed(self, texts): embs = self.emb_model.encode(texts, convert_to_tensor=True, normalize_embeddings=True) return embs.cpu().numpy() def add(self, chunks, source): if not chunks: return 0 embs = self._embed(chunks) faiss.normalize_L2(embs) self.index.add(embs) recs = [] for c in chunks: rec = {"id": str(uuid.uuid4()), "source": source, "text": c} self.meta.append(rec) recs.append(json.dumps(rec)) with open(DOCS_PATH, "a", encoding="utf-8") as f: f.write("\n".join(recs) + "\n") faiss.write_index(self.index, INDEX_PATH) return len(chunks) def search(self, query, k=5): q = self._embed([query]) faiss.normalize_L2(q) D, I = self.index.search(q, k) return [(float(d), self.meta[i]) for d, i in zip(D[0], I[0]) if i != -1] def clear(self): self.index = faiss.IndexFlatIP(self.dim) self.meta = [] if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) if os.path.exists(DOCS_PATH): os.remove(DOCS_PATH) # ----------------------------- # Load models # ----------------------------- print(f"[RAG] Loading embeddings: {EMB_MODEL_ID}") EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE) VEC = VectorStore(EMB) print(f"[RAG] Loading LLM: {LLM_MODEL_ID}") bnb_config = None if DEVICE == "cuda": from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True) LLM = AutoModelForCausalLM.from_pretrained( LLM_MODEL_ID, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, ) # ----------------------------- # Prompt + Generate # ----------------------------- SYSTEM_PROMPT = "You are a helpful assistant. Use the provided context from research papers to answer questions." def build_prompt(query, history, retrieved): ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)]) # Try to use chat template if available if hasattr(TOKENIZER, "apply_chat_template"): messages = [{"role": "system", "content": SYSTEM_PROMPT + "\nContext:\n" + ctx}] for u, a in history[-3:]: messages.append({"role": "user", "content": u}) messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": query}) return TOKENIZER.apply_chat_template(messages, tokenize=False) else: # Fallback manual prompt hist = "".join([f"{u}{a}" for u, a in history[-3:]]) return f"{SYSTEM_PROMPT}\nContext:\n{ctx}{hist}{query}" @torch.inference_mode() def generate_answer(prompt, temperature=0.3, max_new_tokens=512): streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True) inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device) kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0, streamer=streamer ) import threading t = threading.Thread(target=LLM.generate, kwargs=kwargs) t.start() out = "" for token in streamer: out += token yield out t.join() # ----------------------------- # Gradio UI # ----------------------------- def ui_ingest(files, chunk_size, overlap): total = 0 names = [] for f in files or []: text = load_file(f.name) chunks = chunk_text(text, chunk_size, overlap) n = VEC.add(chunks, os.path.basename(f.name)) total += n; names.append(f.name) return f"Added {total} chunks", "\n".join(names) or "—", VEC.index.ntotal def ui_clear(): VEC.clear() gc.collect() return "Index cleared", "—", 0 def ui_chat(msg, history, top_k, temperature, max_tokens): if not msg.strip(): return history, "" retrieved = VEC.search(msg, top_k) prompt = build_prompt(msg, history, retrieved) reply = "" for partial in generate_answer(prompt, temperature, max_tokens): reply = partial yield history + [(msg, reply)], "" yield history + [(msg, reply)], "" with gr.Blocks() as demo: gr.Markdown("# 🔎📚 Research Paper RAG Chat (Phi-3-mini + Specter2)") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(height=500) msg = gr.Textbox(placeholder="Ask a question...") with gr.Row(): send = gr.Button("Send", variant="primary") clearc = gr.Button("Clear Chat") with gr.Column(): files = gr.File(label="Upload PDFs/DOCX/TXT", file_types=[".pdf", ".docx", ".txt", ".md"], file_count="multiple") chunk_size = gr.Slider(200,2000,900,step=50,label="Chunk Size") overlap = gr.Slider(0,400,150,step=10,label="Overlap") ingest_btn = gr.Button("Index Documents") status = gr.Textbox(label="Status", value="—") added = gr.Textbox(label="Files", value="—") total = gr.Number(label="Total Chunks", value=VEC.index.ntotal) clear_idx = gr.Button("Clear Index", variant="stop") top_k = gr.Slider(1,10,5,1,label="Top-K") temperature = gr.Slider(0.0,1.5,0.3,0.1,label="Temperature") max_tokens = gr.Slider(64,2048,512,64,label="Max New Tokens") ingest_btn.click(ui_ingest, [files, chunk_size, overlap], [status, added, total]) clear_idx.click(ui_clear, [], [status, added, total]) send.click(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) msg.submit(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) clearc.click(lambda: ([], ""), [], [chatbot, msg]) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))