Praga-6000 commited on
Commit
82d5a2f
·
verified ·
1 Parent(s): 857c266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -209
app.py CHANGED
@@ -1,219 +1,92 @@
1
- # app.py
2
- import os
3
- import io
4
- import json
5
- import requests
6
- from typing import List, Dict, Optional
7
- import numpy as np
8
- import faiss
9
- import pathlib
10
- import hashlib
11
- import time
12
-
13
  import gradio as gr
14
- import torch
15
- from transformers import AutoTokenizer, AutoModelForCausalLM
16
  from sentence_transformers import SentenceTransformer
 
 
 
17
 
18
- # PDF lib (fallback)
19
- import PyPDF2
 
 
 
 
 
20
 
21
- # ---------- CONFIG ----------
22
- EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # CPU-friendly
23
- LLM_MODEL = "microsoft/Phi-3-mini-4k-instruct" # chosen model
24
- DATA_DIR = "/tmp/rag_data" # persistent within Space runtime
25
- os.makedirs(DATA_DIR, exist_ok=True)
26
-
27
- # ---------- DEVICE ----------
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
-
30
- # ---------- LOAD MODELS ----------
31
- print("Loading embedding model...")
32
- embedder = SentenceTransformer(EMBED_MODEL_NAME)
33
- embed_dim = embedder.get_sentence_embedding_dimension()
34
- print(f"Embedding dim: {embed_dim}")
35
-
36
- print("Loading tokenizer and LLM (may take a while)...")
37
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL, use_fast=True)
38
- model = AutoModelForCausalLM.from_pretrained(LLM_MODEL, trust_remote_code=True)
39
- model.to(device)
40
- model.eval()
41
-
42
- # ---------- UTILITIES ----------
43
- def sha256_text(s: str) -> str:
44
- return hashlib.sha256(s.encode("utf-8")).hexdigest()
45
-
46
- def extract_text_from_pdf_url(url: str) -> Optional[str]:
47
- try:
48
- resp = requests.get(url, timeout=20)
49
- resp.raise_for_status()
50
- pdf_bytes = io.BytesIO(resp.content)
51
- reader = PyPDF2.PdfReader(pdf_bytes)
52
- text_parts = []
53
- for p in reader.pages:
54
- page_text = p.extract_text()
55
- if page_text:
56
- text_parts.append(page_text)
57
- if not text_parts:
58
- return None
59
- return "\n".join(text_parts)
60
- except Exception as e:
61
- print("PDF extraction error:", e)
62
- return None
63
-
64
- def chunk_text_token_aware(text: str, max_tokens=800, overlap_tokens=128):
65
- # approximate by splitting on sentences/words, then measuring token length with tokenizer
66
  words = text.split()
67
  chunks = []
68
- i = 0
69
- while i < len(words):
70
- # grow until ~max_tokens
71
- j = min(len(words), i + max_tokens)
72
- chunk = " ".join(words[i:j])
73
- # if too long by tokens, shrink
74
- enc = tokenizer.encode(chunk, add_special_tokens=False)
75
- if len(enc) > max_tokens:
76
- # binary shrink loop
77
- high = j
78
- low = i
79
- while high - low > 1:
80
- mid = (high + low) // 2
81
- c = " ".join(words[i:mid])
82
- if len(tokenizer.encode(c, add_special_tokens=False)) <= max_tokens:
83
- low = mid
84
- else:
85
- high = mid
86
- chunk = " ".join(words[i:low])
87
- j = low
88
  chunks.append(chunk)
89
- # advance by chunk_size - overlap
90
- i = max(i + max(1, len(tokenizer.encode(chunk, add_special_tokens=False)) - overlap_tokens), j)
91
  return chunks
92
 
93
- def build_or_load_index(paper_id: str, chunks: List[str]):
94
- """
95
- If index exists on disk for paper_id, load it. Otherwise build FAISS index from chunks.
96
- Returns (index, chunks)
97
- """
98
- safe_id = sha256_text(paper_id)
99
- index_path = os.path.join(DATA_DIR, f"{safe_id}.index")
100
- meta_path = os.path.join(DATA_DIR, f"{safe_id}.chunks.json")
101
- if os.path.exists(index_path) and os.path.exists(meta_path):
102
- # load
103
- print("Loading existing index:", index_path)
104
- index = faiss.read_index(index_path)
105
- with open(meta_path, "r", encoding="utf-8") as f:
106
- stored_chunks = json.load(f)
107
- return index, stored_chunks
108
-
109
- # build embeddings
110
- print("Encoding chunks:", len(chunks))
111
- embeddings = embedder.encode(chunks, show_progress_bar=False, convert_to_numpy=True)
112
- # normalize for cosine similarity (IndexFlatIP)
113
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
114
- norms[norms==0] = 1e-10
115
- embeddings = embeddings / norms
116
-
117
- # create index
118
- index = faiss.IndexFlatIP(embeddings.shape[1])
119
- index.add(embeddings.astype('float32'))
120
- # persist
121
- faiss.write_index(index, index_path)
122
- with open(meta_path, "w", encoding="utf-8") as f:
123
- json.dump(chunks, f)
124
- print("Index written:", index_path)
125
- return index, chunks
126
-
127
- def retrieve_relevant(index, chunks, query, k=4):
128
  q_emb = embedder.encode([query], convert_to_numpy=True)
129
- q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
130
- D, I = index.search(q_emb.astype('float32'), k)
131
- results = []
132
- for idx in I[0]:
133
- if idx < len(chunks):
134
- results.append(chunks[idx])
135
- return results
136
-
137
- def generate_answer(question: str, context_chunks: List[str], chat_history: List[Dict]):
138
- # Build a safe prompt: limited context
139
- context = "\n\n---\n\n".join(context_chunks)
140
- # Keep last few messages
141
- history_text = ""
142
- for msg in (chat_history or [])[-6:]:
143
- role = "User" if msg.get("role")=="user" else "Assistant"
144
- history_text += f"{role}: {msg.get('content')}\n"
145
- prompt = f"""You are a helpful research assistant. Use the provided paper content to answer the user's question concisely and cite which chunk the answer came from when relevant.
146
-
147
- Paper Context:
148
- {context}
149
-
150
- Conversation History:
151
- {history_text}
152
-
153
- User: {question}
154
-
155
- Assistant:"""
156
- # tokenize & truncate if needed
157
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length).to(device)
158
- gen = model.generate(
159
- **inputs,
160
- max_new_tokens=256,
161
- temperature=0.0,
162
- do_sample=False,
163
- eos_token_id=tokenizer.eos_token_id,
164
- pad_token_id=tokenizer.eos_token_id
165
- )
166
- out = tokenizer.decode(gen[0], skip_special_tokens=True)
167
- # post-process to return assistant text only
168
- if "Assistant:" in out:
169
- out = out.split("Assistant:")[-1].strip()
170
- return out
171
-
172
- # ---------- MAIN PROCESS ----------
173
- def process_paper_and_answer(paper_id, title, abstract, url, question, chat_history):
174
- # derive unique id (paper_id or url)
175
- pid = paper_id or url or title
176
- if not pid:
177
- pid = str(time.time())
178
- # Try to load or extract text
179
- full_text = None
180
- if url and url.lower().endswith(".pdf"):
181
- full_text = extract_text_from_pdf_url(url)
182
- if not full_text:
183
- full_text = abstract or title or "No content"
184
- # chunk
185
- chunks = chunk_text_token_aware(full_text, max_tokens=800, overlap_tokens=128)
186
- # build or load index (persisted)
187
- index, stored_chunks = build_or_load_index(pid, chunks)
188
- # retrieve
189
- relevant = retrieve_relevant(index, stored_chunks, question, k=4)
190
- # generate
191
- answer = generate_answer(question, relevant, chat_history)
192
- return answer
193
-
194
- # ---------- GRADIO API ----------
195
- def chat_api(paper_id, paper_title, paper_abstract, paper_url, question, chat_history_json):
196
- # chat_history_json might be None or a JSON string
197
- chat_history = chat_history_json or []
198
- try:
199
- return process_paper_and_answer(paper_id, paper_title, paper_abstract, paper_url, question, chat_history)
200
- except Exception as e:
201
- print("Error:", e)
202
- return "Sorry, an internal error occurred."
203
-
204
- iface = gr.Interface(
205
- fn=chat_api,
206
- inputs=[
207
- gr.Textbox(label="Paper ID", lines=1),
208
- gr.Textbox(label="Paper Title", lines=1),
209
- gr.Textbox(label="Paper Abstract", lines=4),
210
- gr.Textbox(label="Paper URL", lines=1),
211
- gr.Textbox(label="Question", lines=2),
212
- gr.JSON(label="Chat History")
213
- ],
214
- outputs=gr.Textbox(label="Answer"),
215
- title="Paper Chat RAG (Space)",
216
- description="Upload a paper URL (PDF) or paste abstract and ask questions."
217
- )
218
-
219
- app = iface.app # expose as API in Space
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import PyPDF2
 
3
  from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ from transformers import pipeline
7
 
8
+ # Load models (lightweight for CPU)
9
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
10
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
11
+
12
+ # Initialize FAISS index (for semantic search)
13
+ embedding_size = 384 # MiniLM-L6-v2 output dimension
14
+ index = faiss.IndexFlatL2(embedding_size)
15
 
16
+ # Storage for documents and embeddings
17
+ doc_chunks = []
18
+ doc_embeddings = None
19
+
20
+
21
+ def extract_text_from_pdf(file):
22
+ """Extract raw text from uploaded PDF."""
23
+ reader = PyPDF2.PdfReader(file)
24
+ text = ""
25
+ for page in reader.pages:
26
+ text += page.extract_text() + " "
27
+ return text
28
+
29
+
30
+ def chunk_text(text, chunk_size=300, overlap=50):
31
+ """Split text into overlapping chunks."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  words = text.split()
33
  chunks = []
34
+ for i in range(0, len(words), chunk_size - overlap):
35
+ chunk = " ".join(words[i:i + chunk_size])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  chunks.append(chunk)
 
 
37
  return chunks
38
 
39
+
40
+ def build_index(pdf_file):
41
+ """Process PDF, create embeddings, and store in FAISS."""
42
+ global doc_chunks, doc_embeddings, index
43
+
44
+ # Extract + chunk
45
+ text = extract_text_from_pdf(pdf_file)
46
+ doc_chunks = chunk_text(text)
47
+
48
+ # Encode chunks
49
+ doc_embeddings = embedder.encode(doc_chunks, convert_to_numpy=True)
50
+
51
+ # Reset and add to FAISS
52
+ index = faiss.IndexFlatL2(embedding_size)
53
+ index.add(doc_embeddings)
54
+
55
+ return f"PDF processed! {len(doc_chunks)} chunks indexed."
56
+
57
+
58
+ def answer_question(query, top_k=3):
59
+ """Retrieve relevant chunks and answer user query."""
60
+ if doc_embeddings is None:
61
+ return "Please upload and process a PDF first."
62
+
63
+ # Embed question
 
 
 
 
 
 
 
 
 
 
64
  q_emb = embedder.encode([query], convert_to_numpy=True)
65
+ distances, indices = index.search(q_emb, top_k)
66
+
67
+ # Gather top chunks
68
+ context = " ".join([doc_chunks[i] for i in indices[0]])
69
+
70
+ # Run QA pipeline
71
+ result = qa_pipeline(question=query, context=context)
72
+ return result["answer"]
73
+
74
+
75
+ # Gradio UI
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# 📚 PDF Q&A App\nUpload a PDF and ask questions about it!")
78
+
79
+ with gr.Row():
80
+ pdf_input = gr.File(label="Upload PDF", type="filepath")
81
+ process_btn = gr.Button("Process PDF")
82
+
83
+ status = gr.Textbox(label="Status", interactive=False)
84
+
85
+ with gr.Row():
86
+ question = gr.Textbox(label="Ask a Question")
87
+ answer = gr.Textbox(label="Answer", interactive=False)
88
+
89
+ process_btn.click(build_index, inputs=pdf_input, outputs=status)
90
+ question.submit(answer_question, inputs=question, outputs=answer)
91
+
92
+ demo.launch()