# app.py
# RAG app for chatting with research papers (optimized for Hugging Face Spaces)
import os, sys, subprocess, re, json, uuid, gc
from typing import List, Dict, Tuple
# -----------------------------
# Auto-install deps if missing
# -----------------------------
def ensure(pkg, pip_name=None):
try:
__import__(pkg)
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name or pkg])
ensure("torch")
ensure("transformers")
ensure("accelerate")
ensure("gradio")
ensure("faiss", "faiss-cpu")
ensure("sentence_transformers", "sentence-transformers")
ensure("pypdf")
ensure("docx", "python-docx")
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer
)
from sentence_transformers import SentenceTransformer
import faiss, gradio as gr
from pypdf import PdfReader
# -----------------------------
# Config
# -----------------------------
DATA_DIR = "rag_data"
os.makedirs(DATA_DIR, exist_ok=True)
INDEX_PATH = os.path.join(DATA_DIR, "faiss.index")
DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl")
# Default Models
default_emb_model = "allenai/specter2_base"
default_llm_model = "microsoft/Phi-3-mini-4k-instruct"
EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model)
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# File loaders
# -----------------------------
def read_txt(path):
return open(path, "r", encoding="utf-8", errors="ignore").read()
def read_pdf(path):
r = PdfReader(path)
return "\n".join([p.extract_text() or "" for p in r.pages])
def read_docx(path):
import docx
d = docx.Document(path)
return "\n".join([p.text for p in d.paragraphs])
def load_file(path):
ext = os.path.splitext(path)[1].lower()
if ext in [".txt", ".md"]:
return read_txt(path)
if ext == ".pdf":
return read_pdf(path)
if ext == ".docx":
return read_docx(path)
return read_txt(path)
# -----------------------------
# Chunking
# -----------------------------
def normalize_ws(s: str):
return re.sub(r"\s+", " ", s).strip()
def chunk_text(text, chunk_size=900, overlap=150):
text = normalize_ws(text)
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunks.append(text[i:i+chunk_size])
return chunks
# -----------------------------
# VectorStore
# -----------------------------
class VectorStore:
def __init__(self, emb_model):
self.emb_model = emb_model
self.dim = emb_model.get_sentence_embedding_dimension()
if os.path.exists(INDEX_PATH):
self.index = faiss.read_index(INDEX_PATH)
self.meta = [json.loads(l) for l in open(DOCS_PATH, "r", encoding="utf-8")]
else:
self.index = faiss.IndexFlatIP(self.dim)
self.meta = []
def _embed(self, texts):
embs = self.emb_model.encode(texts, convert_to_tensor=True, normalize_embeddings=True)
return embs.cpu().numpy()
def add(self, chunks, source):
if not chunks: return 0
embs = self._embed(chunks)
faiss.normalize_L2(embs)
self.index.add(embs)
recs = []
for c in chunks:
rec = {"id": str(uuid.uuid4()), "source": source, "text": c}
self.meta.append(rec)
recs.append(json.dumps(rec))
with open(DOCS_PATH, "a", encoding="utf-8") as f:
f.write("\n".join(recs) + "\n")
faiss.write_index(self.index, INDEX_PATH)
return len(chunks)
def search(self, query, k=5):
q = self._embed([query])
faiss.normalize_L2(q)
D, I = self.index.search(q, k)
return [(float(d), self.meta[i]) for d, i in zip(D[0], I[0]) if i != -1]
def clear(self):
self.index = faiss.IndexFlatIP(self.dim)
self.meta = []
if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH)
if os.path.exists(DOCS_PATH): os.remove(DOCS_PATH)
# -----------------------------
# Load models
# -----------------------------
print(f"[RAG] Loading embeddings: {EMB_MODEL_ID}")
EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE)
VEC = VectorStore(EMB)
print(f"[RAG] Loading LLM: {LLM_MODEL_ID}")
bnb_config = None
if DEVICE == "cuda":
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True)
LLM = AutoModelForCausalLM.from_pretrained(
LLM_MODEL_ID,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
# -----------------------------
# Prompt + Generate
# -----------------------------
SYSTEM_PROMPT = "You are a helpful assistant. Use the provided context from research papers to answer questions."
def build_prompt(query, history, retrieved):
ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)])
# Try to use chat template if available
if hasattr(TOKENIZER, "apply_chat_template"):
messages = [{"role": "system", "content": SYSTEM_PROMPT + "\nContext:\n" + ctx}]
for u, a in history[-3:]:
messages.append({"role": "user", "content": u})
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": query})
return TOKENIZER.apply_chat_template(messages, tokenize=False)
else:
# Fallback manual prompt
hist = "".join([f"{u}{a}" for u, a in history[-3:]])
return f"{SYSTEM_PROMPT}\nContext:\n{ctx}{hist}{query}"
@torch.inference_mode()
def generate_answer(prompt, temperature=0.3, max_new_tokens=512):
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device)
kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
streamer=streamer
)
import threading
t = threading.Thread(target=LLM.generate, kwargs=kwargs)
t.start()
out = ""
for token in streamer:
out += token
yield out
t.join()
# -----------------------------
# Gradio UI
# -----------------------------
def ui_ingest(files, chunk_size, overlap):
total = 0
names = []
for f in files or []:
text = load_file(f.name)
chunks = chunk_text(text, chunk_size, overlap)
n = VEC.add(chunks, os.path.basename(f.name))
total += n; names.append(f.name)
return f"Added {total} chunks", "\n".join(names) or "—", VEC.index.ntotal
def ui_clear():
VEC.clear()
gc.collect()
return "Index cleared", "—", 0
def ui_chat(msg, history, top_k, temperature, max_tokens):
if not msg.strip():
return history, ""
retrieved = VEC.search(msg, top_k)
prompt = build_prompt(msg, history, retrieved)
reply = ""
for partial in generate_answer(prompt, temperature, max_tokens):
reply = partial
yield history + [(msg, reply)], ""
yield history + [(msg, reply)], ""
with gr.Blocks() as demo:
gr.Markdown("# 🔎📚 Research Paper RAG Chat (Phi-3-mini + Specter2)")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(placeholder="Ask a question...")
with gr.Row():
send = gr.Button("Send", variant="primary")
clearc = gr.Button("Clear Chat")
with gr.Column():
files = gr.File(label="Upload PDFs/DOCX/TXT", file_types=[".pdf", ".docx", ".txt", ".md"], file_count="multiple")
chunk_size = gr.Slider(200,2000,900,step=50,label="Chunk Size")
overlap = gr.Slider(0,400,150,step=10,label="Overlap")
ingest_btn = gr.Button("Index Documents")
status = gr.Textbox(label="Status", value="—")
added = gr.Textbox(label="Files", value="—")
total = gr.Number(label="Total Chunks", value=VEC.index.ntotal)
clear_idx = gr.Button("Clear Index", variant="stop")
top_k = gr.Slider(1,10,5,1,label="Top-K")
temperature = gr.Slider(0.0,1.5,0.3,0.1,label="Temperature")
max_tokens = gr.Slider(64,2048,512,64,label="Max New Tokens")
ingest_btn.click(ui_ingest, [files, chunk_size, overlap], [status, added, total])
clear_idx.click(ui_clear, [], [status, added, total])
send.click(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg])
msg.submit(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg])
clearc.click(lambda: ([], ""), [], [chatbot, msg])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))