Bug-O-Scope / model_utils.py
dalybuilds's picture
Update model_utils.py
b04c0f2 verified
# 3. model_utils.py
# Model management (loading, prediction, and species information)
from transformers import ViTForImageClassification
from PIL import Image
import torch
from dataset_utils import DatasetHandler
import threading
class BugClassifier:
def __init__(self, model_path="google/vit-base-patch16-224"):
self.model = ViTForImageClassification.from_pretrained(model_path)
self.model.eval()
self.labels = [
"Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
"Japanese Beetle", "Garden Spider", "Green Grasshopper",
"Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
]
self.species_descriptions = {}
self.load_species_descriptions()
def load_species_descriptions(self):
def load():
handler = DatasetHandler()
self.species_descriptions = handler.load_descriptions(max_records=500)
thread = threading.Thread(target=load)
thread.start()
def predict(self, image):
try:
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
confidence, predicted_idx = probabilities.max(dim=1)
confidence = confidence.item() * 100
predicted_label = self.labels[predicted_idx.item()]
if confidence < 30:
return "Unknown Insect", confidence
return predicted_label, confidence
except Exception as e:
return "Error Processing Image", 0.0
def get_species_info(self, species):
return self.species_descriptions.get(
species, "Information not available. Consider updating your dataset for this species."
)