# 3. model_utils.py
# Model management (loading, prediction, and species information)
from transformers import ViTForImageClassification
from PIL import Image
import torch
from dataset_utils import DatasetHandler
import threading

class BugClassifier:
    def __init__(self, model_path="google/vit-base-patch16-224"):
        self.model = ViTForImageClassification.from_pretrained(model_path)
        self.model.eval()
        self.labels = [
            "Seven-spotted Ladybug", "Monarch Butterfly", "Carpenter Ant",
            "Japanese Beetle", "Garden Spider", "Green Grasshopper",
            "Luna Moth", "Common Dragonfly", "Honey Bee", "Paper Wasp"
        ]
        self.species_descriptions = {}
        self.load_species_descriptions()

    def load_species_descriptions(self):
        def load():
            handler = DatasetHandler()
            self.species_descriptions = handler.load_descriptions(max_records=500)

        thread = threading.Thread(target=load)
        thread.start()

    def predict(self, image):
        try:
            processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
            inputs = processor(images=image, return_tensors="pt")
            with torch.no_grad():
                outputs = self.model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
                confidence, predicted_idx = probabilities.max(dim=1)
                confidence = confidence.item() * 100
                predicted_label = self.labels[predicted_idx.item()]

                if confidence < 30:
                    return "Unknown Insect", confidence

                return predicted_label, confidence
        except Exception as e:
            return "Error Processing Image", 0.0

    def get_species_info(self, species):
        return self.species_descriptions.get(
            species, "Information not available. Consider updating your dataset for this species."
        )