Spaces:
Sleeping
Sleeping
Commit
·
cc9c554
1
Parent(s):
48b7a28
big-query
Browse files- app.py +112 -4
- bigquery_uploader.py +112 -0
- knowledge_base.py +215 -44
- local_database.py +57 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -19,11 +19,13 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1" # Ensure torch compile is off
|
|
| 19 |
|
| 20 |
# --- Step 1: Import Core Components from Modules ---
|
| 21 |
from vision_model import load_vision_model
|
| 22 |
-
from knowledge_base import
|
| 23 |
from agent_setup import initialize_adk
|
| 24 |
from google.genai import types
|
| 25 |
from story_generator import create_story_prompt_from_pdf, generate_video_from_prompt
|
| 26 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
|
|
|
|
| 27 |
|
| 28 |
print("✅ All libraries imported successfully.")
|
| 29 |
|
|
@@ -32,7 +34,8 @@ print("✅ All libraries imported successfully.")
|
|
| 32 |
|
| 33 |
print("Performing initial setup...")
|
| 34 |
VISION_MODEL, PROCESSOR = load_vision_model()
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
# Initialize ADK components for Connected Mode
|
| 38 |
adk_components = initialize_adk(VISION_MODEL, PROCESSOR, RETRIEVER)
|
|
@@ -62,7 +65,6 @@ else:
|
|
| 62 |
|
| 63 |
def create_field_mode_ui():
|
| 64 |
"""Creates the Gradio UI for the offline Field Mode."""
|
| 65 |
-
# ... (This function remains unchanged) ...
|
| 66 |
def get_diagnosis_and_remedy(uploaded_image: Image.Image) -> str:
|
| 67 |
if uploaded_image is None:
|
| 68 |
return "Please upload an image of a maize plant first."
|
|
@@ -81,7 +83,23 @@ def create_field_mode_ui():
|
|
| 81 |
if "Could not parse" in diagnosis:
|
| 82 |
return f"Sorry, I couldn't identify the condition from the image. Raw output: {diagnosis}"
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
final_response = f"""
|
| 87 |
## Diagnosis Report
|
|
@@ -286,6 +304,79 @@ def create_story_mode_ui():
|
|
| 286 |
)
|
| 287 |
return demo
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
# --- Step 4: App Launcher ---
|
| 290 |
|
| 291 |
def check_internet_connection(host="8.8.8.8", port=53, timeout=3):
|
|
@@ -299,6 +390,13 @@ def check_internet_connection(host="8.8.8.8", port=53, timeout=3):
|
|
| 299 |
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
field_mode_ui = create_field_mode_ui()
|
| 303 |
interface_list = [field_mode_ui]
|
| 304 |
tab_titles = ["Field Mode (Offline)"]
|
|
@@ -323,6 +421,16 @@ if __name__ == "__main__":
|
|
| 323 |
document_analysis_ui = create_document_analysis_ui()
|
| 324 |
interface_list.append(document_analysis_ui)
|
| 325 |
tab_titles.append("Document Analysis")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
else:
|
| 328 |
print("❌ No internet connection. Launching in Offline Mode only.")
|
|
|
|
| 19 |
|
| 20 |
# --- Step 1: Import Core Components from Modules ---
|
| 21 |
from vision_model import load_vision_model
|
| 22 |
+
from knowledge_base import KnowledgeBase
|
| 23 |
from agent_setup import initialize_adk
|
| 24 |
from google.genai import types
|
| 25 |
from story_generator import create_story_prompt_from_pdf, generate_video_from_prompt
|
| 26 |
from langchain_huggingface import HuggingFaceEndpoint
|
| 27 |
+
from bigquery_uploader import upload_to_bigquery
|
| 28 |
+
import local_database
|
| 29 |
|
| 30 |
print("✅ All libraries imported successfully.")
|
| 31 |
|
|
|
|
| 34 |
|
| 35 |
print("Performing initial setup...")
|
| 36 |
VISION_MODEL, PROCESSOR = load_vision_model()
|
| 37 |
+
KB = KnowledgeBase()
|
| 38 |
+
RETRIEVER = KB # The retriever is now the KB itself
|
| 39 |
|
| 40 |
# Initialize ADK components for Connected Mode
|
| 41 |
adk_components = initialize_adk(VISION_MODEL, PROCESSOR, RETRIEVER)
|
|
|
|
| 65 |
|
| 66 |
def create_field_mode_ui():
|
| 67 |
"""Creates the Gradio UI for the offline Field Mode."""
|
|
|
|
| 68 |
def get_diagnosis_and_remedy(uploaded_image: Image.Image) -> str:
|
| 69 |
if uploaded_image is None:
|
| 70 |
return "Please upload an image of a maize plant first."
|
|
|
|
| 83 |
if "Could not parse" in diagnosis:
|
| 84 |
return f"Sorry, I couldn't identify the condition from the image. Raw output: {diagnosis}"
|
| 85 |
|
| 86 |
+
if "Healthy" in diagnosis:
|
| 87 |
+
return """## Diagnosis Report
|
| 88 |
+
|
| 89 |
+
**Condition Identified:**
|
| 90 |
+
### Healthy Maize Plant
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## Suggested Remedy
|
| 95 |
+
|
| 96 |
+
The plant appears to be healthy. No remedy is required. Continue with good farming practices. You can find recipes for enjoying your healthy maize in the 'Knowledge Base' tab."""
|
| 97 |
+
|
| 98 |
+
results = KB.search(diagnosis)
|
| 99 |
+
if not results:
|
| 100 |
+
remedy = "No remedy found in the local knowledge base."
|
| 101 |
+
else:
|
| 102 |
+
remedy = results[0]['content']
|
| 103 |
|
| 104 |
final_response = f"""
|
| 105 |
## Diagnosis Report
|
|
|
|
| 304 |
)
|
| 305 |
return demo
|
| 306 |
|
| 307 |
+
def create_settings_ui():
|
| 308 |
+
"""Creates the Gradio UI for the Settings tab."""
|
| 309 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="gray", secondary_hue="blue")) as demo:
|
| 310 |
+
gr.Markdown("# ⚙️ Settings & Data Management")
|
| 311 |
+
gr.Markdown("Manage application settings and data synchronization.")
|
| 312 |
+
|
| 313 |
+
with gr.Row():
|
| 314 |
+
with gr.Column():
|
| 315 |
+
sync_btn = gr.Button("☁️ Sync Local Data to BigQuery Cloud")
|
| 316 |
+
status_output = gr.Textbox(label="Sync Status", interactive=False, lines=5)
|
| 317 |
+
|
| 318 |
+
def sync_data_to_cloud():
|
| 319 |
+
yield "Attempting to sync local diagnosis data to BigQuery..."
|
| 320 |
+
try:
|
| 321 |
+
# Assuming your bigquery_uploader has a function that returns a summary
|
| 322 |
+
result_message = upload_to_bigquery()
|
| 323 |
+
yield f"Sync successful!\n{result_message}"
|
| 324 |
+
except Exception as e:
|
| 325 |
+
yield f"Sync failed!\nError: {e}"
|
| 326 |
+
|
| 327 |
+
sync_btn.click(
|
| 328 |
+
sync_data_to_cloud,
|
| 329 |
+
inputs=[],
|
| 330 |
+
outputs=[status_output]
|
| 331 |
+
)
|
| 332 |
+
return demo
|
| 333 |
+
|
| 334 |
+
def create_kb_management_ui():
|
| 335 |
+
"""Creates the Gradio UI for managing the knowledge base."""
|
| 336 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple")) as demo:
|
| 337 |
+
gr.Markdown("# 📚 Knowledge Base Management")
|
| 338 |
+
gr.Markdown("Manage the local, encrypted knowledge base.")
|
| 339 |
+
|
| 340 |
+
with gr.Row():
|
| 341 |
+
with gr.Column():
|
| 342 |
+
gr.Markdown("### Rebuild Knowledge Base")
|
| 343 |
+
rebuild_btn = gr.Button("Rebuild from Source Files")
|
| 344 |
+
rebuild_status = gr.Textbox(label="Status", interactive=False)
|
| 345 |
+
|
| 346 |
+
with gr.Column():
|
| 347 |
+
gr.Markdown("### Add PDF to Knowledge Base")
|
| 348 |
+
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
|
| 349 |
+
ingest_btn = gr.Button("Ingest PDF")
|
| 350 |
+
ingest_status = gr.Textbox(label="Status", interactive=False)
|
| 351 |
+
|
| 352 |
+
def rebuild_kb():
|
| 353 |
+
yield "Rebuilding knowledge base..."
|
| 354 |
+
try:
|
| 355 |
+
docs = {}
|
| 356 |
+
for filename in os.listdir("knowledge_base_data"):
|
| 357 |
+
if filename.endswith(".txt"):
|
| 358 |
+
with open(os.path.join("knowledge_base_data", filename)) as f:
|
| 359 |
+
docs[filename] = f.read()
|
| 360 |
+
KB.create_initial_index(docs)
|
| 361 |
+
yield "Knowledge base rebuilt successfully."
|
| 362 |
+
except Exception as e:
|
| 363 |
+
yield f"Error rebuilding knowledge base: {e}"
|
| 364 |
+
|
| 365 |
+
def ingest_pdf(pdf):
|
| 366 |
+
if pdf is None:
|
| 367 |
+
return "Please upload a PDF file."
|
| 368 |
+
yield "Ingesting PDF..."
|
| 369 |
+
try:
|
| 370 |
+
KB.ingest_pdf(pdf.name, os.path.basename(pdf.name))
|
| 371 |
+
yield f"Successfully ingested {os.path.basename(pdf.name)}."
|
| 372 |
+
except Exception as e:
|
| 373 |
+
yield f"Error ingesting PDF: {e}"
|
| 374 |
+
|
| 375 |
+
rebuild_btn.click(rebuild_kb, outputs=[rebuild_status])
|
| 376 |
+
ingest_btn.click(ingest_pdf, inputs=[pdf_input], outputs=[ingest_status])
|
| 377 |
+
|
| 378 |
+
return demo
|
| 379 |
+
|
| 380 |
# --- Step 4: App Launcher ---
|
| 381 |
|
| 382 |
def check_internet_connection(host="8.8.8.8", port=53, timeout=3):
|
|
|
|
| 390 |
|
| 391 |
|
| 392 |
if __name__ == "__main__":
|
| 393 |
+
# Initialize local database
|
| 394 |
+
conn = local_database.get_db_connection()
|
| 395 |
+
if conn is not None:
|
| 396 |
+
local_database.init_db()
|
| 397 |
+
conn.close()
|
| 398 |
+
else:
|
| 399 |
+
print("❌ Could not create a connection to the local database.")
|
| 400 |
field_mode_ui = create_field_mode_ui()
|
| 401 |
interface_list = [field_mode_ui]
|
| 402 |
tab_titles = ["Field Mode (Offline)"]
|
|
|
|
| 421 |
document_analysis_ui = create_document_analysis_ui()
|
| 422 |
interface_list.append(document_analysis_ui)
|
| 423 |
tab_titles.append("Document Analysis")
|
| 424 |
+
|
| 425 |
+
# Add the Settings UI
|
| 426 |
+
settings_ui = create_settings_ui()
|
| 427 |
+
interface_list.append(settings_ui)
|
| 428 |
+
tab_titles.append("Settings")
|
| 429 |
+
|
| 430 |
+
# Add the Knowledge Base Management UI
|
| 431 |
+
kb_management_ui = create_kb_management_ui()
|
| 432 |
+
interface_list.append(kb_management_ui)
|
| 433 |
+
tab_titles.append("Knowledge Base")
|
| 434 |
|
| 435 |
else:
|
| 436 |
print("❌ No internet connection. Launching in Offline Mode only.")
|
bigquery_uploader.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file will contain the logic for uploading data to BigQuery.
|
| 2 |
+
|
| 3 |
+
from google.cloud import bigquery
|
| 4 |
+
from google.cloud.exceptions import NotFound
|
| 5 |
+
import local_database
|
| 6 |
+
|
| 7 |
+
PROJECT_ID = "gem-creation"
|
| 8 |
+
DATASET_ID = "aura_mind_glow_data"
|
| 9 |
+
TABLE_ID = "farm_analysis"
|
| 10 |
+
|
| 11 |
+
def get_bigquery_client():
|
| 12 |
+
"""Returns an authenticated BigQuery client."""
|
| 13 |
+
try:
|
| 14 |
+
client = bigquery.Client(project=PROJECT_ID)
|
| 15 |
+
print("Successfully authenticated with BigQuery.")
|
| 16 |
+
return client
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error authenticating with BigQuery: {e}")
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def create_dataset_if_not_exists(client):
|
| 22 |
+
"""Creates the BigQuery dataset if it doesn't exist."""
|
| 23 |
+
dataset_id = f"{PROJECT_ID}.{DATASET_ID}"
|
| 24 |
+
try:
|
| 25 |
+
client.get_dataset(dataset_id) # Make an API request.
|
| 26 |
+
print(f"Dataset {dataset_id} already exists.")
|
| 27 |
+
except NotFound:
|
| 28 |
+
print(f"Dataset {dataset_id} is not found. Creating dataset...")
|
| 29 |
+
dataset = bigquery.Dataset(dataset_id)
|
| 30 |
+
dataset.location = "US"
|
| 31 |
+
dataset = client.create_dataset(dataset, timeout=30) # Make an API request.
|
| 32 |
+
print(f"Created dataset {client.project}.{dataset.dataset_id}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_table_if_not_exists(client):
|
| 36 |
+
"""Creates the BigQuery table if it doesn't exist."""
|
| 37 |
+
table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}"
|
| 38 |
+
try:
|
| 39 |
+
client.get_table(table_id) # Make an API request.
|
| 40 |
+
print(f"Table {table_id} already exists.")
|
| 41 |
+
except NotFound:
|
| 42 |
+
print(f"Table {table_id} is not found. Creating table...")
|
| 43 |
+
schema = [
|
| 44 |
+
bigquery.SchemaField("analysis_id", "STRING", mode="REQUIRED"),
|
| 45 |
+
bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
|
| 46 |
+
bigquery.SchemaField("farmer_id", "STRING", mode="NULLABLE"),
|
| 47 |
+
bigquery.SchemaField("gps_latitude", "FLOAT", mode="NULLABLE"),
|
| 48 |
+
bigquery.SchemaField("gps_longitude", "FLOAT", mode="NULLABLE"),
|
| 49 |
+
bigquery.SchemaField("crop_type", "STRING", mode="NULLABLE"),
|
| 50 |
+
bigquery.SchemaField("crop_variety", "STRING", mode="NULLABLE"),
|
| 51 |
+
bigquery.SchemaField("ai_diagnosis", "STRING", mode="NULLABLE"),
|
| 52 |
+
bigquery.SchemaField("confidence_score", "FLOAT", mode="NULLABLE"),
|
| 53 |
+
bigquery.SchemaField("recommended_action", "STRING", mode="NULLABLE"),
|
| 54 |
+
bigquery.SchemaField("farmer_feedback", "STRING", mode="NULLABLE"),
|
| 55 |
+
bigquery.SchemaField("treatment_applied", "STRING", mode="NULLABLE"),
|
| 56 |
+
bigquery.SchemaField("outcome_image_id", "STRING", mode="NULLABLE"),
|
| 57 |
+
]
|
| 58 |
+
table = bigquery.Table(table_id, schema=schema)
|
| 59 |
+
table = client.create_table(table) # Make an API request.
|
| 60 |
+
print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def upload_data_from_local_db():
|
| 64 |
+
"""Uploads data from the local SQLite database to BigQuery."""
|
| 65 |
+
conn = local_database.create_connection()
|
| 66 |
+
if conn is None:
|
| 67 |
+
print("Could not connect to the local database.")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
rows = local_database.get_all_analysis(conn)
|
| 71 |
+
if not rows:
|
| 72 |
+
print("No data to upload from the local database.")
|
| 73 |
+
conn.close()
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
client = get_bigquery_client()
|
| 77 |
+
if client is None:
|
| 78 |
+
conn.close()
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
create_dataset_if_not_exists(client)
|
| 82 |
+
create_table_if_not_exists(client)
|
| 83 |
+
|
| 84 |
+
table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}"
|
| 85 |
+
# Convert rows to list of dictionaries
|
| 86 |
+
rows_to_insert = []
|
| 87 |
+
for row in rows:
|
| 88 |
+
rows_to_insert.append({
|
| 89 |
+
"analysis_id": row[0],
|
| 90 |
+
"timestamp": row[1],
|
| 91 |
+
"farmer_id": row[2],
|
| 92 |
+
"gps_latitude": row[3],
|
| 93 |
+
"gps_longitude": row[4],
|
| 94 |
+
"crop_type": row[5],
|
| 95 |
+
"crop_variety": row[6],
|
| 96 |
+
"ai_diagnosis": row[7],
|
| 97 |
+
"confidence_score": row[8],
|
| 98 |
+
"recommended_action": row[9],
|
| 99 |
+
"farmer_feedback": row[10],
|
| 100 |
+
"treatment_applied": row[11],
|
| 101 |
+
"outcome_image_id": row[12],
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
errors = client.insert_rows_json(table_id, rows_to_insert) # Make an API request.
|
| 105 |
+
if errors == []:
|
| 106 |
+
print("New rows have been added.")
|
| 107 |
+
local_database.clear_all_analysis(conn)
|
| 108 |
+
print("Local database cleared.")
|
| 109 |
+
else:
|
| 110 |
+
print(f"Encountered errors while inserting rows: {errors}")
|
| 111 |
+
|
| 112 |
+
conn.close()
|
knowledge_base.py
CHANGED
|
@@ -1,48 +1,219 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import config
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def get_retriever():
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
print(f"⚠️ Building a new FAISS index from all files in {config.KNOWLEDGE_BASE_PATH}...")
|
| 25 |
-
|
| 26 |
-
documents = []
|
| 27 |
-
data_path = config.KNOWLEDGE_BASE_PATH
|
| 28 |
-
for file_name in os.listdir(data_path):
|
| 29 |
-
file_path = os.path.join(data_path, file_name)
|
| 30 |
-
if os.path.isfile(file_path) and file_name.endswith('.txt'):
|
| 31 |
-
print(f" - Loading {file_name}...")
|
| 32 |
-
loader = TextLoader(file_path)
|
| 33 |
-
documents.extend(loader.load())
|
| 34 |
-
|
| 35 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
| 36 |
-
docs = text_splitter.split_documents(documents)
|
| 37 |
-
|
| 38 |
-
print("\n✨ Creating new FAISS index...")
|
| 39 |
-
db = FAISS.from_documents(docs, embeddings)
|
| 40 |
-
db.save_local(config.FAISS_INDEX_PATH)
|
| 41 |
-
print(f"✅ New FAISS index built and saved to {config.FAISS_INDEX_PATH}.")
|
| 42 |
-
|
| 43 |
-
retriever = db.as_retriever(search_kwargs={"k": 1})
|
| 44 |
-
print("✅ RAG knowledge base and retriever created successfully!")
|
| 45 |
-
return retriever
|
| 46 |
-
except Exception as e:
|
| 47 |
-
print(f"❌ CRITICAL ERROR during RAG setup: {e}")
|
| 48 |
-
return None
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sqlite3
|
| 3 |
+
import faiss
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
import fitz # PyMuPDF
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import io
|
| 9 |
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
| 10 |
+
from cryptography.hazmat.primitives import padding
|
| 11 |
+
from cryptography.hazmat.backends import default_backend
|
| 12 |
import config
|
| 13 |
|
| 14 |
+
# --- Security ---
|
| 15 |
+
SECRET_KEY = os.environ.get("AURA_MIND_SECRET_KEY", "a_default_secret_key_32_bytes_!!").encode()
|
| 16 |
+
if len(SECRET_KEY) != 32:
|
| 17 |
+
raise ValueError("SECRET_KEY must be 32 bytes long for AES-256.")
|
| 18 |
+
|
| 19 |
+
def encrypt_data(data: bytes) -> bytes:
|
| 20 |
+
iv = os.urandom(16)
|
| 21 |
+
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
| 22 |
+
padded_data = padder.update(data) + padder.finalize()
|
| 23 |
+
cipher = Cipher(algorithms.AES(SECRET_KEY), modes.CBC(iv), backend=default_backend())
|
| 24 |
+
encryptor = cipher.encryptor()
|
| 25 |
+
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
| 26 |
+
return iv + encrypted_data
|
| 27 |
+
|
| 28 |
+
def decrypt_data(encrypted_data_with_iv: bytes) -> bytes:
|
| 29 |
+
iv = encrypted_data_with_iv[:16]
|
| 30 |
+
encrypted_data = encrypted_data_with_iv[16:]
|
| 31 |
+
cipher = Cipher(algorithms.AES(SECRET_KEY), modes.CBC(iv), backend=default_backend())
|
| 32 |
+
decryptor = cipher.decryptor()
|
| 33 |
+
padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
| 34 |
+
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
| 35 |
+
data = unpadder.update(padded_data) + unpadder.finalize()
|
| 36 |
+
return data
|
| 37 |
+
|
| 38 |
+
# --- KnowledgeBase Class ---
|
| 39 |
+
class KnowledgeBase:
|
| 40 |
+
def __init__(self, db_file="auramind_local.db", index_file="auramind_faiss.index", model_name='clip-ViT-B-32'):
|
| 41 |
+
self.db_file = db_file
|
| 42 |
+
self.index_file = index_file
|
| 43 |
+
self.model = SentenceTransformer(model_name)
|
| 44 |
+
self.init_db()
|
| 45 |
+
|
| 46 |
+
def init_db(self):
|
| 47 |
+
conn = sqlite3.connect(self.db_file)
|
| 48 |
+
cursor = conn.cursor()
|
| 49 |
+
cursor.execute('''
|
| 50 |
+
CREATE TABLE IF NOT EXISTS documents (
|
| 51 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 52 |
+
name TEXT NOT NULL UNIQUE
|
| 53 |
+
)
|
| 54 |
+
''')
|
| 55 |
+
cursor.execute('''
|
| 56 |
+
CREATE TABLE IF NOT EXISTS chunks (
|
| 57 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 58 |
+
doc_id INTEGER,
|
| 59 |
+
content_type TEXT NOT NULL, -- 'text' or 'image'
|
| 60 |
+
encrypted_content BLOB NOT NULL,
|
| 61 |
+
page_num INTEGER,
|
| 62 |
+
FOREIGN KEY (doc_id) REFERENCES documents (id)
|
| 63 |
+
)
|
| 64 |
+
''')
|
| 65 |
+
conn.commit()
|
| 66 |
+
conn.close()
|
| 67 |
+
|
| 68 |
+
def get_db_connection(self):
|
| 69 |
+
conn = sqlite3.connect(self.db_file)
|
| 70 |
+
conn.row_factory = sqlite3.Row
|
| 71 |
+
return conn
|
| 72 |
+
|
| 73 |
+
def delete_database_and_index(self):
|
| 74 |
+
if os.path.exists(self.db_file):
|
| 75 |
+
os.remove(self.db_file)
|
| 76 |
+
print(f"Removed old database: {self.db_file}")
|
| 77 |
+
if os.path.exists(self.index_file):
|
| 78 |
+
os.remove(self.index_file)
|
| 79 |
+
print(f"Removed old index: {self.index_file}")
|
| 80 |
+
|
| 81 |
+
def create_initial_index(self, documents_dict):
|
| 82 |
+
print("Performing a clean rebuild of the knowledge base...")
|
| 83 |
+
self.delete_database_and_index()
|
| 84 |
+
self.init_db()
|
| 85 |
+
|
| 86 |
+
conn = self.get_db_connection()
|
| 87 |
+
cursor = conn.cursor()
|
| 88 |
+
|
| 89 |
+
all_chunks = []
|
| 90 |
+
all_embeddings = []
|
| 91 |
+
|
| 92 |
+
for name, content in documents_dict.items():
|
| 93 |
+
cursor.execute("INSERT INTO documents (name) VALUES (?)", (name,))
|
| 94 |
+
doc_id = cursor.lastrowid
|
| 95 |
+
chunk_text = content
|
| 96 |
+
all_chunks.append((doc_id, 'text', encrypt_data(chunk_text.encode('utf-8')), 1))
|
| 97 |
+
text_embedding = self.model.encode([chunk_text])
|
| 98 |
+
all_embeddings.append(text_embedding)
|
| 99 |
+
|
| 100 |
+
cursor.executemany(
|
| 101 |
+
"INSERT INTO chunks (doc_id, content_type, encrypted_content, page_num) VALUES (?, ?, ?, ?)",
|
| 102 |
+
all_chunks
|
| 103 |
+
)
|
| 104 |
+
conn.commit()
|
| 105 |
+
conn.close()
|
| 106 |
+
|
| 107 |
+
if not all_embeddings:
|
| 108 |
+
print("No content to index.")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
embeddings_np = np.vstack(all_embeddings).astype('float32')
|
| 112 |
+
dimension = embeddings_np.shape[1]
|
| 113 |
+
index = faiss.IndexFlatL2(dimension)
|
| 114 |
+
index.add(embeddings_np)
|
| 115 |
+
faiss.write_index(index, self.index_file)
|
| 116 |
+
print(f"Initial encrypted index created with {len(all_chunks)} chunks.")
|
| 117 |
+
|
| 118 |
+
def ingest_pdf(self, file_path, file_name):
|
| 119 |
+
print(f"Starting ingestion for: {file_name}")
|
| 120 |
+
conn = self.get_db_connection()
|
| 121 |
+
cursor = conn.cursor()
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
cursor.execute("INSERT INTO documents (name) VALUES (?)", (file_name,))
|
| 125 |
+
doc_id = cursor.lastrowid
|
| 126 |
+
except conn.IntegrityError:
|
| 127 |
+
print("Document already exists in DB. Skipping doc table insert.")
|
| 128 |
+
doc_id = cursor.execute("SELECT id FROM documents WHERE name=?", (file_name,)).fetchone()['id']
|
| 129 |
+
|
| 130 |
+
doc = fitz.open(file_path)
|
| 131 |
+
new_embeddings = []
|
| 132 |
+
|
| 133 |
+
if os.path.exists(self.index_file):
|
| 134 |
+
index = faiss.read_index(self.index_file)
|
| 135 |
+
else:
|
| 136 |
+
dimension = self.model.encode(["test"]).shape[1]
|
| 137 |
+
index = faiss.IndexFlatL2(dimension)
|
| 138 |
+
|
| 139 |
+
for page_num, page in enumerate(doc):
|
| 140 |
+
text = page.get_text()
|
| 141 |
+
if text.strip():
|
| 142 |
+
encrypted_text = encrypt_data(text.encode('utf-8'))
|
| 143 |
+
cursor.execute(
|
| 144 |
+
"INSERT INTO chunks (doc_id, content_type, encrypted_content, page_num) VALUES (?, ?, ?, ?)",
|
| 145 |
+
(doc_id, 'text', encrypted_text, page_num + 1)
|
| 146 |
+
)
|
| 147 |
+
text_embedding = self.model.encode([text])
|
| 148 |
+
new_embeddings.append(text_embedding)
|
| 149 |
+
|
| 150 |
+
image_list = page.get_images(full=True)
|
| 151 |
+
for img_index, img in enumerate(image_list):
|
| 152 |
+
xref = img[0]
|
| 153 |
+
base_image = doc.extract_image(xref)
|
| 154 |
+
image_bytes = base_image["image"]
|
| 155 |
+
encrypted_image = encrypt_data(image_bytes)
|
| 156 |
+
cursor.execute(
|
| 157 |
+
"INSERT INTO chunks (doc_id, content_type, encrypted_content, page_num) VALUES (?, ?, ?, ?)",
|
| 158 |
+
(doc_id, 'image', encrypted_image, page_num + 1)
|
| 159 |
+
)
|
| 160 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
| 161 |
+
image_embedding = self.model.encode(pil_image)
|
| 162 |
+
new_embeddings.append(image_embedding.reshape(1, -1))
|
| 163 |
+
|
| 164 |
+
conn.commit()
|
| 165 |
+
conn.close()
|
| 166 |
+
|
| 167 |
+
if new_embeddings:
|
| 168 |
+
embeddings_np = np.vstack(new_embeddings).astype('float32')
|
| 169 |
+
index.add(embeddings_np)
|
| 170 |
+
faiss.write_index(index, self.index_file)
|
| 171 |
+
print(f"Successfully ingested {file_name} and added {len(new_embeddings)} new chunks.")
|
| 172 |
+
else:
|
| 173 |
+
print(f"No new content found to ingest in {file_name}.")
|
| 174 |
+
|
| 175 |
+
def search(self, query, k=1):
|
| 176 |
+
if not os.path.exists(self.index_file):
|
| 177 |
+
return []
|
| 178 |
+
|
| 179 |
+
index = faiss.read_index(self.index_file)
|
| 180 |
+
query_embedding = self.model.encode([query]).astype('float32')
|
| 181 |
+
distances, indices = index.search(query_embedding, k)
|
| 182 |
+
|
| 183 |
+
results = []
|
| 184 |
+
conn = self.get_db_connection()
|
| 185 |
+
for i, faiss_id in enumerate(indices[0]):
|
| 186 |
+
if faiss_id != -1:
|
| 187 |
+
sql_id = int(faiss_id) + 1
|
| 188 |
+
chunk_record = conn.execute('SELECT * FROM chunks WHERE id = ?', (sql_id,)).fetchone()
|
| 189 |
+
if chunk_record:
|
| 190 |
+
content_type = chunk_record['content_type']
|
| 191 |
+
decrypted_content_bytes = decrypt_data(chunk_record['encrypted_content'])
|
| 192 |
+
if content_type == 'text':
|
| 193 |
+
content = decrypted_content_bytes.decode('utf-8')
|
| 194 |
+
elif content_type == 'image':
|
| 195 |
+
content = Image.open(io.BytesIO(decrypted_content_bytes))
|
| 196 |
+
results.append({
|
| 197 |
+
'distance': distances[0][i],
|
| 198 |
+
'content': content,
|
| 199 |
+
'type': content_type,
|
| 200 |
+
'page': chunk_record['page_num']
|
| 201 |
+
})
|
| 202 |
+
conn.close()
|
| 203 |
+
return results
|
| 204 |
+
|
| 205 |
def get_retriever():
|
| 206 |
+
kb = KnowledgeBase()
|
| 207 |
+
# This is a placeholder to maintain compatibility with the existing code.
|
| 208 |
+
# The actual search will be done using kb.search()
|
| 209 |
+
class Retriever:
|
| 210 |
+
def __init__(self, kb):
|
| 211 |
+
self.kb = kb
|
| 212 |
+
def get_relevant_documents(self, query):
|
| 213 |
+
results = self.kb.search(query)
|
| 214 |
+
# Langchain retrievers expect a list of Document objects.
|
| 215 |
+
# We will return the content of the documents for now.
|
| 216 |
+
from langchain.schema import Document
|
| 217 |
+
return [Document(page_content=r['content']) if r['type'] == 'text' else r['content'] for r in results]
|
| 218 |
+
|
| 219 |
+
return Retriever(kb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_database.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
DB_FILE = "auramind_local.db"
|
| 5 |
+
INDEX_FILE = "auramind_faiss.index"
|
| 6 |
+
|
| 7 |
+
def init_db():
|
| 8 |
+
"""
|
| 9 |
+
Initializes a more robust database schema for multimodal data.
|
| 10 |
+
- 'documents' table tracks the source files.
|
| 11 |
+
- 'chunks' table stores the individual encrypted text/image chunks.
|
| 12 |
+
"""
|
| 13 |
+
conn = sqlite3.connect(DB_FILE)
|
| 14 |
+
cursor = conn.cursor()
|
| 15 |
+
|
| 16 |
+
# Table to track the source documents (e.g., 'healthy_maize.txt', 'user_guide.pdf')
|
| 17 |
+
cursor.execute('''
|
| 18 |
+
CREATE TABLE IF NOT EXISTS documents (
|
| 19 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 20 |
+
name TEXT NOT NULL UNIQUE
|
| 21 |
+
)
|
| 22 |
+
''')
|
| 23 |
+
|
| 24 |
+
# Table to store each chunk of content (text or image)
|
| 25 |
+
# The faiss_id will correspond to the row number in the FAISS index
|
| 26 |
+
cursor.execute('''
|
| 27 |
+
CREATE TABLE IF NOT EXISTS chunks (
|
| 28 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 29 |
+
doc_id INTEGER,
|
| 30 |
+
content_type TEXT NOT NULL, -- 'text' or 'image'
|
| 31 |
+
encrypted_content BLOB NOT NULL,
|
| 32 |
+
page_num INTEGER,
|
| 33 |
+
FOREIGN KEY (doc_id) REFERENCES documents (id)
|
| 34 |
+
)
|
| 35 |
+
''')
|
| 36 |
+
conn.commit()
|
| 37 |
+
conn.close()
|
| 38 |
+
|
| 39 |
+
def get_db_connection():
|
| 40 |
+
"""Establishes a connection to the database."""
|
| 41 |
+
conn = sqlite3.connect(DB_FILE)
|
| 42 |
+
conn.row_factory = sqlite3.Row
|
| 43 |
+
return conn
|
| 44 |
+
|
| 45 |
+
def check_if_indexed():
|
| 46 |
+
"""Checks if the initial database and index file exist."""
|
| 47 |
+
# A basic check. A more robust check might query the db for content.
|
| 48 |
+
return os.path.exists(DB_FILE) and os.path.exists(INDEX_FILE)
|
| 49 |
+
|
| 50 |
+
def delete_database_and_index():
|
| 51 |
+
"""Deletes existing db and index files for a clean rebuild."""
|
| 52 |
+
if os.path.exists(DB_FILE):
|
| 53 |
+
os.remove(DB_FILE)
|
| 54 |
+
print(f"Removed old database: {DB_FILE}")
|
| 55 |
+
if os.path.exists(INDEX_FILE):
|
| 56 |
+
os.remove(INDEX_FILE)
|
| 57 |
+
print(f"Removed old index: {INDEX_FILE}")
|
requirements.txt
CHANGED
|
@@ -17,4 +17,5 @@ duckduckgo-search
|
|
| 17 |
langgraph
|
| 18 |
google-genai
|
| 19 |
google-adk
|
| 20 |
-
pypdf
|
|
|
|
|
|
| 17 |
langgraph
|
| 18 |
google-genai
|
| 19 |
google-adk
|
| 20 |
+
pypdf
|
| 21 |
+
google-cloud-bigquery
|