Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from unsloth import FastVisionModel | |
from transformers import AutoProcessor | |
from PIL import Image | |
import torch | |
import os | |
# --- 1. Global Setup: Load Model and Processor --- | |
# This section runs only ONCE when the application starts. | |
print("Performing initial model setup...") | |
# Load the base model | |
BASE_MODEL_NAME = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" | |
# !!! IMPORTANT: Load the adapter from the Hugging Face Hub !!! | |
# Replace this with the model repo ID you created in Part 1. | |
ADAPTER_PATH = "surfiniaburger/maize-health-diagnosis-adapter" | |
model = None | |
processor = None | |
try: | |
print(f"Loading base model: {BASE_MODEL_NAME}") | |
model, processor = FastVisionModel.from_pretrained( | |
model_name=BASE_MODEL_NAME, | |
max_seq_length=2048, | |
load_in_4bit=True, | |
dtype=None, | |
) | |
FastVisionModel.for_inference(model) | |
print(f"Loading adapter from Hub: {ADAPTER_PATH}") | |
model.load_adapter(ADAPTER_PATH) # Load from the Hub | |
print("\nβ Model and adapter loaded successfully!") | |
except Exception as e: | |
print(f"β Critical error during model loading: {e}") | |
# Raise the exception to see the full traceback in the logs | |
raise e | |
# --- 2. Define the Core Prediction Function --- | |
def diagnose_maize_plant(uploaded_image: Image.Image) -> str: | |
if model is None or processor is None or uploaded_image is None: | |
return "Model is not loaded or no image was uploaded. Please check the Space logs for errors." | |
image = uploaded_image.convert("RGB") | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "What is the condition of this maize plant?"}, | |
{"type": "image", "image": image}, | |
], | |
} | |
] | |
text_prompt = processor.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = processor( | |
text=text_prompt, | |
images=image, | |
return_tensors="pt" | |
).to(model.device) | |
with torch.inference_mode(): | |
outputs = model.generate(**inputs, max_new_tokens=128, use_cache=True) | |
response = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
prompt_marker = "model\n" | |
answer_start_index = response.rfind(prompt_marker) | |
if answer_start_index != -1: | |
final_answer = response[answer_start_index + len(prompt_marker):].strip() | |
else: | |
final_answer = "Could not parse model's response. Raw output: " + response | |
return final_answer | |
# --- 3. Build and Launch the Gradio Interface --- | |
print("Building Gradio interface...") | |
demo = gr.Interface( | |
fn=diagnose_maize_plant, | |
inputs=gr.Image(type="pil", label="Upload Maize Plant Image"), | |
outputs=gr.Textbox(label="Diagnosis", lines=4), | |
title="π½ Maize Health Diagnosis Assistant", | |
description="Upload an image of a maize plant, and the AI will analyze its condition. This tool is powered by a fine-tuned Gemma-3N vision model.", | |
article="Built with Unsloth and Gradio. Model fine-tuned on Kaggle.", | |
allow_flagging="never", | |
) | |
print("Launching Gradio app...") | |
demo.launch() |