Spaces:
Sleeping
Sleeping
# app.py (Final Production Version - Robust, Self-Correcting Agent) | |
import gradio as gr | |
import os | |
from PIL import Image | |
import warnings | |
# --- Suppress harmless warnings --- | |
# Suppress the LangSmith API key warning | |
os.environ["LANGCHAIN_TRACING_V2"] = "false" | |
# Suppress the specific Gradio UserWarning about chatbot type | |
warnings.filterwarnings("ignore", category=UserWarning, message="You have not specified a value for the `type` parameter.") | |
# LangChain and Agent Imports | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langgraph.graph import StateGraph, END | |
from langchain.agents import AgentExecutor, create_react_agent | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain import hub | |
# Unsloth for Vision Model | |
from unsloth import FastVisionModel | |
from transformers import AutoProcessor | |
import torch | |
print("β All libraries imported successfully.") | |
# --- 1. Global Setup: Models, Tools, and Prompts --- | |
LLM, VISION_MODEL, PROCESSOR, EMBEDDINGS = None, None, None, None | |
DOCUMENT_QA_CHAIN, GENERAL_AGENT_EXECUTOR = None, None | |
try: | |
print("Initializing models and tools...") | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN secret not found in Space settings.") | |
print("β HF_TOKEN secret found successfully.") | |
# Shared LLM for all agents | |
base_llm = HuggingFaceEndpoint( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
huggingfacehub_api_token=hf_token, max_new_tokens=1024, temperature=0.1 | |
) | |
LLM = ChatHuggingFace(llm=base_llm) | |
# Shared Embeddings for RAG | |
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Vision Model (for Diagnosis Agent) | |
print("Loading Vision Model...") | |
VISION_MODEL, PROCESSOR = FastVisionModel.from_pretrained( | |
model_name="unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit", | |
max_seq_length=2048, load_in_4bit=True, dtype=None | |
) | |
FastVisionModel.for_inference(VISION_MODEL) | |
VISION_MODEL.load_adapter("surfiniaburger/maize-health-diagnosis-adapter") | |
print("β Vision model loaded.") | |
# General Knowledge Tool (for General Agent) | |
search_tool = DuckDuckGoSearchRun() | |
# Create the General Agent (with web search and self-correction) | |
react_prompt = hub.pull("hwchase17/react") | |
tools = [search_tool] | |
# *** FIX APPLIED HERE: Added handle_parsing_errors=True *** | |
agent = create_react_agent(LLM, tools, react_prompt) | |
GENERAL_AGENT_EXECUTOR = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
handle_parsing_errors=True # This makes the agent robust to formatting errors | |
) | |
print("β General Knowledge Agent created.") | |
# Create the Document Q&A Chain | |
doc_qa_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are an expert AI assistant who answers questions based ONLY on the provided context from the user's document. If the answer is not in the context, clearly state that you cannot find the answer in the document."), | |
("human", "CONTEXT:\n{context}\n\nQUESTION:\n{question}") | |
]) | |
DOCUMENT_QA_CHAIN = ( | |
{"context": (lambda x: x["retriever"].invoke(x["question"])), "question": (lambda x: x["question"])} | |
| RunnablePassthrough() | doc_qa_prompt | LLM | StrOutputParser() | |
) | |
print("β Document Q&A Chain created.") | |
except Exception as e: | |
print(f"β CRITICAL ERROR during initialization: {e}") | |
# --- 2. Master Router Logic --- | |
ROUTER_PROMPT = ChatPromptTemplate.from_messages([ | |
("system", "You are an expert at routing a user's request to the correct specialist agent. Respond with ONLY the name of the chosen agent: 'document_qa', 'plant_diagnosis', or 'general_knowledge'."), | |
("human", "Analyze the user's request. User Query: '{query}'. Document Uploaded: {doc_uploaded}. Image Uploaded: {image_uploaded}. If an image is uploaded, choose 'plant_diagnosis'. If the query is about a document and one is uploaded, choose 'document_qa'. Otherwise, choose 'general_knowledge'.") | |
]) | |
router_chain = ROUTER_PROMPT | LLM | StrOutputParser() | |
# --- 3. Gradio Application Logic --- | |
def process_document(file_path: str): | |
try: | |
loader = PyMuPDFLoader(file_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = text_splitter.split_documents(documents) | |
vector_store = FAISS.from_documents(chunks, EMBEDDINGS) | |
return vector_store.as_retriever(search_kwargs={"k": 3}) | |
except Exception as e: | |
raise gr.Error(f"Failed to process document: {e}") | |
def diagnose_plant(image: Image.Image): | |
image = image.convert("RGB") | |
messages = [{"role": "user", "content": [{"type": "text", "text": "What is the condition of this maize plant?"}, {"type": "image", "image": image}]}] | |
text_prompt = PROCESSOR.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = PROCESSOR(text=text_prompt, images=image, return_tensors="pt").to(VISION_MODEL.device) | |
with torch.inference_mode(): | |
outputs = VISION_MODEL.generate(**inputs, max_new_tokens=48, use_cache=True) | |
response = PROCESSOR.batch_decode(outputs, skip_special_tokens=True)[0] | |
return response[response.rfind("model\n") + len("model\n"):].strip() if "model\n" in response else "Could not parse diagnosis." | |
def master_agent_flow(history, doc_retriever, image_input): | |
user_query = history[-1][0] | |
doc_uploaded = doc_retriever is not None | |
image_uploaded = image_input is not None | |
print("Routing query...") | |
router_input = {"query": user_query, "doc_uploaded": doc_uploaded, "image_uploaded": image_uploaded} | |
chosen_agent = router_chain.invoke(router_input) | |
print(f"Chosen agent: {chosen_agent}") | |
response = "" | |
if "plant_diagnosis" in chosen_agent and image_uploaded: | |
response = diagnose_plant(image_input) | |
elif "document_qa" in chosen_agent and doc_uploaded: | |
chain_input = {"question": user_query, "retriever": doc_retriever} | |
response = DOCUMENT_QA_CHAIN.invoke(chain_input) | |
elif "general_knowledge" in chosen_agent: | |
result = GENERAL_AGENT_EXECUTOR.invoke({"input": user_query}) | |
response = result.get("output", "I couldn't find an answer.") | |
else: # Fallback logic | |
response = "I'm not sure how to handle that. If you uploaded an image, please ask for a diagnosis. If you uploaded a document, please ask a question about it. Otherwise, I can search the web." | |
history[-1] = (user_query, response) | |
return history, None, None, "" # Clear states and textbox | |
def add_query_to_history(query, history): | |
return history + [(query, None)] | |
# --- 4. Building the Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo: | |
doc_retriever_state = gr.State() | |
image_state = gr.State() | |
gr.Markdown("# π§ Enterprise Agricultural Assistant") | |
gr.Markdown("I can diagnose plant diseases from images, answer questions about your uploaded documents, or look up general agricultural information.") | |
# *** FIX APPLIED HERE: Using the modern 'messages' type for the chatbot *** | |
chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) | |
with gr.Row(): | |
query_box = gr.Textbox(placeholder="Ask a question or describe the image...", scale=4, container=False) | |
image_upload = gr.Image(type="pil", label="Upload Plant Image", scale=1) | |
doc_upload = gr.UploadButton("π Upload Document", file_types=['.pdf', '.txt'], scale=1) | |
def handle_doc_upload(file, chatbot_history): | |
retriever = process_document(file.name) | |
new_history = chatbot_history + [[None, f"Document '{os.path.basename(file.name)}' loaded successfully. You can now ask questions about it."]] | |
return retriever, new_history | |
def handle_image_upload(img, chatbot_history): | |
new_history = chatbot_history + [[None, "Image loaded. Ask for a diagnosis or describe what you need."]] | |
return img, new_history | |
doc_upload.upload(handle_doc_upload, [doc_upload, chatbot], [doc_retriever_state, chatbot]) | |
image_upload.upload(handle_image_upload, [image_upload, chatbot], [image_state, chatbot]) | |
query_box.submit( | |
add_query_to_history, [query_box, chatbot], [chatbot] | |
).then( | |
master_agent_flow, [chatbot, doc_retriever_state, image_state], [chatbot, doc_retriever_state, image_state, query_box] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |