File size: 775 Bytes
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from typing import Union
from pathlib import Path
import torch

from torchvision.models.segmentation import deeplabv3_resnet101
from SoccerNet.Evaluation.utils_calibration import SoccerPitch



class InferenceSegmentationModel:
    def __init__(self, checkpoint: Union[str, Path], device) -> None:
        self.device = device
        self.model = deeplabv3_resnet101(
            num_classes=len(SoccerPitch.lines_classes) + 1, aux_loss=True
        )
        checkpoint_data = torch.load(checkpoint, map_location=self.device, weights_only=False)
        self.model.load_state_dict(checkpoint_data["model"], strict=False)
        self.model.to(self.device)
        self.model.eval()

    def inference(self, img_batch):
        return self.model(img_batch)["out"].argmax(1)