Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# app.py
|
| 2 |
# RAG app for chatting with research papers (optimized for Hugging Face Spaces)
|
| 3 |
|
| 4 |
-
import os, sys, subprocess, re, json, uuid, gc
|
| 5 |
from typing import List, Dict, Tuple
|
| 6 |
|
| 7 |
# -----------------------------
|
|
@@ -16,15 +16,18 @@ def ensure(pkg, pip_name=None):
|
|
| 16 |
ensure("torch")
|
| 17 |
ensure("transformers")
|
| 18 |
ensure("accelerate")
|
| 19 |
-
ensure("bitsandbytes")
|
| 20 |
-
ensure("faiss", "faiss-cpu")
|
| 21 |
ensure("gradio")
|
|
|
|
| 22 |
ensure("sentence_transformers", "sentence-transformers")
|
| 23 |
ensure("pypdf")
|
| 24 |
ensure("docx", "python-docx")
|
| 25 |
|
| 26 |
import torch
|
| 27 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from sentence_transformers import SentenceTransformer
|
| 29 |
import faiss, gradio as gr
|
| 30 |
from pypdf import PdfReader
|
|
@@ -37,9 +40,9 @@ os.makedirs(DATA_DIR, exist_ok=True)
|
|
| 37 |
INDEX_PATH = os.path.join(DATA_DIR, "faiss.index")
|
| 38 |
DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl")
|
| 39 |
|
| 40 |
-
# Models
|
| 41 |
default_emb_model = "allenai/specter2_base"
|
| 42 |
-
default_llm_model = "
|
| 43 |
|
| 44 |
EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model)
|
| 45 |
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model)
|
|
@@ -107,11 +110,13 @@ class VectorStore:
|
|
| 107 |
embs = self._embed(chunks)
|
| 108 |
faiss.normalize_L2(embs)
|
| 109 |
self.index.add(embs)
|
|
|
|
| 110 |
for c in chunks:
|
| 111 |
rec = {"id": str(uuid.uuid4()), "source": source, "text": c}
|
| 112 |
self.meta.append(rec)
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
faiss.write_index(self.index, INDEX_PATH)
|
| 116 |
return len(chunks)
|
| 117 |
|
|
@@ -134,14 +139,22 @@ print(f"[RAG] Loading embeddings: {EMB_MODEL_ID}")
|
|
| 134 |
EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE)
|
| 135 |
VEC = VectorStore(EMB)
|
| 136 |
|
| 137 |
-
print(f"[RAG] Loading LLM
|
| 138 |
-
bnb_config =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True)
|
| 140 |
LLM = AutoModelForCausalLM.from_pretrained(
|
| 141 |
LLM_MODEL_ID,
|
| 142 |
device_map="auto",
|
| 143 |
quantization_config=bnb_config,
|
| 144 |
-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.
|
| 145 |
low_cpu_mem_usage=True,
|
| 146 |
trust_remote_code=True,
|
| 147 |
)
|
|
@@ -153,14 +166,30 @@ SYSTEM_PROMPT = "You are a helpful assistant. Use the provided context from rese
|
|
| 153 |
|
| 154 |
def build_prompt(query, history, retrieved):
|
| 155 |
ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)])
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
@torch.inference_mode()
|
| 160 |
def generate_answer(prompt, temperature=0.3, max_new_tokens=512):
|
| 161 |
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 162 |
inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device)
|
| 163 |
-
kwargs = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
import threading
|
| 165 |
t = threading.Thread(target=LLM.generate, kwargs=kwargs)
|
| 166 |
t.start()
|
|
@@ -200,7 +229,7 @@ def ui_chat(msg, history, top_k, temperature, max_tokens):
|
|
| 200 |
yield history + [(msg, reply)], ""
|
| 201 |
|
| 202 |
with gr.Blocks() as demo:
|
| 203 |
-
gr.Markdown("# ππ Research Paper RAG Chat (
|
| 204 |
with gr.Row():
|
| 205 |
with gr.Column(scale=2):
|
| 206 |
chatbot = gr.Chatbot(height=500)
|
|
|
|
| 1 |
# app.py
|
| 2 |
# RAG app for chatting with research papers (optimized for Hugging Face Spaces)
|
| 3 |
|
| 4 |
+
import os, sys, subprocess, re, json, uuid, gc
|
| 5 |
from typing import List, Dict, Tuple
|
| 6 |
|
| 7 |
# -----------------------------
|
|
|
|
| 16 |
ensure("torch")
|
| 17 |
ensure("transformers")
|
| 18 |
ensure("accelerate")
|
|
|
|
|
|
|
| 19 |
ensure("gradio")
|
| 20 |
+
ensure("faiss", "faiss-cpu")
|
| 21 |
ensure("sentence_transformers", "sentence-transformers")
|
| 22 |
ensure("pypdf")
|
| 23 |
ensure("docx", "python-docx")
|
| 24 |
|
| 25 |
import torch
|
| 26 |
+
from transformers import (
|
| 27 |
+
AutoTokenizer,
|
| 28 |
+
AutoModelForCausalLM,
|
| 29 |
+
TextIteratorStreamer
|
| 30 |
+
)
|
| 31 |
from sentence_transformers import SentenceTransformer
|
| 32 |
import faiss, gradio as gr
|
| 33 |
from pypdf import PdfReader
|
|
|
|
| 40 |
INDEX_PATH = os.path.join(DATA_DIR, "faiss.index")
|
| 41 |
DOCS_PATH = os.path.join(DATA_DIR, "docs.jsonl")
|
| 42 |
|
| 43 |
+
# Default Models
|
| 44 |
default_emb_model = "allenai/specter2_base"
|
| 45 |
+
default_llm_model = "microsoft/Phi-3-mini-4k-instruct"
|
| 46 |
|
| 47 |
EMB_MODEL_ID = os.environ.get("EMB_MODEL_ID", default_emb_model)
|
| 48 |
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", default_llm_model)
|
|
|
|
| 110 |
embs = self._embed(chunks)
|
| 111 |
faiss.normalize_L2(embs)
|
| 112 |
self.index.add(embs)
|
| 113 |
+
recs = []
|
| 114 |
for c in chunks:
|
| 115 |
rec = {"id": str(uuid.uuid4()), "source": source, "text": c}
|
| 116 |
self.meta.append(rec)
|
| 117 |
+
recs.append(json.dumps(rec))
|
| 118 |
+
with open(DOCS_PATH, "a", encoding="utf-8") as f:
|
| 119 |
+
f.write("\n".join(recs) + "\n")
|
| 120 |
faiss.write_index(self.index, INDEX_PATH)
|
| 121 |
return len(chunks)
|
| 122 |
|
|
|
|
| 139 |
EMB = SentenceTransformer(EMB_MODEL_ID, device=DEVICE)
|
| 140 |
VEC = VectorStore(EMB)
|
| 141 |
|
| 142 |
+
print(f"[RAG] Loading LLM: {LLM_MODEL_ID}")
|
| 143 |
+
bnb_config = None
|
| 144 |
+
if DEVICE == "cuda":
|
| 145 |
+
from transformers import BitsAndBytesConfig
|
| 146 |
+
bnb_config = BitsAndBytesConfig(
|
| 147 |
+
load_in_4bit=True,
|
| 148 |
+
bnb_4bit_use_double_quant=True,
|
| 149 |
+
bnb_4bit_quant_type="nf4"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_ID, use_fast=True, trust_remote_code=True)
|
| 153 |
LLM = AutoModelForCausalLM.from_pretrained(
|
| 154 |
LLM_MODEL_ID,
|
| 155 |
device_map="auto",
|
| 156 |
quantization_config=bnb_config,
|
| 157 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 158 |
low_cpu_mem_usage=True,
|
| 159 |
trust_remote_code=True,
|
| 160 |
)
|
|
|
|
| 166 |
|
| 167 |
def build_prompt(query, history, retrieved):
|
| 168 |
ctx = "\n\n".join([f"[{i+1}] {m['text']}" for i, (_, m) in enumerate(retrieved)])
|
| 169 |
+
# Try to use chat template if available
|
| 170 |
+
if hasattr(TOKENIZER, "apply_chat_template"):
|
| 171 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT + "\nContext:\n" + ctx}]
|
| 172 |
+
for u, a in history[-3:]:
|
| 173 |
+
messages.append({"role": "user", "content": u})
|
| 174 |
+
messages.append({"role": "assistant", "content": a})
|
| 175 |
+
messages.append({"role": "user", "content": query})
|
| 176 |
+
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
| 177 |
+
else:
|
| 178 |
+
# Fallback manual prompt
|
| 179 |
+
hist = "".join([f"<user>{u}</user><assistant>{a}</assistant>" for u, a in history[-3:]])
|
| 180 |
+
return f"<system>{SYSTEM_PROMPT}\nContext:\n{ctx}</system>{hist}<user>{query}</user><assistant>"
|
| 181 |
|
| 182 |
@torch.inference_mode()
|
| 183 |
def generate_answer(prompt, temperature=0.3, max_new_tokens=512):
|
| 184 |
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
|
| 185 |
inputs = TOKENIZER([prompt], return_tensors="pt").to(LLM.device)
|
| 186 |
+
kwargs = dict(
|
| 187 |
+
**inputs,
|
| 188 |
+
max_new_tokens=max_new_tokens,
|
| 189 |
+
temperature=temperature,
|
| 190 |
+
do_sample=temperature > 0,
|
| 191 |
+
streamer=streamer
|
| 192 |
+
)
|
| 193 |
import threading
|
| 194 |
t = threading.Thread(target=LLM.generate, kwargs=kwargs)
|
| 195 |
t.start()
|
|
|
|
| 229 |
yield history + [(msg, reply)], ""
|
| 230 |
|
| 231 |
with gr.Blocks() as demo:
|
| 232 |
+
gr.Markdown("# ππ Research Paper RAG Chat (Phi-3-mini + Specter2)")
|
| 233 |
with gr.Row():
|
| 234 |
with gr.Column(scale=2):
|
| 235 |
chatbot = gr.Chatbot(height=500)
|