FastAPI / utils /model_func.py
Solar-Iz's picture
Upload 10 files
dbfc835
raw
history blame
647 Bytes
import torch
from torchvision import models, transforms
from PIL import Image
import json
def load_classes():
with open('utils/imagenet-simple-labels.json') as f:
labels = json.load(f)
return labels
def class_id_to_label(i):
labels = load_classes()
return labels[i]
def load_model():
model = models.mobilenet_v2(pretrained=True)
model.eval()
return model
def transform_image(img):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(img)