NPRC24 / MiAlgo /classes /fc4 /ModelFC4.py
Artyom
MiAlgo
82567db verified
import os
from typing import Union, Tuple
import torchvision.transforms.functional as F
from torch import Tensor
from torchvision.transforms import transforms
from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING
from auxiliary.utils import correct, rescale, scale
from classes.core.Model import Model
from classes.fc4.FC4 import FC4
class ModelFC4(Model):
def __init__(self):
super().__init__()
self._network = FC4().to(self._device)
def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, Tuple]:
"""
Performs inference on the input image using the FC4 method.
@param img: the image for which an illuminant colour has to be estimated
@param return_steps: whether or not to also return the per-patch estimates and confidence weights. When this
flag is set to True, confidence-weighted pooling must be active)
@return: the colour estimate as a Tensor. If "return_steps" is set to true, the per-path colour estimates and
the confidence weights are also returned (used for visualizations)
"""
if USE_CONFIDENCE_WEIGHTED_POOLING:
pred, rgb, confidence = self._network(img)
if return_steps:
return pred, rgb, confidence
return pred
return self._network(img)
def optimize(self, img: Tensor, label: Tensor) -> float:
self._optimizer.zero_grad()
pred = self.predict(img)
loss = self.get_loss(pred, label)
loss.backward()
self._optimizer.step()
return loss.item()