Spaces:
Sleeping
Sleeping
# 3. model_utils.py | |
# Model management (loading, prediction, and species information) | |
from transformers import ViTForImageClassification | |
from PIL import Image | |
import torch | |
from dataset_utils import DatasetHandler | |
import threading | |
class BugClassifier: | |
def __init__(self, model_path="google/vit-base-patch16-224"): | |
self.model = ViTForImageClassification.from_pretrained(model_path) | |
self.model.eval() | |
self.labels = [ | |
"Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", | |
"Japanese Beetle", "Garden Spider", "Green Grasshopper", | |
"Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" | |
] | |
self.species_descriptions = {} | |
self.load_species_descriptions() | |
def load_species_descriptions(self): | |
def load(): | |
handler = DatasetHandler() | |
self.species_descriptions = handler.load_descriptions(max_records=500) | |
thread = threading.Thread(target=load) | |
thread.start() | |
def predict(self, image): | |
try: | |
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
confidence, predicted_idx = probabilities.max(dim=1) | |
confidence = confidence.item() * 100 | |
predicted_label = self.labels[predicted_idx.item()] | |
if confidence < 30: | |
return "Unknown Insect", confidence | |
return predicted_label, confidence | |
except Exception as e: | |
return "Error Processing Image", 0.0 | |
def get_species_info(self, species): | |
return self.species_descriptions.get( | |
species, "Information not available. Consider updating your dataset for this species." | |
) |