import os from typing import Union, Tuple import torchvision.transforms.functional as F from torch import Tensor from torchvision.transforms import transforms from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING from auxiliary.utils import correct, rescale, scale from classes.core.Model import Model from classes.fc4.FC4 import FC4 class ModelFC4(Model): def __init__(self): super().__init__() self._network = FC4().to(self._device) def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, Tuple]: """ Performs inference on the input image using the FC4 method. @param img: the image for which an illuminant colour has to be estimated @param return_steps: whether or not to also return the per-patch estimates and confidence weights. When this flag is set to True, confidence-weighted pooling must be active) @return: the colour estimate as a Tensor. If "return_steps" is set to true, the per-path colour estimates and the confidence weights are also returned (used for visualizations) """ if USE_CONFIDENCE_WEIGHTED_POOLING: pred, rgb, confidence = self._network(img) if return_steps: return pred, rgb, confidence return pred return self._network(img) def optimize(self, img: Tensor, label: Tensor) -> float: self._optimizer.zero_grad() pred = self.predict(img) loss = self.get_loss(pred, label) loss.backward() self._optimizer.step() return loss.item()