# 3. model_utils.py # Model management (loading, prediction, and species information) from transformers import ViTForImageClassification from PIL import Image import torch from dataset_utils import DatasetHandler import threading class BugClassifier: def __init__(self, model_path="google/vit-base-patch16-224"): self.model = ViTForImageClassification.from_pretrained(model_path) self.model.eval() self.labels = [ "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant", "Japanese Beetle", "Garden Spider", "Green Grasshopper", "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp" ] self.species_descriptions = {} self.load_species_descriptions() def load_species_descriptions(self): def load(): handler = DatasetHandler() self.species_descriptions = handler.load_descriptions(max_records=500) thread = threading.Thread(target=load) thread.start() def predict(self, image): try: processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) confidence, predicted_idx = probabilities.max(dim=1) confidence = confidence.item() * 100 predicted_label = self.labels[predicted_idx.item()] if confidence < 30: return "Unknown Insect", confidence return predicted_label, confidence except Exception as e: return "Error Processing Image", 0.0 def get_species_info(self, species): return self.species_descriptions.get( species, "Information not available. Consider updating your dataset for this species." )