plant_disease_classification / plant_disease_classifier.py
ryanwang058
Fix model input
8cff122
raw
history blame
6.26 kB
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, vit_b_32
from transformers import LevitForImageClassification, logging
logging.set_verbosity_error()
class PlantDiseaseClassifier:
class_names = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']
def __init__(self, model_type, model_path, batch_size=32):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.batch_size = batch_size
self.model_type = model_type
# Initialize and load the appropriate model
self.model = self._load_model(model_type, model_path)
self.model = self.model.to(self.device)
self.model.eval()
# Data transformation pipeline
self.data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def _load_model(self, model_type, model_path):
if model_type == "resnet":
model = resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(self.class_names))
model.load_state_dict(torch.load(model_path, map_location=self.device))
elif model_type == "levit":
model = LevitForImageClassification.from_pretrained(
"facebook/levit-128S", num_labels=len(self.class_names), ignore_mismatched_sizes=True,
)
state_dict = torch.load(model_path, map_location=self.device)
filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier_distill")}
model.load_state_dict(filtered_state_dict)
elif model_type == "vit":
model = vit_b_32(pretrained=False, num_classes=len(self.class_names))
model.load_state_dict(torch.load(model_path, map_location=self.device))
else:
raise ValueError(f"Unsupported model type: {model_type}")
return model
class _PlantDiseaseDataset(Dataset):
def __init__(self, directory_path, transform=None):
self.directory_path = directory_path
self.transform = transform
# Collect all images and their respective class labels
self.image_files = []
self.labels = []
for class_name in os.listdir(directory_path):
class_dir = os.path.join(directory_path, class_name)
if os.path.isdir(class_dir) and class_name in PlantDiseaseClassifier.class_names:
for img_file in os.listdir(class_dir):
if img_file.lower().endswith(('.jpg', '.png')):
self.image_files.append(os.path.join(class_dir, img_file))
self.labels.append(PlantDiseaseClassifier.class_names.index(class_name))
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
def calculate_accuracy(self, test_dir):
dataset = self._PlantDiseaseDataset(test_dir, transform=self.data_transforms)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
outputs = self.model(inputs)
logits = outputs.logits if self.model_type in ["levit"] else outputs
_, preds = torch.max(logits, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = (correct / total) * 100 if total > 0 else 0.0
return accuracy
def predict(self, image):
# Ensure the image is in RGB format if not already
if image.mode != "RGB":
image = image.convert("RGB")
# Transform the image to match the model's input requirements
transformed_image = self.data_transforms(image).unsqueeze(0)
transformed_image = transformed_image.to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(transformed_image)
logits = outputs.logits if self.model_type in ["levit"] else outputs
_, predicted_idx = torch.max(logits, 1)
predicted_class = self.class_names[predicted_idx.item()]
return predicted_class
def predict_image_with_all_models(image_path, classifiers):
actual_disease = os.path.basename(os.path.dirname(image_path))
print(f"Actual disease: {actual_disease}\n")
for model_name, classifier in classifiers.items():
predicted_class = classifier.predict_image(image_path)
print(f"Model: {model_name}, Predicted Class: {predicted_class}")