import torch from torchvision.models import resnet18 import torchvision.transforms as T import json mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) def load_classes(): ''' Returns IMAGENET classes ''' with open('utils/imagenet-simple-labels.json') as f: labels = json.load(f) return labels def class_id_to_label(i): ''' Input int: class index Returns class name ''' labels = load_classes() return labels[i] def load_model(): ''' Returns resnet model with IMAGENET weights ''' model = resnet18() model.load_state_dict(torch.load('utils/resnet18-weights.pth', map_location='cpu')) model.eval() return model def transform_image(img): ''' Input: PIL img Returns: transformed image ''' trnsfrms = T.Compose( [ T.Resize((224, 224)), T.CenterCrop(100), T.ToTensor(), T.Normalize(mean, std) ] ) return trnsfrms(img)