File size: 775 Bytes
bdb955e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from typing import Union
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet101
from SoccerNet.Evaluation.utils_calibration import SoccerPitch
class InferenceSegmentationModel:
def __init__(self, checkpoint: Union[str, Path], device) -> None:
self.device = device
self.model = deeplabv3_resnet101(
num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True
)
checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint_data["model"], strict=False)
self.model.to(self.device)
self.model.eval()
def inference(self, img_batch):
return self.model(img_batch)["out"].argmax(1)
|