surfiniaburger's picture
ok
a9f948c
# app.py (Final Production Version - Robust, Self-Correcting Agent)
import gradio as gr
import os
from PIL import Image
import warnings
# --- Suppress harmless warnings ---
# Suppress the LangSmith API key warning
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# Suppress the specific Gradio UserWarning about chatbot type
warnings.filterwarnings("ignore", category=UserWarning, message="You have not specified a value for the `type` parameter.")
# LangChain and Agent Imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langgraph.graph import StateGraph, END
from langchain.agents import AgentExecutor, create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from langchain import hub
# Unsloth for Vision Model
from unsloth import FastVisionModel
from transformers import AutoProcessor
import torch
print("βœ… All libraries imported successfully.")
# --- 1. Global Setup: Models, Tools, and Prompts ---
LLM, VISION_MODEL, PROCESSOR, EMBEDDINGS = None, None, None, None
DOCUMENT_QA_CHAIN, GENERAL_AGENT_EXECUTOR = None, None
try:
print("Initializing models and tools...")
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN secret not found in Space settings.")
print("βœ… HF_TOKEN secret found successfully.")
# Shared LLM for all agents
base_llm = HuggingFaceEndpoint(
repo_id="HuggingFaceH4/zephyr-7b-beta",
huggingfacehub_api_token=hf_token, max_new_tokens=1024, temperature=0.1
)
LLM = ChatHuggingFace(llm=base_llm)
# Shared Embeddings for RAG
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Vision Model (for Diagnosis Agent)
print("Loading Vision Model...")
VISION_MODEL, PROCESSOR = FastVisionModel.from_pretrained(
model_name="unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
max_seq_length=2048, load_in_4bit=True, dtype=None
)
FastVisionModel.for_inference(VISION_MODEL)
VISION_MODEL.load_adapter("surfiniaburger/maize-health-diagnosis-adapter")
print("βœ… Vision model loaded.")
# General Knowledge Tool (for General Agent)
search_tool = DuckDuckGoSearchRun()
# Create the General Agent (with web search and self-correction)
react_prompt = hub.pull("hwchase17/react")
tools = [search_tool]
# *** FIX APPLIED HERE: Added handle_parsing_errors=True ***
agent = create_react_agent(LLM, tools, react_prompt)
GENERAL_AGENT_EXECUTOR = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True # This makes the agent robust to formatting errors
)
print("βœ… General Knowledge Agent created.")
# Create the Document Q&A Chain
doc_qa_prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert AI assistant who answers questions based ONLY on the provided context from the user's document. If the answer is not in the context, clearly state that you cannot find the answer in the document."),
("human", "CONTEXT:\n{context}\n\nQUESTION:\n{question}")
])
DOCUMENT_QA_CHAIN = (
{"context": (lambda x: x["retriever"].invoke(x["question"])), "question": (lambda x: x["question"])}
| RunnablePassthrough() | doc_qa_prompt | LLM | StrOutputParser()
)
print("βœ… Document Q&A Chain created.")
except Exception as e:
print(f"❌ CRITICAL ERROR during initialization: {e}")
# --- 2. Master Router Logic ---
ROUTER_PROMPT = ChatPromptTemplate.from_messages([
("system", "You are an expert at routing a user's request to the correct specialist agent. Respond with ONLY the name of the chosen agent: 'document_qa', 'plant_diagnosis', or 'general_knowledge'."),
("human", "Analyze the user's request. User Query: '{query}'. Document Uploaded: {doc_uploaded}. Image Uploaded: {image_uploaded}. If an image is uploaded, choose 'plant_diagnosis'. If the query is about a document and one is uploaded, choose 'document_qa'. Otherwise, choose 'general_knowledge'.")
])
router_chain = ROUTER_PROMPT | LLM | StrOutputParser()
# --- 3. Gradio Application Logic ---
def process_document(file_path: str):
try:
loader = PyMuPDFLoader(file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(chunks, EMBEDDINGS)
return vector_store.as_retriever(search_kwargs={"k": 3})
except Exception as e:
raise gr.Error(f"Failed to process document: {e}")
def diagnose_plant(image: Image.Image):
image = image.convert("RGB")
messages = [{"role": "user", "content": [{"type": "text", "text": "What is the condition of this maize plant?"}, {"type": "image", "image": image}]}]
text_prompt = PROCESSOR.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = PROCESSOR(text=text_prompt, images=image, return_tensors="pt").to(VISION_MODEL.device)
with torch.inference_mode():
outputs = VISION_MODEL.generate(**inputs, max_new_tokens=48, use_cache=True)
response = PROCESSOR.batch_decode(outputs, skip_special_tokens=True)[0]
return response[response.rfind("model\n") + len("model\n"):].strip() if "model\n" in response else "Could not parse diagnosis."
def master_agent_flow(history, doc_retriever, image_input):
user_query = history[-1][0]
doc_uploaded = doc_retriever is not None
image_uploaded = image_input is not None
print("Routing query...")
router_input = {"query": user_query, "doc_uploaded": doc_uploaded, "image_uploaded": image_uploaded}
chosen_agent = router_chain.invoke(router_input)
print(f"Chosen agent: {chosen_agent}")
response = ""
if "plant_diagnosis" in chosen_agent and image_uploaded:
response = diagnose_plant(image_input)
elif "document_qa" in chosen_agent and doc_uploaded:
chain_input = {"question": user_query, "retriever": doc_retriever}
response = DOCUMENT_QA_CHAIN.invoke(chain_input)
elif "general_knowledge" in chosen_agent:
result = GENERAL_AGENT_EXECUTOR.invoke({"input": user_query})
response = result.get("output", "I couldn't find an answer.")
else: # Fallback logic
response = "I'm not sure how to handle that. If you uploaded an image, please ask for a diagnosis. If you uploaded a document, please ask a question about it. Otherwise, I can search the web."
history[-1] = (user_query, response)
return history, None, None, "" # Clear states and textbox
def add_query_to_history(query, history):
return history + [(query, None)]
# --- 4. Building the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
doc_retriever_state = gr.State()
image_state = gr.State()
gr.Markdown("# 🧠 Enterprise Agricultural Assistant")
gr.Markdown("I can diagnose plant diseases from images, answer questions about your uploaded documents, or look up general agricultural information.")
# *** FIX APPLIED HERE: Using the modern 'messages' type for the chatbot ***
chatbot = gr.Chatbot(label="Conversation", height=500, value=[])
with gr.Row():
query_box = gr.Textbox(placeholder="Ask a question or describe the image...", scale=4, container=False)
image_upload = gr.Image(type="pil", label="Upload Plant Image", scale=1)
doc_upload = gr.UploadButton("πŸ“ Upload Document", file_types=['.pdf', '.txt'], scale=1)
def handle_doc_upload(file, chatbot_history):
retriever = process_document(file.name)
new_history = chatbot_history + [[None, f"Document '{os.path.basename(file.name)}' loaded successfully. You can now ask questions about it."]]
return retriever, new_history
def handle_image_upload(img, chatbot_history):
new_history = chatbot_history + [[None, "Image loaded. Ask for a diagnosis or describe what you need."]]
return img, new_history
doc_upload.upload(handle_doc_upload, [doc_upload, chatbot], [doc_retriever_state, chatbot])
image_upload.upload(handle_image_upload, [image_upload, chatbot], [image_state, chatbot])
query_box.submit(
add_query_to_history, [query_box, chatbot], [chatbot]
).then(
master_agent_flow, [chatbot, doc_retriever_state, image_state], [chatbot, doc_retriever_state, image_state, query_box]
)
if __name__ == "__main__":
demo.launch(debug=True)