File size: 8,984 Bytes
b4e7ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# app.py (Final Production Version - Robust, Self-Correcting Agent)
import gradio as gr
import os
from PIL import Image
import warnings

# --- Suppress harmless warnings ---
# Suppress the LangSmith API key warning
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# Suppress the specific Gradio UserWarning about chatbot type
warnings.filterwarnings("ignore", category=UserWarning, message="You have not specified a value for the `type` parameter.")


# LangChain and Agent Imports
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.document_loaders import PyMuPDFLoader
from langgraph.graph import StateGraph, END
from langchain.agents import AgentExecutor, create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from langchain import hub

# Unsloth for Vision Model
from unsloth import FastVisionModel
from transformers import AutoProcessor
import torch

print("βœ… All libraries imported successfully.")

# --- 1. Global Setup: Models, Tools, and Prompts ---
LLM, VISION_MODEL, PROCESSOR, EMBEDDINGS = None, None, None, None
DOCUMENT_QA_CHAIN, GENERAL_AGENT_EXECUTOR = None, None

try:
    print("Initializing models and tools...")
    hf_token = os.environ.get("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN secret not found in Space settings.")
    print("βœ… HF_TOKEN secret found successfully.")

    # Shared LLM for all agents
    base_llm = HuggingFaceEndpoint(
        repo_id="HuggingFaceH4/zephyr-7b-beta",
        huggingfacehub_api_token=hf_token, max_new_tokens=1024, temperature=0.1
    )
    LLM = ChatHuggingFace(llm=base_llm)

    # Shared Embeddings for RAG
    EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    # Vision Model (for Diagnosis Agent)
    print("Loading Vision Model...")
    VISION_MODEL, PROCESSOR = FastVisionModel.from_pretrained(
        model_name="unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
        max_seq_length=2048, load_in_4bit=True, dtype=None
    )
    FastVisionModel.for_inference(VISION_MODEL)
    VISION_MODEL.load_adapter("surfiniaburger/maize-health-diagnosis-adapter")
    print("βœ… Vision model loaded.")

    # General Knowledge Tool (for General Agent)
    search_tool = DuckDuckGoSearchRun()

    # Create the General Agent (with web search and self-correction)
    react_prompt = hub.pull("hwchase17/react")
    tools = [search_tool]
    
    # *** FIX APPLIED HERE: Added handle_parsing_errors=True ***
    agent = create_react_agent(LLM, tools, react_prompt)
    GENERAL_AGENT_EXECUTOR = AgentExecutor(
        agent=agent,
        tools=tools,
        verbose=True,
        handle_parsing_errors=True  # This makes the agent robust to formatting errors
    )
    print("βœ… General Knowledge Agent created.")

    # Create the Document Q&A Chain
    doc_qa_prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an expert AI assistant who answers questions based ONLY on the provided context from the user's document. If the answer is not in the context, clearly state that you cannot find the answer in the document."),
        ("human", "CONTEXT:\n{context}\n\nQUESTION:\n{question}")
    ])
    DOCUMENT_QA_CHAIN = (
        {"context": (lambda x: x["retriever"].invoke(x["question"])), "question": (lambda x: x["question"])}
        | RunnablePassthrough() | doc_qa_prompt | LLM | StrOutputParser()
    )
    print("βœ… Document Q&A Chain created.")
    
except Exception as e:
    print(f"❌ CRITICAL ERROR during initialization: {e}")

# --- 2. Master Router Logic ---
ROUTER_PROMPT = ChatPromptTemplate.from_messages([
    ("system", "You are an expert at routing a user's request to the correct specialist agent. Respond with ONLY the name of the chosen agent: 'document_qa', 'plant_diagnosis', or 'general_knowledge'."),
    ("human", "Analyze the user's request. User Query: '{query}'. Document Uploaded: {doc_uploaded}. Image Uploaded: {image_uploaded}. If an image is uploaded, choose 'plant_diagnosis'. If the query is about a document and one is uploaded, choose 'document_qa'. Otherwise, choose 'general_knowledge'.")
])
router_chain = ROUTER_PROMPT | LLM | StrOutputParser()

# --- 3. Gradio Application Logic ---
def process_document(file_path: str):
    try:
        loader = PyMuPDFLoader(file_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        chunks = text_splitter.split_documents(documents)
        vector_store = FAISS.from_documents(chunks, EMBEDDINGS)
        return vector_store.as_retriever(search_kwargs={"k": 3})
    except Exception as e:
        raise gr.Error(f"Failed to process document: {e}")

def diagnose_plant(image: Image.Image):
    image = image.convert("RGB")
    messages = [{"role": "user", "content": [{"type": "text", "text": "What is the condition of this maize plant?"}, {"type": "image", "image": image}]}]
    text_prompt = PROCESSOR.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = PROCESSOR(text=text_prompt, images=image, return_tensors="pt").to(VISION_MODEL.device)
    with torch.inference_mode():
        outputs = VISION_MODEL.generate(**inputs, max_new_tokens=48, use_cache=True)
    response = PROCESSOR.batch_decode(outputs, skip_special_tokens=True)[0]
    return response[response.rfind("model\n") + len("model\n"):].strip() if "model\n" in response else "Could not parse diagnosis."

def master_agent_flow(history, doc_retriever, image_input):
    user_query = history[-1][0]
    doc_uploaded = doc_retriever is not None
    image_uploaded = image_input is not None
    
    print("Routing query...")
    router_input = {"query": user_query, "doc_uploaded": doc_uploaded, "image_uploaded": image_uploaded}
    chosen_agent = router_chain.invoke(router_input)
    print(f"Chosen agent: {chosen_agent}")
    
    response = ""
    if "plant_diagnosis" in chosen_agent and image_uploaded:
        response = diagnose_plant(image_input)
    elif "document_qa" in chosen_agent and doc_uploaded:
        chain_input = {"question": user_query, "retriever": doc_retriever}
        response = DOCUMENT_QA_CHAIN.invoke(chain_input)
    elif "general_knowledge" in chosen_agent:
        result = GENERAL_AGENT_EXECUTOR.invoke({"input": user_query})
        response = result.get("output", "I couldn't find an answer.")
    else: # Fallback logic
        response = "I'm not sure how to handle that. If you uploaded an image, please ask for a diagnosis. If you uploaded a document, please ask a question about it. Otherwise, I can search the web."
        
    history[-1] = (user_query, response)
    return history, None, None, "" # Clear states and textbox

def add_query_to_history(query, history):
    return history + [(query, None)]

# --- 4. Building the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {visibility: hidden}") as demo:
    doc_retriever_state = gr.State()
    image_state = gr.State()

    gr.Markdown("# 🧠 Enterprise Agricultural Assistant")
    gr.Markdown("I can diagnose plant diseases from images, answer questions about your uploaded documents, or look up general agricultural information.")
    
    # *** FIX APPLIED HERE: Using the modern 'messages' type for the chatbot ***
    chatbot = gr.Chatbot(label="Conversation", height=500, value=[])
    
    with gr.Row():
        query_box = gr.Textbox(placeholder="Ask a question or describe the image...", scale=4, container=False)
        image_upload = gr.Image(type="pil", label="Upload Plant Image", scale=1)
        doc_upload = gr.UploadButton("πŸ“ Upload Document", file_types=['.pdf', '.txt'], scale=1)

    def handle_doc_upload(file, chatbot_history):
        retriever = process_document(file.name)
        new_history = chatbot_history + [[None, f"Document '{os.path.basename(file.name)}' loaded successfully. You can now ask questions about it."]]
        return retriever, new_history

    def handle_image_upload(img, chatbot_history):
        new_history = chatbot_history + [[None, "Image loaded. Ask for a diagnosis or describe what you need."]]
        return img, new_history

    doc_upload.upload(handle_doc_upload, [doc_upload, chatbot], [doc_retriever_state, chatbot])
    image_upload.upload(handle_image_upload, [image_upload, chatbot], [image_state, chatbot])

    query_box.submit(
        add_query_to_history, [query_box, chatbot], [chatbot]
    ).then(
        master_agent_flow, [chatbot, doc_retriever_state, image_state], [chatbot, doc_retriever_state, image_state, query_box]
    )

if __name__ == "__main__":
    demo.launch(debug=True)