ak0601 commited on
Commit
6362b0c
·
verified ·
1 Parent(s): 490fed8

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chat_sessions.db filter=lfs diff=lfs merge=lfs -text
37
+ chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Law RAG Chatbot API using FastAPI, Langchain, Groq, and ChromaDB
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import logging
8
+ from typing import List, Dict, Any, Optional
9
+ from fastapi import FastAPI, HTTPException, Depends, Header
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ import uvicorn
13
+
14
+ from config import *
15
+ from rag_system import RAGSystem
16
+ from session_manager import SessionManager
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize FastAPI app
23
+ app = FastAPI(
24
+ title=API_TITLE,
25
+ version=API_VERSION,
26
+ description=API_DESCRIPTION
27
+ )
28
+
29
+ # Add CORS middleware
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # Pydantic models for API requests/responses
39
+ class ChatRequest(BaseModel):
40
+ question: str
41
+ context_length: int = 5 # Increased default context length
42
+ session_id: Optional[str] = None
43
+
44
+ class ChatResponse(BaseModel):
45
+ answer: str
46
+ sources: List[Dict[str, Any]]
47
+ confidence: float
48
+ processing_time: float
49
+ question: str
50
+ session_id: str
51
+ chat_history_count: int
52
+
53
+ class SessionCreateRequest(BaseModel):
54
+ user_info: Optional[str] = None
55
+ metadata: Optional[Dict[str, Any]] = None
56
+
57
+ class SessionResponse(BaseModel):
58
+ session_id: str
59
+ created_at: str
60
+ user_info: str
61
+ metadata: Dict[str, Any]
62
+
63
+ class HealthResponse(BaseModel):
64
+ status: str
65
+ message: str
66
+ components: Dict[str, str]
67
+
68
+ # Global instances
69
+ rag_system: RAGSystem = None
70
+ session_manager: SessionManager = None
71
+
72
+ @app.on_event("startup")
73
+ async def startup_event():
74
+ """Initialize the RAG system and session manager on startup"""
75
+ global rag_system, session_manager
76
+ try:
77
+ logger.info("Initializing RAG system...")
78
+ rag_system = RAGSystem()
79
+ await rag_system.initialize()
80
+ logger.info("RAG system initialized successfully")
81
+
82
+ logger.info("Initializing session manager...")
83
+ session_manager = SessionManager()
84
+ logger.info("Session manager initialized successfully")
85
+
86
+ except Exception as e:
87
+ logger.error(f"Failed to initialize systems: {e}")
88
+ raise
89
+
90
+ @app.get("/", response_model=HealthResponse)
91
+ async def root():
92
+ """Root endpoint with health check"""
93
+ return HealthResponse(
94
+ status="healthy",
95
+ message="Law RAG Chatbot API is running",
96
+ components={
97
+ "api": "running",
98
+ "rag_system": "running" if rag_system else "not_initialized",
99
+ "session_manager": "running" if session_manager else "not_initialized",
100
+ "vector_db": "connected" if rag_system and rag_system.is_ready() else "disconnected"
101
+ }
102
+ )
103
+
104
+ @app.get("/health", response_model=HealthResponse)
105
+ async def health_check():
106
+ """Health check endpoint"""
107
+ if not rag_system:
108
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
109
+
110
+ if not rag_system.is_ready():
111
+ raise HTTPException(status_code=503, detail="RAG system not ready")
112
+
113
+ if not session_manager:
114
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
115
+
116
+ return HealthResponse(
117
+ status="healthy",
118
+ message="All systems operational",
119
+ components={
120
+ "api": "running",
121
+ "rag_system": "ready",
122
+ "session_manager": "ready",
123
+ "vector_db": "connected",
124
+ "embeddings": "ready",
125
+ "llm": "ready"
126
+ }
127
+ )
128
+
129
+ @app.post("/sessions", response_model=SessionResponse)
130
+ async def create_session(request: SessionCreateRequest):
131
+ """Create a new chat session"""
132
+ if not session_manager:
133
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
134
+
135
+ try:
136
+ session_id = session_manager.create_session(
137
+ user_info=request.user_info,
138
+ metadata=request.metadata
139
+ )
140
+
141
+ session = session_manager.get_session(session_id)
142
+
143
+ return SessionResponse(
144
+ session_id=session_id,
145
+ created_at=session["created_at"].isoformat(),
146
+ user_info=session["user_info"],
147
+ metadata=session["metadata"]
148
+ )
149
+
150
+ except Exception as e:
151
+ logger.error(f"Error creating session: {e}")
152
+ raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}")
153
+
154
+ @app.get("/sessions/{session_id}", response_model=Dict[str, Any])
155
+ async def get_session_info(session_id: str):
156
+ """Get session information and statistics"""
157
+ if not session_manager:
158
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
159
+
160
+ try:
161
+ session_stats = session_manager.get_session_stats(session_id)
162
+
163
+ if not session_stats:
164
+ raise HTTPException(status_code=404, detail="Session not found")
165
+
166
+ return session_stats
167
+
168
+ except HTTPException:
169
+ raise
170
+ except Exception as e:
171
+ logger.error(f"Error getting session info: {e}")
172
+ raise HTTPException(status_code=500, detail=f"Failed to get session info: {str(e)}")
173
+
174
+ @app.get("/sessions/{session_id}/history")
175
+ async def get_chat_history(session_id: str, limit: int = 10):
176
+ """Get chat history for a session"""
177
+ if not session_manager:
178
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
179
+
180
+ try:
181
+ history = session_manager.get_chat_history(session_id, limit)
182
+ return {
183
+ "session_id": session_id,
184
+ "history": history,
185
+ "total": len(history)
186
+ }
187
+
188
+ except Exception as e:
189
+ logger.error(f"Error getting chat history: {e}")
190
+ raise HTTPException(status_code=500, detail=f"Failed to get chat history: {str(e)}")
191
+
192
+ @app.post("/chat", response_model=ChatResponse)
193
+ async def chat(request: ChatRequest):
194
+ """Main chat endpoint for legal questions with session support"""
195
+ if not rag_system:
196
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
197
+
198
+ if not rag_system.is_ready():
199
+ raise HTTPException(status_code=503, detail="RAG system not ready")
200
+
201
+ if not session_manager:
202
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
203
+
204
+ try:
205
+ start_time = time.time()
206
+
207
+ # Handle session
208
+ if not request.session_id:
209
+ # Create new session if none provided
210
+ request.session_id = session_manager.create_session()
211
+ logger.info(f"Created new session: {request.session_id}")
212
+
213
+ # Verify session exists
214
+ session = session_manager.get_session(request.session_id)
215
+ if not session:
216
+ raise HTTPException(status_code=404, detail="Session not found")
217
+
218
+ # Get response from RAG system
219
+ response = await rag_system.get_response(
220
+ question=request.question,
221
+ context_length=request.context_length
222
+ )
223
+
224
+ processing_time = time.time() - start_time
225
+
226
+ # Store chat response in session
227
+ session_manager.store_chat_response(
228
+ session_id=request.session_id,
229
+ question=request.question,
230
+ answer=response["answer"],
231
+ sources=response["sources"],
232
+ confidence=response["confidence"],
233
+ processing_time=processing_time
234
+ )
235
+
236
+ # Get chat history count
237
+ chat_history = session_manager.get_chat_history(request.session_id, limit=1)
238
+ chat_history_count = len(chat_history)
239
+
240
+ return ChatResponse(
241
+ answer=response["answer"],
242
+ sources=response["sources"],
243
+ confidence=response["confidence"],
244
+ processing_time=processing_time,
245
+ question=request.question,
246
+ session_id=request.session_id,
247
+ chat_history_count=chat_history_count
248
+ )
249
+
250
+ except HTTPException:
251
+ raise
252
+ except Exception as e:
253
+ logger.error(f"Error processing chat request: {e}")
254
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
255
+
256
+ @app.get("/search")
257
+ async def search(query: str, limit: int = 5, session_id: Optional[str] = None):
258
+ """Search for relevant legal documents with optional session tracking"""
259
+ if not rag_system:
260
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
261
+
262
+ try:
263
+ results = await rag_system.search_documents(query, limit)
264
+
265
+ # Store search query if session provided
266
+ if session_id and session_manager:
267
+ session_manager.store_search_query(session_id, query, len(results))
268
+
269
+ return {
270
+ "query": query,
271
+ "results": results,
272
+ "total": len(results),
273
+ "session_id": session_id
274
+ }
275
+ except Exception as e:
276
+ logger.error(f"Error in search: {e}")
277
+ raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
278
+
279
+ @app.get("/stats")
280
+ async def get_stats():
281
+ """Get system statistics"""
282
+ if not rag_system:
283
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
284
+
285
+ try:
286
+ rag_stats = await rag_system.get_stats()
287
+
288
+ # Add session statistics
289
+ if session_manager:
290
+ # Get total sessions count (this would need to be implemented in session manager)
291
+ session_stats = {
292
+ "session_manager": "active",
293
+ "total_sessions": "available" # Could implement actual count
294
+ }
295
+ else:
296
+ session_stats = {"session_manager": "not_initialized"}
297
+
298
+ return {**rag_stats, **session_stats}
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error getting stats: {e}")
302
+ raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
303
+
304
+ @app.post("/reindex")
305
+ async def reindex():
306
+ """Reindex the vector database"""
307
+ if not rag_system:
308
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
309
+
310
+ try:
311
+ await rag_system.reindex()
312
+ return {"message": "Reindexing completed successfully"}
313
+ except Exception as e:
314
+ logger.error(f"Error in reindexing: {e}")
315
+ raise HTTPException(status_code=500, detail=f"Reindexing failed: {str(e)}")
316
+
317
+ @app.delete("/sessions/{session_id}")
318
+ async def delete_session(session_id: str):
319
+ """Delete a session and all its data"""
320
+ if not session_manager:
321
+ raise HTTPException(status_code=503, detail="Session manager not initialized")
322
+
323
+ try:
324
+ session_manager.delete_session(session_id)
325
+ return {"message": f"Session {session_id} deleted successfully"}
326
+
327
+ except Exception as e:
328
+ logger.error(f"Error deleting session: {e}")
329
+ raise HTTPException(status_code=500, detail=f"Failed to delete session: {str(e)}")
330
+
331
+ if __name__ == "__main__":
332
+ uvicorn.run(
333
+ "app:app",
334
+ host=HOST,
335
+ port=PORT,
336
+ reload=True,
337
+ log_level="info"
338
+ )
339
+
340
+
341
+
342
+
343
+
344
+
345
+
chat_sessions.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5d3522a84d269cac0974a1b6dfb44c2a692d54d30cb525f492f409f5d357124
3
+ size 516096
chroma_db/ba961ef0-bc79-4ac2-97c0-6cb89b88a421/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdf80b989177c39e67b75605633a79a36c15f05a0759a7604aad9de8e9456ae9
3
+ size 16760000
chroma_db/ba961ef0-bc79-4ac2-97c0-6cb89b88a421/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3224fa57da8a778a9fd40be4bfce5cf573b3641f3ee57850ad29a1f05ce9a61
3
+ size 100
chroma_db/ba961ef0-bc79-4ac2-97c0-6cb89b88a421/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5bbfd4dec3468174919140b7ce1c5ddea2dd93f5f1b37828e0c85572c87e403
3
+ size 287384
chroma_db/ba961ef0-bc79-4ac2-97c0-6cb89b88a421/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05281421a401d4dcb852084dcc03c528cd38abae62b80d715d78f8891643674a
3
+ size 40000
chroma_db/ba961ef0-bc79-4ac2-97c0-6cb89b88a421/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f136f8d41ef27e5e51ab3bd71d5a9d56676c691236e1dfdd612b90837ae20832
3
+ size 64104
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:127fdf8c5fc37542506a3523d855c71e628158fe51792a8d170542a9ab4cae4f
3
+ size 47648768
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn>=0.24.0
3
+ langchain>=0.1.0
4
+ langchain-groq>=0.0.1
5
+ langchain-community>=0.0.10
6
+ langchain-core>=0.1.0
7
+ groq>=0.4.0
8
+ sentence-transformers>=2.2.0
9
+ chromadb>=0.4.0
10
+ datasets>=2.14.0
11
+ huggingface-hub>=0.16.0
12
+ pydantic>=2.0.0
13
+ python-multipart>=0.0.6
14
+ python-dotenv>=1.0.0
15
+ numpy>=1.24.0
16
+ pandas>=2.0.0
17
+ requests>=2.31.0
18
+ tiktoken>=0.5.0