import torch
from torchvision.models import resnet18
import torchvision.transforms as T
import json

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

def load_classes():
    '''
    Returns IMAGENET classes
    '''
    with open('utils/imagenet-simple-labels.json') as f:
        labels = json.load(f)
    return labels

def class_id_to_label(i):
    '''
    Input int: class index
    Returns class name
    '''

    labels = load_classes()
    return labels[i]

def load_model():
    '''
    Returns resnet model with IMAGENET weights
    '''
    model = resnet18()
    model.load_state_dict(torch.load('utils/resnet18-weights.pth', map_location='cpu'))
    model.eval()
    return model

def transform_image(img):
    '''
    Input: PIL img
    Returns: transformed image
    '''
    trnsfrms = T.Compose(
        [
            T.Resize((224, 224)), 
            T.CenterCrop(100),
            T.ToTensor(),
            T.Normalize(mean, std)
        ]
    )
    return trnsfrms(img)