Spaces:
Runtime error
Runtime error
# app.py | |
# RAG app for chatting with research papers (optimized for Hugging Face Spaces) | |
import os, sys, subprocess, re, json, uuid, gc | |
from typing import List, Dict, Tuple | |
# ----------------------------- | |
# Auto-install deps if missing | |
# ----------------------------- | |
def ensure(pkg, pip_name=None): | |
try: | |
__import__(pkg) | |
except ImportError: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name or pkg]) | |
ensure("torch") | |
ensure("transformers") | |
ensure("accelerate") | |
ensure("gradio") | |
ensure("faiss", "faiss-cpu") | |
ensure("sentence_transformers", "sentence-transformers") | |
ensure("pypdf") | |
ensure("docx", "python-docx") | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TextIteratorStreamer | |
) | |
from sentence_transformers import SentenceTransformer | |
import faiss, gradio as gr | |
from pypdf import PdfReader | |
# ----------------------------- | |
# Config | |
# ----------------------------- | |
DATA_DIR = "rag_data" | |
os.makedirs(DATA_DIR, exist_ok=True) | |
INDEX_PATH = os.path.join(DATA_DIR, "faiss.index") | |
DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl") | |
# Default Models | |
default_emb_model = "allenai/specter2_base" | |
default_llm_model = "microsoft/Phi-3-mini-4k-instruct" | |
EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model) | |
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# ----------------------------- | |
# File loaders | |
# ----------------------------- | |
def read_txt(path): | |
return open(path, "r", encoding="utf-8", errors="ignore").read() | |
def read_pdf(path): | |
r = PdfReader(path) | |
return "\n".join([p.extract_text() or "" for p in r.pages]) | |
def read_docx(path): | |
import docx | |
d = docx.Document(path) | |
return "\n".join([p.text for p in d.paragraphs]) | |
def load_file(path): | |
ext = os.path.splitext(path)[1].lower() | |
if ext in [".txt", ".md"]: | |
return read_txt(path) | |
if ext == ".pdf": | |
return read_pdf(path) | |
if ext == ".docx": | |
return read_docx(path) | |
return read_txt(path) | |
# ----------------------------- | |
# Chunking | |
# ----------------------------- | |
def normalize_ws(s: str): | |
return re.sub(r"\s+", " ", s).strip() | |
def chunk_text(text, chunk_size=900, overlap=150): | |
text = normalize_ws(text) | |
chunks = [] | |
for i in range(0, len(text), chunk_size - overlap): | |
chunks.append(text[i:i+chunk_size]) | |
return chunks | |
# ----------------------------- | |
# VectorStore | |
# ----------------------------- | |
class VectorStore: | |
def __init__(self, emb_model): | |
self.emb_model = emb_model | |
self.dim = emb_model.get_sentence_embedding_dimension() | |
if os.path.exists(INDEX_PATH): | |
self.index = faiss.read_index(INDEX_PATH) | |
self.meta = [json.loads(l) for l in open(DOCS_PATH, "r", encoding="utf-8")] | |
else: | |
self.index = faiss.IndexFlatIP(self.dim) | |
self.meta = [] | |
def _embed(self, texts): | |
embs = self.emb_model.encode(texts, convert_to_tensor=True, normalize_embeddings=True) | |
return embs.cpu().numpy() | |
def add(self, chunks, source): | |
if not chunks: return 0 | |
embs = self._embed(chunks) | |
faiss.normalize_L2(embs) | |
self.index.add(embs) | |
recs = [] | |
for c in chunks: | |
rec = {"id": str(uuid.uuid4()), "source": source, "text": c} | |
self.meta.append(rec) | |
recs.append(json.dumps(rec)) | |
with open(DOCS_PATH, "a", encoding="utf-8") as f: | |
f.write("\n".join(recs) + "\n") | |
faiss.write_index(self.index, INDEX_PATH) | |
return len(chunks) | |
def search(self, query, k=5): | |
q = self._embed([query]) | |
faiss.normalize_L2(q) | |
D, I = self.index.search(q, k) | |
return [(float(d), self.meta[i]) for d, i in zip(D[0], I[0]) if i != -1] | |
def clear(self): | |
self.index = faiss.IndexFlatIP(self.dim) | |
self.meta = [] | |
if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) | |
if os.path.exists(DOCS_PATH): os.remove(DOCS_PATH) | |
# ----------------------------- | |
# Load models | |
# ----------------------------- | |
print(f"[RAG] Loading embeddings: {EMB_MODEL_ID}") | |
EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE) | |
VEC = VectorStore(EMB) | |
print(f"[RAG] Loading LLM: {LLM_MODEL_ID}") | |
bnb_config = None | |
if DEVICE == "cuda": | |
from transformers import BitsAndBytesConfig | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True) | |
LLM = AutoModelForCausalLM.from_pretrained( | |
LLM_MODEL_ID, | |
device_map="auto", | |
quantization_config=bnb_config, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True, | |
) | |
# ----------------------------- | |
# Prompt + Generate | |
# ----------------------------- | |
SYSTEM_PROMPT = "You are a helpful assistant. Use the provided context from research papers to answer questions." | |
def build_prompt(query, history, retrieved): | |
ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)]) | |
# Try to use chat template if available | |
if hasattr(TOKENIZER, "apply_chat_template"): | |
messages = [{"role": "system", "content": SYSTEM_PROMPT + "\nContext:\n" + ctx}] | |
for u, a in history[-3:]: | |
messages.append({"role": "user", "content": u}) | |
messages.append({"role": "assistant", "content": a}) | |
messages.append({"role": "user", "content": query}) | |
return TOKENIZER.apply_chat_template(messages, tokenize=False) | |
else: | |
# Fallback manual prompt | |
hist = "".join([f"<user>{u}</user><assistant>{a}</assistant>" for u, a in history[-3:]]) | |
return f"<system>{SYSTEM_PROMPT}\nContext:\n{ctx}</system>{hist}<user>{query}</user><assistant>" | |
def generate_answer(prompt, temperature=0.3, max_new_tokens=512): | |
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True) | |
inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device) | |
kwargs = dict( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
do_sample=temperature > 0, | |
streamer=streamer | |
) | |
import threading | |
t = threading.Thread(target=LLM.generate, kwargs=kwargs) | |
t.start() | |
out = "" | |
for token in streamer: | |
out += token | |
yield out | |
t.join() | |
# ----------------------------- | |
# Gradio UI | |
# ----------------------------- | |
def ui_ingest(files, chunk_size, overlap): | |
total = 0 | |
names = [] | |
for f in files or []: | |
text = load_file(f.name) | |
chunks = chunk_text(text, chunk_size, overlap) | |
n = VEC.add(chunks, os.path.basename(f.name)) | |
total += n; names.append(f.name) | |
return f"Added {total} chunks", "\n".join(names) or "β", VEC.index.ntotal | |
def ui_clear(): | |
VEC.clear() | |
gc.collect() | |
return "Index cleared", "β", 0 | |
def ui_chat(msg, history, top_k, temperature, max_tokens): | |
if not msg.strip(): | |
return history, "" | |
retrieved = VEC.search(msg, top_k) | |
prompt = build_prompt(msg, history, retrieved) | |
reply = "" | |
for partial in generate_answer(prompt, temperature, max_tokens): | |
reply = partial | |
yield history + [(msg, reply)], "" | |
yield history + [(msg, reply)], "" | |
with gr.Blocks() as demo: | |
gr.Markdown("# ππ Research Paper RAG Chat (Phi-3-mini + Specter2)") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(height=500) | |
msg = gr.Textbox(placeholder="Ask a question...") | |
with gr.Row(): | |
send = gr.Button("Send", variant="primary") | |
clearc = gr.Button("Clear Chat") | |
with gr.Column(): | |
files = gr.File(label="Upload PDFs/DOCX/TXT", file_types=[".pdf", ".docx", ".txt", ".md"], file_count="multiple") | |
chunk_size = gr.Slider(200,2000,900,step=50,label="Chunk Size") | |
overlap = gr.Slider(0,400,150,step=10,label="Overlap") | |
ingest_btn = gr.Button("Index Documents") | |
status = gr.Textbox(label="Status", value="β") | |
added = gr.Textbox(label="Files", value="β") | |
total = gr.Number(label="Total Chunks", value=VEC.index.ntotal) | |
clear_idx = gr.Button("Clear Index", variant="stop") | |
top_k = gr.Slider(1,10,5,1,label="Top-K") | |
temperature = gr.Slider(0.0,1.5,0.3,0.1,label="Temperature") | |
max_tokens = gr.Slider(64,2048,512,64,label="Max New Tokens") | |
ingest_btn.click(ui_ingest, [files, chunk_size, overlap], [status, added, total]) | |
clear_idx.click(ui_clear, [], [status, added, total]) | |
send.click(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) | |
msg.submit(ui_chat, [msg, chatbot, top_k, temperature, max_tokens], [chatbot, msg]) | |
clearc.click(lambda: ([], ""), [], [chatbot, msg]) | |
if __name__ == "__main__": | |
demo.queue().launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |