ashishninehertz commited on
Commit
3a2241a
·
verified ·
1 Parent(s): df36133

Upload 9 files

Browse files
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. .gitignore +16 -0
  3. __init__.py +0 -0
  4. app.py +7 -0
  5. chat_history.db +3 -0
  6. chatbot.py +446 -0
  7. database.py +20 -0
  8. main.py +65 -0
  9. models.py +8 -0
  10. requirements.txt +11 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chat_history.db filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore environment files
2
+ .env
3
+
4
+ # Ignore Python virtual environments
5
+ venv/
6
+ env/
7
+ __pycache__/
8
+ *.pyc
9
+
10
+ # Ignore IDE/config files
11
+ *.log
12
+ .DS_Store
13
+ *.sqlite3
14
+ *.db
15
+ .vscode/
16
+ .idea/
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ demo.launch()
chat_history.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3485fa5937b9980b45b397f433587f3b5b74731d1a0372f166eeb287dd0a20ee
3
+ size 208896
chatbot.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import os
3
+ import io
4
+ import time
5
+ from functools import lru_cache
6
+ from dotenv import load_dotenv
7
+ from database import SessionLocal, ChatMessage
8
+ from qdrant_client import QdrantClient
9
+ from qdrant_client.models import (
10
+ PointStruct, Distance, VectorParams,
11
+ Filter, FieldCondition, MatchValue, PointIdsList
12
+ )
13
+ from sentence_transformers import SentenceTransformer
14
+ from groq import Groq
15
+ import pdfplumber
16
+ from tabulate import tabulate
17
+ import pytesseract
18
+ from PIL import Image, ImageEnhance, ImageFilter
19
+ import fitz # PyMuPDF
20
+ import torch
21
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
22
+ import warnings
23
+ warnings.filterwarnings("ignore", message="Could get FontBBox from font descriptor*")
24
+
25
+ # Configure Tesseract path (Windows specific)
26
+ pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
27
+
28
+ load_dotenv()
29
+
30
+ # Initialize clients
31
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
32
+ qdrant = QdrantClient(
33
+ url=os.getenv("QDRANT_URL"),
34
+ api_key=os.getenv("QDRANT_API_KEY"),
35
+ )
36
+
37
+ COLLECTION_NAME = "chatbot_sessions"
38
+ PDF_COLLECTION_NAME = "pdf_documents"
39
+ MAX_HISTORY_LENGTH = 5
40
+ SUMMARY_CACHE_SIZE = 100
41
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
42
+
43
+ # Initialize DePlot
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ deplot_processor = Pix2StructProcessor.from_pretrained("google/deplot")
46
+ deplot_model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(device)
47
+
48
+ def create_collections():
49
+ """Initialize Qdrant collections if they don't exist"""
50
+ existing_collections = [c.name for c in qdrant.get_collections().collections]
51
+
52
+ if COLLECTION_NAME not in existing_collections:
53
+ qdrant.recreate_collection(
54
+ collection_name=COLLECTION_NAME,
55
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE),
56
+ timeout=1200
57
+ )
58
+ qdrant.create_payload_index(
59
+ collection_name=COLLECTION_NAME,
60
+ field_name="session_id",
61
+ field_schema="keyword"
62
+ )
63
+
64
+ if PDF_COLLECTION_NAME not in existing_collections:
65
+ qdrant.recreate_collection(
66
+ collection_name=PDF_COLLECTION_NAME,
67
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE),
68
+ timeout=1200
69
+ )
70
+ qdrant.create_payload_index(
71
+ collection_name=PDF_COLLECTION_NAME,
72
+ field_name="document_id",
73
+ field_schema="keyword"
74
+ )
75
+
76
+ def generate_session_id():
77
+ """Generate a unique session ID"""
78
+ return str(uuid.uuid4())
79
+
80
+ def store_message(session_id, role, message):
81
+ """Store message in both database and vector store"""
82
+ db = SessionLocal()
83
+ chat_record = ChatMessage(session_id=session_id, role=role, message=message)
84
+ db.add(chat_record)
85
+ db.commit()
86
+ db.refresh(chat_record)
87
+ db.close()
88
+
89
+ # Store in vector database
90
+ embedding = embedder.encode(message).tolist()
91
+ point = PointStruct(
92
+ id=int(uuid.uuid4().int % 1e12),
93
+ vector=embedding,
94
+ payload={
95
+ "session_id": session_id,
96
+ "role": role,
97
+ "message": message,
98
+ "timestamp": int(time.time())
99
+ }
100
+ )
101
+ qdrant.upsert(collection_name=COLLECTION_NAME, points=[point])
102
+
103
+ # Clean up old messages
104
+ existing = qdrant.scroll(
105
+ collection_name=COLLECTION_NAME,
106
+ scroll_filter=Filter(must=[
107
+ FieldCondition(key="session_id", match=MatchValue(value=session_id))
108
+ ]),
109
+ limit=100,
110
+ with_payload=True
111
+ )
112
+ if len(existing[0]) > MAX_HISTORY_LENGTH:
113
+ old_points = sorted(existing[0], key=lambda x: x.payload.get("timestamp", 0))
114
+ old_ids = [p.id for p in old_points[:-MAX_HISTORY_LENGTH]]
115
+ qdrant.delete(
116
+ collection_name=COLLECTION_NAME,
117
+ points_selector=PointIdsList(points=old_ids)
118
+ )
119
+
120
+ @lru_cache(maxsize=SUMMARY_CACHE_SIZE)
121
+ def get_conversation_summary(session_id):
122
+ """Generate a concise summary of the conversation"""
123
+ db = SessionLocal()
124
+ messages = db.query(ChatMessage).filter(
125
+ ChatMessage.session_id == session_id
126
+ ).order_by(ChatMessage.id).all()
127
+ db.close()
128
+
129
+ if not messages:
130
+ return "No previous conversation history"
131
+
132
+ conversation = "\n".join(
133
+ f"{msg.role}: {msg.message}" for msg in messages[-10:]
134
+ )
135
+
136
+ summary_prompt = (
137
+ "Create a very concise summary (1-2 sentences max) focusing on:\n"
138
+ "1. Main topic being discussed\n"
139
+ "2. Any specific numbers/dates mentioned\n"
140
+ "3. The most recent question\n\n"
141
+ "Conversation:\n" + conversation
142
+ )
143
+
144
+ try:
145
+ response = client.chat.completions.create(
146
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
147
+ messages=[{"role": "user", "content": summary_prompt}],
148
+ temperature=0.3,
149
+ max_tokens=100
150
+ )
151
+ return response.choices[0].message.content.strip()
152
+ except Exception as e:
153
+ print(f"Summary generation failed: {e}")
154
+ return "Current conversation context unavailable"
155
+
156
+ def get_session_history(session_id):
157
+ """Retrieve conversation history from vector store"""
158
+ result = qdrant.scroll(
159
+ collection_name=COLLECTION_NAME,
160
+ scroll_filter=Filter(must=[
161
+ FieldCondition(key="session_id", match=MatchValue(value=session_id))
162
+ ]),
163
+ limit=MAX_HISTORY_LENGTH,
164
+ with_payload=True
165
+ )
166
+ messages = sorted(result[0], key=lambda x: x.payload.get("timestamp", 0))
167
+ return [{"role": p.payload["role"], "content": p.payload["message"]} for p in messages]
168
+
169
+ def extract_pdf_content(pdf_path):
170
+ """Extract text and images from PDF"""
171
+ full_text = ""
172
+ images = []
173
+
174
+ with pdfplumber.open(pdf_path) as pdf:
175
+ for page in pdf.pages:
176
+ page_text = page.extract_text()
177
+ if page_text:
178
+ full_text += page_text + "\n\n"
179
+
180
+ tables = page.extract_tables()
181
+ for table in tables:
182
+ formatted_table = tabulate(table, headers="firstrow", tablefmt="grid")
183
+ full_text += f"\n\nTABLE:\n{formatted_table}\n\n"
184
+
185
+ if page.images:
186
+ page_image = page.to_image(resolution=300)
187
+ for img in page.images:
188
+ try:
189
+ bbox = (img["x0"], img["top"], img["x1"], img["bottom"])
190
+ cropped = page_image.original.crop(bbox)
191
+ images.append(cropped)
192
+ except Exception as e:
193
+ print(f"Image extraction failed: {e}")
194
+
195
+ return full_text, images
196
+
197
+ def extract_chart_data(image: Image.Image) -> str:
198
+ """Extract text from chart images using OCR"""
199
+ try:
200
+ image = image.convert("L")
201
+ image = image.filter(ImageFilter.SHARPEN)
202
+ enhancer = ImageEnhance.Contrast(image)
203
+ image = enhancer.enhance(2.0)
204
+
205
+ chart_text = pytesseract.image_to_string(image, config="--psm 6")
206
+
207
+ if chart_text.strip():
208
+ return f"Chart contains: {chart_text.strip()}"
209
+ else:
210
+ width, height = image.size
211
+ return f"Visual chart approximately {width}x{height} pixels with data points"
212
+ except Exception as e:
213
+ return f"[Chart content could not be extracted: {str(e)}]"
214
+
215
+ def extract_charts_with_deplot(pdf_path: str, document_id: str, chunk_size: int = 500):
216
+ """
217
+ Extract charts from PDF using DePlot and store in vector database
218
+
219
+ Args:
220
+ pdf_path: Path to PDF file
221
+ document_id: Unique document identifier
222
+ chunk_size: Size for text chunks
223
+
224
+ Returns:
225
+ List of processing results
226
+ """
227
+ doc = fitz.open(pdf_path)
228
+ results = []
229
+
230
+ for page_num in range(len(doc)):
231
+ page = doc[page_num]
232
+ image_list = page.get_images(full=True)
233
+
234
+ for img_index, img in enumerate(image_list):
235
+ try:
236
+ # Extract and process image
237
+ xref = img[0]
238
+ base_image = doc.extract_image(xref)
239
+ image_bytes = base_image["image"]
240
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
241
+
242
+ # Extract table data
243
+ text_table = "Extract all data from this chart in table format with clear headers."
244
+ inputs_table = deplot_processor(images=image, text=text_table, return_tensors="pt").to(device)
245
+ table_ids = deplot_model.generate(**inputs_table, max_new_tokens=512)
246
+ table_data = deplot_processor.decode(table_ids[0], skip_special_tokens=True)
247
+
248
+ # Generate summary
249
+ text_summary = ("Provide a comprehensive summary of this chart including: "
250
+ "1. Chart title and type, 2. Key trends and patterns, "
251
+ "3. Notable data points, 4. Overall conclusion.")
252
+ inputs_summary = deplot_processor(images=image, text=text_summary, return_tensors="pt").to(device)
253
+ summary_ids = deplot_model.generate(**inputs_summary, max_new_tokens=512)
254
+ chart_summary = deplot_processor.decode(summary_ids[0], skip_special_tokens=True)
255
+
256
+ # Create and store chunks
257
+ combined_content = f"CHART SUMMARY:\n{chart_summary}\n\nEXTRACTED DATA:\n{table_data}"
258
+ chunks = []
259
+ current_chunk = ""
260
+
261
+ for para in [p for p in combined_content.split('\n') if p.strip()]:
262
+ if len(current_chunk) + len(para) + 1 <= chunk_size:
263
+ current_chunk += para + "\n"
264
+ else:
265
+ if current_chunk:
266
+ chunks.append(current_chunk.strip())
267
+ current_chunk = para + "\n"
268
+
269
+ if current_chunk:
270
+ chunks.append(current_chunk.strip())
271
+
272
+ # Store chunks in vector database
273
+ points = []
274
+ for i, chunk in enumerate(chunks):
275
+ embedding = embedder.encode(chunk).tolist()
276
+ point = PointStruct(
277
+ id=int(uuid.uuid4().int % 1e12),
278
+ vector=embedding,
279
+ payload={
280
+ "document_id": document_id,
281
+ "page": page_num + 1,
282
+ "image_index": img_index + 1,
283
+ "type": "chart_chunk",
284
+ "chunk_index": i,
285
+ "total_chunks": len(chunks),
286
+ "content": chunk,
287
+ "full_summary": chart_summary,
288
+ "full_table": table_data
289
+ }
290
+ )
291
+ points.append(point)
292
+
293
+ if points:
294
+ qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=points)
295
+
296
+ results.append({
297
+ "page": page_num + 1,
298
+ "image_index": img_index + 1,
299
+ "summary": chart_summary,
300
+ "table_data": table_data,
301
+ "num_chunks": len(chunks)
302
+ })
303
+
304
+ except Exception as e:
305
+ print(f"❌ Error processing page {page_num+1} image {img_index+1}: {str(e)}")
306
+ continue
307
+
308
+ return results
309
+
310
+ def store_pdf_chunks(text: str, document_id: str):
311
+ """Store PDF text content in vector database"""
312
+ paragraphs = text.split('\n\n')
313
+ chunks = []
314
+ current_chunk = ""
315
+
316
+ for para in paragraphs:
317
+ if len(current_chunk) + len(para) < 1000:
318
+ current_chunk += para + "\n\n"
319
+ else:
320
+ chunks.append(current_chunk.strip())
321
+ current_chunk = para + "\n\n"
322
+ if current_chunk:
323
+ chunks.append(current_chunk.strip())
324
+
325
+ for chunk in chunks:
326
+ embedding = embedder.encode(chunk).tolist()
327
+ point = PointStruct(
328
+ id=int(uuid.uuid4().int % 1e12),
329
+ vector=embedding,
330
+ payload={
331
+ "document_id": document_id,
332
+ "content": chunk,
333
+ "source": "pdf"
334
+ }
335
+ )
336
+ qdrant.upsert(collection_name=PDF_COLLECTION_NAME, points=[point])
337
+
338
+ def process_pdf(pdf_path: str):
339
+ """Process a PDF file and store its contents"""
340
+ text, images = extract_pdf_content(pdf_path)
341
+ ocr_text = ""
342
+ chart_summaries = []
343
+
344
+ for i, image in enumerate(images):
345
+ try:
346
+ ocr_text += pytesseract.image_to_string(image)
347
+ chart_summary = extract_chart_data(image)
348
+ chart_summaries.append(f"Chart {i+1}: {chart_summary}")
349
+ except Exception as e:
350
+ print(f"Image processing failed: {e}")
351
+ chart_summaries.append(f"Chart {i+1}: [Content not extracted]")
352
+
353
+ full_text = (
354
+ "PDF TEXT CONTENT:\n" + text +
355
+ "\n\nIMAGE TEXT CONTENT:\n" + ocr_text +
356
+ "\n\nCHART SUMMARIES:\n" + "\n".join(chart_summaries)
357
+ )
358
+
359
+ document_id = os.path.basename(pdf_path)
360
+ store_pdf_chunks(full_text, document_id)
361
+
362
+ # Process charts with DePlot
363
+ deplot_results = extract_charts_with_deplot(pdf_path, document_id)
364
+ print(f"✅ DePlot processed {len(deplot_results)} charts")
365
+
366
+ def get_relevant_context(user_message: str, session_id: str):
367
+ """Retrieve relevant context from vector stores"""
368
+ question_embedding = embedder.encode(user_message).tolist()
369
+
370
+ # Search PDF content
371
+ pdf_results = qdrant.search(
372
+ collection_name=PDF_COLLECTION_NAME,
373
+ query_vector=question_embedding,
374
+ limit=10,
375
+ score_threshold=0.4
376
+ )
377
+
378
+ # Get conversation history
379
+ history = get_session_history(session_id)
380
+ recent_history = history[-3:]
381
+
382
+ pdf_context = "\n".join([hit.payload.get("content", "") for hit in pdf_results])
383
+ history_context = "\n".join([msg["content"] for msg in recent_history])
384
+
385
+ return pdf_context, history_context
386
+
387
+ def get_verified_context(session_id):
388
+ """Retrieve messages containing numerical data"""
389
+ db = SessionLocal()
390
+ messages = db.query(ChatMessage).filter(
391
+ ChatMessage.session_id == session_id
392
+ ).order_by(ChatMessage.id.desc()).limit(10).all()
393
+ db.close()
394
+
395
+ return [msg for msg in messages if any(char.isdigit() for char in msg.message)]
396
+
397
+ def chat_with_session(session_id, user_message):
398
+ """Main chat function with context-aware responses"""
399
+ try:
400
+ uuid_obj = uuid.UUID(session_id)
401
+ except ValueError:
402
+ return "❌ Invalid session ID format. Please generate a valid session."
403
+
404
+ # Get all context sources
405
+ conversation_summary = get_conversation_summary(session_id)
406
+ pdf_context, history_context = get_relevant_context(user_message, session_id)
407
+ verified_contexts = get_verified_context(session_id)
408
+ verified_text = "\n".join([msg.message for msg in verified_contexts])
409
+
410
+ # Construct system prompt
411
+ system_prompt = (
412
+ "You are a context-aware assistant. Follow these rules strictly:\n"
413
+ "1. CONVERSATION SUMMARY:\n" + conversation_summary + "\n\n"
414
+ "2. Maintain context for follow-up questions\n"
415
+ "3. DOCUMENT CONTEXT:\n" + (pdf_context if pdf_context else "None") + "\n\n"
416
+ "4. VERIFIED NUMERICAL CONTEXT:\n" + (verified_text if verified_text else "None") + "\n\n"
417
+ "5. Respond clearly and concisely to the latest user query while maintaining continuity.\n"
418
+ )
419
+
420
+ # Prepare messages for LLM
421
+ messages = [{"role": "system", "content": system_prompt}]
422
+ messages.extend(get_session_history(session_id)[-3:])
423
+ messages.append({"role": "user", "content": user_message})
424
+
425
+ try:
426
+ completion = client.chat.completions.create(
427
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
428
+ messages=messages,
429
+ temperature=0.7,
430
+ max_tokens=1024,
431
+ top_p=0.9
432
+ )
433
+ reply = completion.choices[0].message.content
434
+ except Exception as e:
435
+ print(f"❌ LLM generation failed: {e}")
436
+ return "Sorry, I couldn't generate a response at this time."
437
+
438
+ # Store conversation
439
+ store_message(session_id, "user", user_message)
440
+ store_message(session_id, "assistant", reply)
441
+ get_conversation_summary.cache_clear()
442
+
443
+ return reply
444
+
445
+ # Initialize collections on startup
446
+ create_collections()
database.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime
2
+ from sqlalchemy.orm import declarative_base, sessionmaker
3
+ from datetime import datetime
4
+
5
+ Base = declarative_base()
6
+
7
+ class ChatMessage(Base):
8
+ __tablename__ = "chat_messages"
9
+ id = Column(Integer, primary_key=True, index=True)
10
+ session_id = Column(String, index=True)
11
+ role = Column(String)
12
+ message = Column(Text)
13
+ timestamp = Column(DateTime, default=datetime.utcnow)
14
+
15
+ # SQLite engine and session
16
+ engine = create_engine("sqlite:///./chat_history.db", connect_args={"check_same_thread": False})
17
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
18
+
19
+ # Create tables
20
+ Base.metadata.create_all(bind=engine)
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from chatbot import generate_session_id, chat_with_session, process_pdf, create_collections
5
+
6
+ import os
7
+
8
+ app = FastAPI()
9
+
10
+ # Enable CORS
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # Ensure the upload directory exists
19
+ os.makedirs("uploaded_files", exist_ok=True)
20
+
21
+ class AskRequest(BaseModel):
22
+ session_id: str
23
+ question: str
24
+
25
+ # Endpoint to generate new session ID
26
+ @app.get("/get_session")
27
+ def get_session():
28
+ return {"session_id": generate_session_id()}
29
+
30
+ # Endpoint to ask questions with session ID
31
+ @app.post("/ask")
32
+ def ask(request: AskRequest):
33
+ response = chat_with_session(request.session_id, request.question)
34
+ return {"answer": response}
35
+
36
+ # WebSocket chat endpoint
37
+ @app.websocket("/ws/{session_id}")
38
+ async def websocket_chat(websocket: WebSocket, session_id: str):
39
+ await websocket.accept()
40
+ try:
41
+ while True:
42
+ question = await websocket.receive_text()
43
+ reply = chat_with_session(session_id, question)
44
+ await websocket.send_text(reply)
45
+ except WebSocketDisconnect:
46
+ print(f"❌ Client {session_id} disconnected")
47
+
48
+ # ✅ New endpoint to upload PDF
49
+ @app.post("/upload_pdf")
50
+ async def upload_pdf(file: UploadFile = File(...)):
51
+ if not file.filename.endswith(".pdf"):
52
+ return {"error": "Only PDF files are allowed."}
53
+
54
+ file_path = f"uploaded_files/{file.filename}"
55
+ with open(file_path, "wb") as f:
56
+ f.write(await file.read())
57
+
58
+ # Process PDF (text extraction, image OCR, embeddings, etc.)
59
+ process_pdf(file_path)
60
+
61
+ return {"message": "PDF uploaded and processed successfully."}
62
+
63
+ @app.on_event("startup")
64
+ def startup_event():
65
+ create_collections()
models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String
2
+ from database import Base
3
+
4
+ class TopicSummary(Base):
5
+ __tablename__ = "topic_summaries"
6
+ id = Column(Integer, primary_key=True, index=True)
7
+ session_id = Column(String, index=True)
8
+ summary = Column(String)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pytesseract
4
+ python-dotenv
5
+ pdf2image
6
+ opencv-python
7
+ PyMuPDF
8
+ groq
9
+ chromadb
10
+ sentence-transformers
11
+ qdrant-client