NitinBot001 commited on
Commit
29d1672
Β·
verified Β·
1 Parent(s): 8418a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +633 -600
app.py CHANGED
@@ -1,601 +1,634 @@
1
- import os
2
- import time
3
- import gradio as gr
4
- import uvicorn
5
- from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
6
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
- from pydantic import BaseModel
8
- from typing import Optional, Dict, Any
9
- import threading
10
- import logging
11
- from langchain.document_loaders import TextLoader
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain.vectorstores import FAISS
14
- from langchain.chains import RetrievalQA
15
- from langchain.prompts import PromptTemplate
16
- from langchain.callbacks.base import BaseCallbackHandler
17
- from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
18
- import tiktoken
19
-
20
- # Configure logging
21
- logging.basicConfig(level=logging.INFO)
22
- logger = logging.getLogger(__name__)
23
-
24
- # --- Configuration ---
25
- CHUNK_SIZE = 800
26
- CHUNK_OVERLAP = 100
27
- MAX_TOKENS = 512
28
- TEMPERATURE = 0.5
29
- RETRIEVAL_K = 5
30
-
31
- # --- Token Counting Setup ---
32
- try:
33
- tokenizer = tiktoken.get_encoding("cl100k_base")
34
- except:
35
- print("Tiktoken encoder 'cl100k_base' not found. Using basic split().")
36
- tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})()
37
-
38
- def estimate_tokens(text):
39
- """Estimates token count for a given text."""
40
- return len(tokenizer.encode(text))
41
-
42
- # Custom Callback Handler to track LLM token usage
43
- class TokenUsageCallbackHandler(BaseCallbackHandler):
44
- """Callback handler to track token usage in LLM calls."""
45
- def __init__(self):
46
- super().__init__()
47
- self.reset_counters()
48
-
49
- def reset_counters(self):
50
- self.total_prompt_tokens = 0
51
- self.total_completion_tokens = 0
52
- self.total_llm_calls = 0
53
-
54
- def on_llm_end(self, response, **kwargs):
55
- """Collect token usage from the LLM response."""
56
- self.total_llm_calls += 1
57
- llm_output = response.llm_output
58
-
59
- if llm_output and 'usage_metadata' in llm_output:
60
- usage = llm_output['usage_metadata']
61
- prompt_tokens = usage.get('prompt_token_count', 0)
62
- completion_tokens = usage.get('candidates_token_count', 0)
63
-
64
- self.total_prompt_tokens += prompt_tokens
65
- self.total_completion_tokens += completion_tokens
66
-
67
- def get_total_tokens(self):
68
- """Returns the total prompt and completion tokens."""
69
- return {
70
- "total_prompt_tokens": self.total_prompt_tokens,
71
- "total_completion_tokens": self.total_completion_tokens,
72
- "total_llm_tokens": self.total_prompt_tokens + self.total_completion_tokens,
73
- "total_llm_calls": self.total_llm_calls
74
- }
75
-
76
- # --- Pydantic Models for API ---
77
- class InitializeRequest(BaseModel):
78
- api_key: str
79
- document_content: Optional[str] = None
80
-
81
- class QueryRequest(BaseModel):
82
- query: str
83
- api_key: str
84
-
85
- class InitializeResponse(BaseModel):
86
- success: bool
87
- message: str
88
- chunks: Optional[int] = None
89
- estimated_tokens: Optional[int] = None
90
-
91
- class QueryResponse(BaseModel):
92
- success: bool
93
- answer: str
94
- response_time: float
95
- query_tokens: int
96
- llm_tokens: Dict[str, int]
97
- session_stats: Dict[str, int]
98
-
99
- class StatsResponse(BaseModel):
100
- total_queries: int
101
- total_embedding_tokens: int
102
- total_llm_tokens: int
103
- total_llm_calls: int
104
- initialization_complete: bool
105
-
106
- # --- Global Variables ---
107
- class RAGSystem:
108
- def __init__(self):
109
- self.vector_store = None
110
- self.qa_chain = None
111
- self.token_callback_handler = TokenUsageCallbackHandler()
112
- self.session_stats = {
113
- "total_queries": 0,
114
- "total_embedding_tokens": 0,
115
- "initialization_complete": False
116
- }
117
- self.current_api_key = None
118
-
119
- # Global RAG system instance
120
- rag_system = RAGSystem()
121
-
122
- def initialize_rag_system(api_key, file_content=None):
123
- """Initialize the RAG system with API key and optional file content."""
124
- global rag_system
125
-
126
- try:
127
- # Set API key
128
- os.environ["GOOGLE_API_KEY"] = api_key
129
- rag_system.current_api_key = api_key
130
-
131
- # Initialize embeddings
132
- embeddings = GoogleGenerativeAIEmbeddings(
133
- model="models/embedding-001",
134
- google_api_key=api_key
135
- )
136
-
137
- # Initialize LLM
138
- llm = ChatGoogleGenerativeAI(
139
- model="gemini-1.5-flash",
140
- google_api_key=api_key,
141
- temperature=TEMPERATURE,
142
- max_tokens=MAX_TOKENS,
143
- callbacks=[rag_system.token_callback_handler],
144
- verbose=False
145
- )
146
-
147
- # Load or use default document
148
- if file_content:
149
- # Save uploaded file content
150
- with open("uploaded_document.txt", "w", encoding="utf-8") as f:
151
- f.write(file_content)
152
- loader = TextLoader("uploaded_document.txt")
153
- else:
154
- # Check if default maize_data.txt exists
155
- if os.path.exists("maize_data.txt"):
156
- loader = TextLoader("maize_data.txt")
157
- else:
158
- return "❌ No document found. Please upload a file or ensure maize_data.txt exists."
159
-
160
- # Load and split documents
161
- documents = loader.load()
162
- text_splitter = RecursiveCharacterTextSplitter(
163
- chunk_size=CHUNK_SIZE,
164
- chunk_overlap=CHUNK_OVERLAP
165
- )
166
- chunks = text_splitter.split_documents(documents)
167
-
168
- # Estimate embedding tokens
169
- initial_embedding_tokens = sum(estimate_tokens(chunk.page_content) for chunk in chunks)
170
- rag_system.session_stats["total_embedding_tokens"] = initial_embedding_tokens
171
-
172
- # Create vector store
173
- rag_system.vector_store = FAISS.from_documents(chunks, embeddings)
174
-
175
- # Create prompt template
176
- prompt_template = PromptTemplate(
177
- input_variables=["context", "question"],
178
- template="""
179
- You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully. If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question.".
180
-
181
- Context:
182
- {context}
183
-
184
- Question: {question}
185
-
186
- Answer:"""
187
- )
188
-
189
- # Set up QA chain
190
- rag_system.qa_chain = RetrievalQA.from_chain_type(
191
- llm=llm,
192
- chain_type="stuff",
193
- retriever=rag_system.vector_store.as_retriever(search_kwargs={"k": RETRIEVAL_K}),
194
- chain_type_kwargs={"prompt": prompt_template},
195
- callbacks=[rag_system.token_callback_handler],
196
- return_source_documents=True
197
- )
198
-
199
- rag_system.session_stats["initialization_complete"] = True
200
-
201
- return f"βœ… RAG system initialized successfully!\nπŸ“„ Document processed: {len(chunks)} chunks\nπŸ”’ Estimated embedding tokens: ~{initial_embedding_tokens}"
202
-
203
- except Exception as e:
204
- logger.error(f"Initialization failed: {str(e)}")
205
- return f"❌ Initialization failed: {str(e)}"
206
-
207
- def process_query(query, api_key):
208
- """Process a user query through the RAG system."""
209
- global rag_system
210
-
211
- if not api_key:
212
- return "❌ Please provide a Google API key first.", ""
213
-
214
- if not rag_system.qa_chain:
215
- return "❌ RAG system not initialized. Please initialize first.", ""
216
-
217
- if not query.strip():
218
- return "❌ Please enter a question.", ""
219
-
220
- try:
221
- # Estimate query embedding tokens
222
- query_tokens = estimate_tokens(query)
223
- rag_system.session_stats["total_embedding_tokens"] += query_tokens
224
- rag_system.session_stats["total_queries"] += 1
225
-
226
- # Process query
227
- start_time = time.time()
228
- result = rag_system.qa_chain({"query": query})
229
- end_time = time.time()
230
-
231
- # Get token usage
232
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
233
-
234
- # Format response
235
- answer = result['result']
236
-
237
- # Create stats summary
238
- stats = f"""
239
- πŸ“Š **Query Statistics:**
240
- - Response time: {end_time - start_time:.2f} seconds
241
- - Query tokens (estimated): ~{query_tokens}
242
- - LLM tokens (this query): Prompt: {llm_tokens['total_prompt_tokens']}, Completion: {llm_tokens['total_completion_tokens']}
243
-
244
- πŸ“ˆ **Session Statistics:**
245
- - Total queries: {rag_system.session_stats['total_queries']}
246
- - Total embedding tokens: ~{rag_system.session_stats['total_embedding_tokens']}
247
- - Total LLM calls: {llm_tokens['total_llm_calls']}
248
- - Total LLM tokens: {llm_tokens['total_llm_tokens']}
249
- """
250
-
251
- return answer, stats
252
-
253
- except Exception as e:
254
- logger.error(f"Error processing query: {str(e)}")
255
- return f"❌ Error processing query: {str(e)}", ""
256
-
257
- def upload_file_and_initialize(api_key, file):
258
- """Handle file upload and system initialization."""
259
- if not api_key:
260
- return "❌ Please provide a Google API key first."
261
-
262
- if file is None:
263
- return initialize_rag_system(api_key)
264
-
265
- try:
266
- # Read uploaded file
267
- file_content = file.decode('utf-8')
268
- return initialize_rag_system(api_key, file_content)
269
- except Exception as e:
270
- return f"❌ Error reading uploaded file: {str(e)}"
271
-
272
- def reset_session():
273
- """Reset the session statistics."""
274
- global rag_system
275
- rag_system.token_callback_handler.reset_counters()
276
- rag_system.session_stats = {
277
- "total_queries": 0,
278
- "total_embedding_tokens": 0,
279
- "initialization_complete": False
280
- }
281
- return "πŸ”„ Session statistics reset."
282
-
283
- # --- FastAPI Setup ---
284
- app = FastAPI(
285
- title="Maize RAG Q&A System API",
286
- description="API for the Maize Agriculture RAG Q&A System",
287
- version="1.0.0"
288
- )
289
-
290
- # Optional: Add API key authentication for API endpoints
291
- security = HTTPBearer(auto_error=False)
292
-
293
- async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
294
- """Extract API key from Authorization header (optional)"""
295
- if credentials:
296
- return credentials.credentials
297
- return None
298
-
299
- # --- API Endpoints ---
300
-
301
- @app.get("/")
302
- async def root():
303
- """Root endpoint"""
304
- return {"message": "Maize RAG Q&A System API", "status": "running"}
305
-
306
- @app.get("/health")
307
- async def health_check():
308
- """Health check endpoint"""
309
- return {
310
- "status": "healthy",
311
- "system_initialized": rag_system.session_stats["initialization_complete"]
312
- }
313
-
314
- @app.post("/initialize", response_model=InitializeResponse)
315
- async def initialize_system(request: InitializeRequest):
316
- """Initialize the RAG system"""
317
- try:
318
- result = initialize_rag_system(request.api_key, request.document_content)
319
-
320
- if "βœ…" in result:
321
- # Parse successful result
322
- lines = result.split('\n')
323
- chunks = None
324
- tokens = None
325
-
326
- for line in lines:
327
- if "chunks" in line:
328
- chunks = int(line.split(': ')[1].split(' ')[0])
329
- elif "tokens" in line:
330
- tokens = int(line.split('~')[1])
331
-
332
- return InitializeResponse(
333
- success=True,
334
- message=result,
335
- chunks=chunks,
336
- estimated_tokens=tokens
337
- )
338
- else:
339
- return InitializeResponse(
340
- success=False,
341
- message=result
342
- )
343
-
344
- except Exception as e:
345
- logger.error(f"API initialization error: {str(e)}")
346
- raise HTTPException(status_code=500, detail=str(e))
347
-
348
- @app.post("/query", response_model=QueryResponse)
349
- async def query_system(request: QueryRequest):
350
- """Query the RAG system"""
351
- try:
352
- if not rag_system.session_stats["initialization_complete"]:
353
- raise HTTPException(status_code=400, detail="System not initialized")
354
-
355
- # Estimate query embedding tokens
356
- query_tokens = estimate_tokens(request.query)
357
- rag_system.session_stats["total_embedding_tokens"] += query_tokens
358
- rag_system.session_stats["total_queries"] += 1
359
-
360
- # Process query
361
- start_time = time.time()
362
- result = rag_system.qa_chain({"query": request.query})
363
- end_time = time.time()
364
-
365
- # Get token usage
366
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
367
-
368
- response_time = end_time - start_time
369
-
370
- return QueryResponse(
371
- success=True,
372
- answer=result['result'],
373
- response_time=response_time,
374
- query_tokens=query_tokens,
375
- llm_tokens=llm_tokens,
376
- session_stats=rag_system.session_stats
377
- )
378
-
379
- except Exception as e:
380
- logger.error(f"API query error: {str(e)}")
381
- raise HTTPException(status_code=500, detail=str(e))
382
-
383
- @app.get("/stats", response_model=StatsResponse)
384
- async def get_stats():
385
- """Get current session statistics"""
386
- llm_tokens = rag_system.token_callback_handler.get_total_tokens()
387
-
388
- return StatsResponse(
389
- total_queries=rag_system.session_stats["total_queries"],
390
- total_embedding_tokens=rag_system.session_stats["total_embedding_tokens"],
391
- total_llm_tokens=llm_tokens["total_llm_tokens"],
392
- total_llm_calls=llm_tokens["total_llm_calls"],
393
- initialization_complete=rag_system.session_stats["initialization_complete"]
394
- )
395
-
396
- @app.post("/reset")
397
- async def reset_system():
398
- """Reset session statistics"""
399
- reset_session()
400
- return {"message": "Session reset successfully"}
401
-
402
- @app.post("/upload-document")
403
- async def upload_document(
404
- file: UploadFile = File(...),
405
- api_key: str = None
406
- ):
407
- """Upload a document and initialize the system"""
408
- try:
409
- if not api_key:
410
- raise HTTPException(status_code=400, detail="API key required")
411
-
412
- # Read uploaded file
413
- content = await file.read()
414
- file_content = content.decode('utf-8')
415
-
416
- # Initialize system with uploaded content
417
- result = initialize_rag_system(api_key, file_content)
418
-
419
- if "βœ…" in result:
420
- return {"success": True, "message": result}
421
- else:
422
- return {"success": False, "message": result}
423
-
424
- except Exception as e:
425
- logger.error(f"Document upload error: {str(e)}")
426
- raise HTTPException(status_code=500, detail=str(e))
427
-
428
- # Create Gradio interface
429
- def create_interface():
430
- with gr.Blocks(title="Maize RAG Q&A System", theme=gr.themes.Soft()) as demo:
431
- gr.Markdown("""
432
- # 🌽 Maize Agriculture RAG Q&A System
433
-
434
- This system uses Retrieval-Augmented Generation (RAG) to answer questions about maize agriculture.
435
- Upload your own document or use the default maize dataset.
436
- """)
437
-
438
- with gr.Row():
439
- with gr.Column(scale=2):
440
- api_key_input = gr.Textbox(
441
- label="πŸ”‘ Google API Key",
442
- placeholder="Enter your Google Generative AI API key",
443
- type="password",
444
- info="Get your API key from Google AI Studio"
445
- )
446
-
447
- with gr.Column(scale=1):
448
- reset_btn = gr.Button("πŸ”„ Reset Session", variant="secondary")
449
-
450
- with gr.Row():
451
- with gr.Column():
452
- file_upload = gr.File(
453
- label="πŸ“ Upload Document (Optional)",
454
- file_types=[".txt"],
455
- info="Upload a text file or use the default maize dataset"
456
- )
457
-
458
- init_btn = gr.Button("πŸš€ Initialize RAG System", variant="primary")
459
- init_output = gr.Textbox(
460
- label="πŸ“‹ Initialization Status",
461
- lines=3,
462
- interactive=False
463
- )
464
-
465
- gr.Markdown("## πŸ’¬ Ask Questions")
466
-
467
- with gr.Row():
468
- with gr.Column(scale=3):
469
- query_input = gr.Textbox(
470
- label="❓ Your Question",
471
- placeholder="Ask something about maize agriculture...",
472
- lines=2
473
- )
474
-
475
- # Sample questions
476
- sample_questions = [
477
- "What are the main pests affecting maize crops?",
478
- "How should maize be irrigated?",
479
- "What is the ideal soil type for maize?",
480
- "What are the nutritional requirements of maize?",
481
- "When is the best time to harvest maize?"
482
- ]
483
-
484
- gr.Examples(
485
- examples=sample_questions,
486
- inputs=query_input,
487
- label="πŸ’‘ Sample Questions"
488
- )
489
-
490
- with gr.Column(scale=1):
491
- submit_btn = gr.Button("πŸ” Ask", variant="primary")
492
-
493
- with gr.Row():
494
- with gr.Column(scale=2):
495
- answer_output = gr.Textbox(
496
- label="πŸ€– Answer",
497
- lines=6,
498
- interactive=False
499
- )
500
-
501
- with gr.Column(scale=1):
502
- stats_output = gr.Markdown(
503
- label="πŸ“Š Statistics",
504
- value="Statistics will appear here after queries."
505
- )
506
-
507
- # Event handlers
508
- init_btn.click(
509
- upload_file_and_initialize,
510
- inputs=[api_key_input, file_upload],
511
- outputs=init_output
512
- )
513
-
514
- submit_btn.click(
515
- process_query,
516
- inputs=[query_input, api_key_input],
517
- outputs=[answer_output, stats_output]
518
- )
519
-
520
- query_input.submit(
521
- process_query,
522
- inputs=[query_input, api_key_input],
523
- outputs=[answer_output, stats_output]
524
- )
525
-
526
- reset_btn.click(
527
- reset_session,
528
- outputs=init_output
529
- )
530
-
531
- gr.Markdown("""
532
- ## πŸ“ Instructions:
533
- 1. **Enter your Google API Key** (required)
534
- 2. **Upload a document** (optional - uses default maize dataset if not provided)
535
- 3. **Initialize the RAG system** by clicking "Initialize RAG System"
536
- 4. **Ask questions** about the document content
537
- 5. **View statistics** to monitor token usage and costs
538
-
539
- ## πŸ’° Cost Information:
540
- - **Gemini 1.5 Flash**: Input: $0.075/1M tokens, Output: $0.30/1M tokens
541
- - **Embedding Model**: $0.025/1M tokens
542
-
543
- Token usage is estimated and displayed for cost tracking.
544
- """)
545
-
546
- return demo
547
-
548
- # Create and launch the interface
549
- def run_gradio():
550
- """Run Gradio interface"""
551
- demo = create_interface()
552
- demo.launch(
553
- server_name="0.0.0.0",
554
- server_port=7860,
555
- show_error=True,
556
- quiet=True # Reduce Gradio logs in combined mode
557
- )
558
-
559
- def run_fastapi():
560
- """Run FastAPI server"""
561
- uvicorn.run(
562
- app,
563
- host="0.0.0.0",
564
- port=8000,
565
- log_level="info"
566
- )
567
-
568
- if __name__ == "__main__":
569
- import sys
570
-
571
- if len(sys.argv) > 1:
572
- mode = sys.argv[1]
573
-
574
- if mode == "api":
575
- # Run only FastAPI
576
- print("Starting FastAPI server on port 8000...")
577
- run_fastapi()
578
- elif mode == "gradio":
579
- # Run only Gradio
580
- print("Starting Gradio interface on port 7860...")
581
- run_gradio()
582
- elif mode == "both":
583
- # Run both servers
584
- print("Starting both FastAPI (port 8000) and Gradio (port 7860)...")
585
-
586
- # Start FastAPI in a separate thread
587
- fastapi_thread = threading.Thread(target=run_fastapi)
588
- fastapi_thread.daemon = True
589
- fastapi_thread.start()
590
-
591
- # Start Gradio in main thread
592
- time.sleep(2) # Give FastAPI time to start
593
- run_gradio()
594
- else:
595
- print("Usage: python app.py [api|gradio|both]")
596
- print("Default: gradio only")
597
- run_gradio()
598
- else:
599
- # Default: run only Gradio (for Hugging Face Spaces compatibility)
600
- print("Starting Gradio interface on port 7860...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  run_gradio()
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ import uvicorn
5
+ from fastapi import FastAPI, HTTPException, Depends, File, UploadFile
6
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
7
+ from pydantic import BaseModel
8
+ from typing import Optional, Dict, Any
9
+ import threading
10
+ import logging
11
+ from langchain_community.document_loaders import TextLoader
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.vectorstores import FAISS
14
+ from langchain.chains import RetrievalQA
15
+ from langchain.prompts import PromptTemplate
16
+ from langchain.callbacks.base import BaseCallbackHandler
17
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
18
+ import tiktoken
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # --- Configuration ---
25
+ CHUNK_SIZE = 800
26
+ CHUNK_OVERLAP = 100
27
+ MAX_TOKENS = 512
28
+ TEMPERATURE = 0.5
29
+ RETRIEVAL_K = 5
30
+
31
+ # --- Token Counting Setup ---
32
+ try:
33
+ tokenizer = tiktoken.get_encoding("cl100k_base")
34
+ except:
35
+ print("Tiktoken encoder 'cl100k_base' not found. Using basic split().")
36
+ tokenizer = type('obj', (object,), {'encode': lambda x: x.split()})()
37
+
38
+ def estimate_tokens(text):
39
+ """Estimates token count for a given text."""
40
+ return len(tokenizer.encode(text))
41
+
42
+ # Custom Callback Handler to track LLM token usage
43
+ class TokenUsageCallbackHandler(BaseCallbackHandler):
44
+ """Callback handler to track token usage in LLM calls."""
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.reset_counters()
48
+
49
+ def reset_counters(self):
50
+ self.total_prompt_tokens = 0
51
+ self.total_completion_tokens = 0
52
+ self.total_llm_calls = 0
53
+
54
+ def on_llm_end(self, response, **kwargs):
55
+ """Collect token usage from the LLM response."""
56
+ self.total_llm_calls += 1
57
+ llm_output = response.llm_output
58
+
59
+ if llm_output and 'usage_metadata' in llm_output:
60
+ usage = llm_output['usage_metadata']
61
+ prompt_tokens = usage.get('prompt_token_count', 0)
62
+ completion_tokens = usage.get('candidates_token_count', 0)
63
+
64
+ self.total_prompt_tokens += prompt_tokens
65
+ self.total_completion_tokens += completion_tokens
66
+
67
+ def get_total_tokens(self):
68
+ """Returns the total prompt and completion tokens."""
69
+ return {
70
+ "total_prompt_tokens": self.total_prompt_tokens,
71
+ "total_completion_tokens": self.total_completion_tokens,
72
+ "total_llm_tokens": self.total_prompt_tokens + self.total_completion_tokens,
73
+ "total_llm_calls": self.total_llm_calls
74
+ }
75
+
76
+ # --- Pydantic Models for API ---
77
+ class InitializeRequest(BaseModel):
78
+ api_key: str
79
+ document_content: Optional[str] = None
80
+
81
+ class QueryRequest(BaseModel):
82
+ query: str
83
+ api_key: str
84
+
85
+ class InitializeResponse(BaseModel):
86
+ success: bool
87
+ message: str
88
+ chunks: Optional[int] = None
89
+ estimated_tokens: Optional[int] = None
90
+
91
+ class QueryResponse(BaseModel):
92
+ success: bool
93
+ answer: str
94
+ response_time: float
95
+ query_tokens: int
96
+ llm_tokens: Dict[str, int]
97
+ session_stats: Dict[str, int]
98
+
99
+ class StatsResponse(BaseModel):
100
+ total_queries: int
101
+ total_embedding_tokens: int
102
+ total_llm_tokens: int
103
+ total_llm_calls: int
104
+ initialization_complete: bool
105
+
106
+ # --- Global Variables ---
107
+ class RAGSystem:
108
+ def __init__(self):
109
+ self.vector_store = None
110
+ self.qa_chain = None
111
+ self.token_callback_handler = TokenUsageCallbackHandler()
112
+ self.session_stats = {
113
+ "total_queries": 0,
114
+ "total_embedding_tokens": 0,
115
+ "initialization_complete": False
116
+ }
117
+ self.current_api_key = None
118
+
119
+ # Global RAG system instance
120
+ rag_system = RAGSystem()
121
+
122
+ def initialize_rag_system(api_key, file_content=None):
123
+ """Initialize the RAG system with API key and optional file content."""
124
+ global rag_system
125
+
126
+ try:
127
+ # Set API key
128
+ os.environ["GOOGLE_API_KEY"] = api_key
129
+ rag_system.current_api_key = api_key
130
+
131
+ # Initialize embeddings
132
+ embeddings = GoogleGenerativeAIEmbeddings(
133
+ model="models/embedding-001",
134
+ google_api_key=api_key
135
+ )
136
+
137
+ # Initialize LLM
138
+ llm = ChatGoogleGenerativeAI(
139
+ model="gemini-1.5-flash",
140
+ google_api_key=api_key,
141
+ temperature=TEMPERATURE,
142
+ max_tokens=MAX_TOKENS,
143
+ callbacks=[rag_system.token_callback_handler],
144
+ verbose=False
145
+ )
146
+
147
+ # Load or use default document
148
+ if file_content:
149
+ # Save uploaded file content
150
+ with open("uploaded_document.txt", "w", encoding="utf-8") as f:
151
+ f.write(file_content)
152
+ loader = TextLoader("uploaded_document.txt")
153
+ else:
154
+ # Check if default maize_data.txt exists
155
+ if os.path.exists("maize_data.txt"):
156
+ loader = TextLoader("maize_data.txt")
157
+ else:
158
+ return "❌ No document found. Please upload a file or ensure maize_data.txt exists."
159
+
160
+ # Load and split documents
161
+ documents = loader.load()
162
+ text_splitter = RecursiveCharacterTextSplitter(
163
+ chunk_size=CHUNK_SIZE,
164
+ chunk_overlap=CHUNK_OVERLAP
165
+ )
166
+ chunks = text_splitter.split_documents(documents)
167
+
168
+ # Estimate embedding tokens
169
+ initial_embedding_tokens = sum(estimate_tokens(chunk.page_content) for chunk in chunks)
170
+ rag_system.session_stats["total_embedding_tokens"] = initial_embedding_tokens
171
+
172
+ # Create vector store
173
+ rag_system.vector_store = FAISS.from_documents(chunks, embeddings)
174
+
175
+ # Create prompt template
176
+ prompt_template = PromptTemplate(
177
+ input_variables=["context", "question"],
178
+ template="""
179
+ You are an expert in maize agriculture. Use the following context ONLY to answer the question accurately and helpfully. If the context doesn't contain the answer, say "Based on the provided context, I cannot answer this question.".
180
+
181
+ Context:
182
+ {context}
183
+
184
+ Question: {question}
185
+
186
+ Answer:"""
187
+ )
188
+
189
+ # Set up QA chain
190
+ rag_system.qa_chain = RetrievalQA.from_chain_type(
191
+ llm=llm,
192
+ chain_type="stuff",
193
+ retriever=rag_system.vector_store.as_retriever(search_kwargs={"k": RETRIEVAL_K}),
194
+ chain_type_kwargs={"prompt": prompt_template},
195
+ callbacks=[rag_system.token_callback_handler],
196
+ return_source_documents=True
197
+ )
198
+
199
+ rag_system.session_stats["initialization_complete"] = True
200
+
201
+ return f"βœ… RAG system initialized successfully!\nπŸ“„ Document processed: {len(chunks)} chunks\nπŸ”’ Estimated embedding tokens: ~{initial_embedding_tokens}"
202
+
203
+ except Exception as e:
204
+ logger.error(f"Initialization failed: {str(e)}")
205
+ return f"❌ Initialization failed: {str(e)}"
206
+
207
+ def process_query(query, api_key):
208
+ """Process a user query through the RAG system."""
209
+ global rag_system
210
+
211
+ if not api_key:
212
+ return "❌ Please provide a Google API key first.", ""
213
+
214
+ if not rag_system.qa_chain:
215
+ return "❌ RAG system not initialized. Please initialize first.", ""
216
+
217
+ if not query.strip():
218
+ return "❌ Please enter a question.", ""
219
+
220
+ try:
221
+ # Estimate query embedding tokens
222
+ query_tokens = estimate_tokens(query)
223
+ rag_system.session_stats["total_embedding_tokens"] += query_tokens
224
+ rag_system.session_stats["total_queries"] += 1
225
+
226
+ # Process query
227
+ start_time = time.time()
228
+ result = rag_system.qa_chain({"query": query})
229
+ end_time = time.time()
230
+
231
+ # Get token usage
232
+ llm_tokens = rag_system.token_callback_handler.get_total_tokens()
233
+
234
+ # Format response
235
+ answer = result['result']
236
+
237
+ # Create stats summary
238
+ stats = f"""
239
+ πŸ“Š **Query Statistics:**
240
+ - Response time: {end_time - start_time:.2f} seconds
241
+ - Query tokens (estimated): ~{query_tokens}
242
+ - LLM tokens (this query): Prompt: {llm_tokens['total_prompt_tokens']}, Completion: {llm_tokens['total_completion_tokens']}
243
+
244
+ πŸ“ˆ **Session Statistics:**
245
+ - Total queries: {rag_system.session_stats['total_queries']}
246
+ - Total embedding tokens: ~{rag_system.session_stats['total_embedding_tokens']}
247
+ - Total LLM calls: {llm_tokens['total_llm_calls']}
248
+ - Total LLM tokens: {llm_tokens['total_llm_tokens']}
249
+ """
250
+
251
+ return answer, stats
252
+
253
+ except Exception as e:
254
+ logger.error(f"Error processing query: {str(e)}")
255
+ return f"❌ Error processing query: {str(e)}", ""
256
+
257
+ def upload_file_and_initialize(api_key, file):
258
+ """Handle file upload and system initialization."""
259
+ if not api_key:
260
+ return "❌ Please provide a Google API key first."
261
+
262
+ if file is None:
263
+ return initialize_rag_system(api_key)
264
+
265
+ try:
266
+ # Handle different file object types based on Gradio version
267
+ if hasattr(file, 'name'):
268
+ # Newer Gradio versions - file has .name attribute
269
+ with open(file.name, 'r', encoding='utf-8') as f:
270
+ file_content = f.read()
271
+ elif isinstance(file, str):
272
+ # File path as string
273
+ with open(file, 'r', encoding='utf-8') as f:
274
+ file_content = f.read()
275
+ elif hasattr(file, 'read'):
276
+ # File-like object
277
+ file_content = file.read()
278
+ if isinstance(file_content, bytes):
279
+ file_content = file_content.decode('utf-8')
280
+ else:
281
+ # Fallback - try to read as bytes and decode
282
+ file_content = file.decode('utf-8') if isinstance(file, bytes) else str(file)
283
+
284
+ return initialize_rag_system(api_key, file_content)
285
+
286
+ except Exception as e:
287
+ logger.error(f"Error reading uploaded file: {str(e)}")
288
+ return f"❌ Error reading uploaded file: {str(e)}"
289
+
290
+ def reset_session():
291
+ """Reset the session statistics."""
292
+ global rag_system
293
+ rag_system.token_callback_handler.reset_counters()
294
+ rag_system.session_stats = {
295
+ "total_queries": 0,
296
+ "total_embedding_tokens": 0,
297
+ "initialization_complete": False
298
+ }
299
+ return "πŸ”„ Session statistics reset."
300
+
301
+ # --- FastAPI Setup ---
302
+ app = FastAPI(
303
+ title="Maize RAG Q&A System API",
304
+ description="API for the Maize Agriculture RAG Q&A System",
305
+ version="1.0.0"
306
+ )
307
+
308
+ # Optional: Add API key authentication for API endpoints
309
+ security = HTTPBearer(auto_error=False)
310
+
311
+ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
312
+ """Extract API key from Authorization header (optional)"""
313
+ if credentials:
314
+ return credentials.credentials
315
+ return None
316
+
317
+ # --- API Endpoints ---
318
+
319
+ @app.get("/")
320
+ async def root():
321
+ """Root endpoint"""
322
+ return {"message": "Maize RAG Q&A System API", "status": "running"}
323
+
324
+ @app.get("/health")
325
+ async def health_check():
326
+ """Health check endpoint"""
327
+ return {
328
+ "status": "healthy",
329
+ "system_initialized": rag_system.session_stats["initialization_complete"]
330
+ }
331
+
332
+ @app.post("/initialize", response_model=InitializeResponse)
333
+ async def initialize_system(request: InitializeRequest):
334
+ """Initialize the RAG system"""
335
+ try:
336
+ result = initialize_rag_system(request.api_key, request.document_content)
337
+
338
+ if "βœ…" in result:
339
+ # Parse successful result
340
+ lines = result.split('\n')
341
+ chunks = None
342
+ tokens = None
343
+
344
+ for line in lines:
345
+ if "chunks" in line:
346
+ chunks = int(line.split(': ')[1].split(' ')[0])
347
+ elif "tokens" in line:
348
+ tokens = int(line.split('~')[1])
349
+
350
+ return InitializeResponse(
351
+ success=True,
352
+ message=result,
353
+ chunks=chunks,
354
+ estimated_tokens=tokens
355
+ )
356
+ else:
357
+ return InitializeResponse(
358
+ success=False,
359
+ message=result
360
+ )
361
+
362
+ except Exception as e:
363
+ logger.error(f"API initialization error: {str(e)}")
364
+ raise HTTPException(status_code=500, detail=str(e))
365
+
366
+ @app.post("/query", response_model=QueryResponse)
367
+ async def query_system(request: QueryRequest):
368
+ """Query the RAG system"""
369
+ try:
370
+ if not rag_system.session_stats["initialization_complete"]:
371
+ raise HTTPException(status_code=400, detail="System not initialized")
372
+
373
+ # Estimate query embedding tokens
374
+ query_tokens = estimate_tokens(request.query)
375
+ rag_system.session_stats["total_embedding_tokens"] += query_tokens
376
+ rag_system.session_stats["total_queries"] += 1
377
+
378
+ # Process query
379
+ start_time = time.time()
380
+ result = rag_system.qa_chain({"query": request.query})
381
+ end_time = time.time()
382
+
383
+ # Get token usage
384
+ llm_tokens = rag_system.token_callback_handler.get_total_tokens()
385
+
386
+ response_time = end_time - start_time
387
+
388
+ return QueryResponse(
389
+ success=True,
390
+ answer=result['result'],
391
+ response_time=response_time,
392
+ query_tokens=query_tokens,
393
+ llm_tokens=llm_tokens,
394
+ session_stats=rag_system.session_stats
395
+ )
396
+
397
+ except Exception as e:
398
+ logger.error(f"API query error: {str(e)}")
399
+ raise HTTPException(status_code=500, detail=str(e))
400
+
401
+ @app.get("/stats", response_model=StatsResponse)
402
+ async def get_stats():
403
+ """Get current session statistics"""
404
+ llm_tokens = rag_system.token_callback_handler.get_total_tokens()
405
+
406
+ return StatsResponse(
407
+ total_queries=rag_system.session_stats["total_queries"],
408
+ total_embedding_tokens=rag_system.session_stats["total_embedding_tokens"],
409
+ total_llm_tokens=llm_tokens["total_llm_tokens"],
410
+ total_llm_calls=llm_tokens["total_llm_calls"],
411
+ initialization_complete=rag_system.session_stats["initialization_complete"]
412
+ )
413
+
414
+ @app.post("/reset")
415
+ async def reset_system():
416
+ """Reset session statistics"""
417
+ reset_session()
418
+ return {"message": "Session reset successfully"}
419
+
420
+ @app.post("/upload-document")
421
+ async def upload_document(
422
+ file: UploadFile = File(...),
423
+ api_key: str = None
424
+ ):
425
+ """Upload a document and initialize the system"""
426
+ try:
427
+ if not api_key:
428
+ raise HTTPException(status_code=400, detail="API key required")
429
+
430
+ # Read uploaded file
431
+ content = await file.read()
432
+ file_content = content.decode('utf-8')
433
+
434
+ # Initialize system with uploaded content
435
+ result = initialize_rag_system(api_key, file_content)
436
+
437
+ if "βœ…" in result:
438
+ return {"success": True, "message": result}
439
+ else:
440
+ return {"success": False, "message": result}
441
+
442
+ except Exception as e:
443
+ logger.error(f"Document upload error: {str(e)}")
444
+ raise HTTPException(status_code=500, detail=str(e))
445
+
446
+ # Create Gradio interface with version compatibility
447
+ def create_interface():
448
+ # Check Gradio version for compatibility
449
+ import gradio as gr
450
+ gradio_version = gr.__version__
451
+
452
+ with gr.Blocks(title="Maize RAG Q&A System", theme=gr.themes.Soft()) as demo:
453
+ gr.Markdown("""
454
+ # 🌽 Maize Agriculture RAG Q&A System
455
+
456
+ This system uses Retrieval-Augmented Generation (RAG) to answer questions about maize agriculture.
457
+ Upload your own document or use the default maize dataset.
458
+ """)
459
+
460
+ with gr.Row():
461
+ with gr.Column(scale=2):
462
+ api_key_input = gr.Textbox(
463
+ label="πŸ”‘ Google API Key",
464
+ placeholder="Enter your Google Generative AI API key",
465
+ type="password"
466
+ )
467
+ gr.Markdown("Get your API key from Google AI Studio")
468
+
469
+ with gr.Column(scale=1):
470
+ reset_btn = gr.Button("πŸ”„ Reset Session", variant="secondary")
471
+
472
+ with gr.Row():
473
+ with gr.Column():
474
+ file_upload = gr.File(
475
+ label="πŸ“ Upload Document (Optional)",
476
+ file_types=[".txt"]
477
+ )
478
+ gr.Markdown("Upload a text file or use the default maize dataset")
479
+
480
+ init_btn = gr.Button("πŸš€ Initialize RAG System", variant="primary")
481
+ init_output = gr.Textbox(
482
+ label="πŸ“‹ Initialization Status",
483
+ lines=3,
484
+ interactive=False
485
+ )
486
+
487
+ gr.Markdown("## πŸ’¬ Ask Questions")
488
+
489
+ with gr.Row():
490
+ with gr.Column(scale=3):
491
+ query_input = gr.Textbox(
492
+ label="❓ Your Question",
493
+ placeholder="Ask something about maize agriculture...",
494
+ lines=2
495
+ )
496
+
497
+ # Sample questions
498
+ sample_questions = [
499
+ "What are the main pests affecting maize crops?",
500
+ "How should maize be irrigated?",
501
+ "What is the ideal soil type for maize?",
502
+ "What are the nutritional requirements of maize?",
503
+ "When is the best time to harvest maize?"
504
+ ]
505
+
506
+ # Use Examples component if available, otherwise just show as markdown
507
+ try:
508
+ gr.Examples(
509
+ examples=sample_questions,
510
+ inputs=query_input,
511
+ label="πŸ’‘ Sample Questions"
512
+ )
513
+ except:
514
+ gr.Markdown("πŸ’‘ **Sample Questions:**\n" +
515
+ "\n".join([f"- {q}" for q in sample_questions]))
516
+
517
+ with gr.Column(scale=1):
518
+ submit_btn = gr.Button("πŸ” Ask", variant="primary")
519
+
520
+ with gr.Row():
521
+ with gr.Column(scale=2):
522
+ answer_output = gr.Textbox(
523
+ label="πŸ€– Answer",
524
+ lines=6,
525
+ interactive=False
526
+ )
527
+
528
+ with gr.Column(scale=1):
529
+ stats_output = gr.Markdown(
530
+ value="πŸ“Š Statistics will appear here after queries."
531
+ )
532
+
533
+ # Event handlers
534
+ init_btn.click(
535
+ upload_file_and_initialize,
536
+ inputs=[api_key_input, file_upload],
537
+ outputs=init_output
538
+ )
539
+
540
+ submit_btn.click(
541
+ process_query,
542
+ inputs=[query_input, api_key_input],
543
+ outputs=[answer_output, stats_output]
544
+ )
545
+
546
+ query_input.submit(
547
+ process_query,
548
+ inputs=[query_input, api_key_input],
549
+ outputs=[answer_output, stats_output]
550
+ )
551
+
552
+ reset_btn.click(
553
+ reset_session,
554
+ outputs=init_output
555
+ )
556
+
557
+ gr.Markdown("""
558
+ ## πŸ“ Instructions:
559
+ 1. **Enter your Google API Key** (required)
560
+ 2. **Upload a document** (optional - uses default maize dataset if not provided)
561
+ 3. **Initialize the RAG system** by clicking "Initialize RAG System"
562
+ 4. **Ask questions** about the document content
563
+ 5. **View statistics** to monitor token usage and costs
564
+
565
+ ## πŸ’° Cost Information:
566
+ - **Gemini 1.5 Flash**: Input: $0.075/1M tokens, Output: $0.30/1M tokens
567
+ - **Embedding Model**: $0.025/1M tokens
568
+
569
+ Token usage is estimated and displayed for cost tracking.
570
+
571
+ ## πŸ”— API Access:
572
+ This system also provides REST API endpoints:
573
+ - **API Docs**: Add `/docs` to the URL for interactive API documentation
574
+ - **Health Check**: `GET /health`
575
+ - **Initialize**: `POST /initialize`
576
+ - **Query**: `POST /query`
577
+ """)
578
+
579
+ return demo
580
+
581
+ # Create and launch the interface
582
+ def run_gradio():
583
+ """Run Gradio interface"""
584
+ demo = create_interface()
585
+ demo.launch(
586
+ server_name="0.0.0.0",
587
+ server_port=7860,
588
+ show_error=True,
589
+ quiet=True # Reduce Gradio logs in combined mode
590
+ )
591
+
592
+ def run_fastapi():
593
+ """Run FastAPI server"""
594
+ uvicorn.run(
595
+ app,
596
+ host="0.0.0.0",
597
+ port=8000,
598
+ log_level="info"
599
+ )
600
+
601
+ if __name__ == "__main__":
602
+ import sys
603
+
604
+ if len(sys.argv) > 1:
605
+ mode = sys.argv[1]
606
+
607
+ if mode == "api":
608
+ # Run only FastAPI
609
+ print("Starting FastAPI server on port 8000...")
610
+ run_fastapi()
611
+ elif mode == "gradio":
612
+ # Run only Gradio
613
+ print("Starting Gradio interface on port 7860...")
614
+ run_gradio()
615
+ elif mode == "both":
616
+ # Run both servers
617
+ print("Starting both FastAPI (port 8000) and Gradio (port 7860)...")
618
+
619
+ # Start FastAPI in a separate thread
620
+ fastapi_thread = threading.Thread(target=run_fastapi)
621
+ fastapi_thread.daemon = True
622
+ fastapi_thread.start()
623
+
624
+ # Start Gradio in main thread
625
+ time.sleep(2) # Give FastAPI time to start
626
+ run_gradio()
627
+ else:
628
+ print("Usage: python app.py [api|gradio|both]")
629
+ print("Default: gradio only")
630
+ run_gradio()
631
+ else:
632
+ # Default: run only Gradio (for Hugging Face Spaces compatibility)
633
+ print("Starting Gradio interface on port 7860...")
634
  run_gradio()