|
import os |
|
from typing import Union, Tuple |
|
|
|
import torchvision.transforms.functional as F |
|
from torch import Tensor |
|
from torchvision.transforms import transforms |
|
|
|
from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING |
|
from auxiliary.utils import correct, rescale, scale |
|
from classes.core.Model import Model |
|
from classes.fc4.FC4 import FC4 |
|
|
|
|
|
class ModelFC4(Model): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self._network = FC4().to(self._device) |
|
|
|
def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, Tuple]: |
|
""" |
|
Performs inference on the input image using the FC4 method. |
|
@param img: the image for which an illuminant colour has to be estimated |
|
@param return_steps: whether or not to also return the per-patch estimates and confidence weights. When this |
|
flag is set to True, confidence-weighted pooling must be active) |
|
@return: the colour estimate as a Tensor. If "return_steps" is set to true, the per-path colour estimates and |
|
the confidence weights are also returned (used for visualizations) |
|
""" |
|
if USE_CONFIDENCE_WEIGHTED_POOLING: |
|
pred, rgb, confidence = self._network(img) |
|
if return_steps: |
|
return pred, rgb, confidence |
|
return pred |
|
return self._network(img) |
|
|
|
def optimize(self, img: Tensor, label: Tensor) -> float: |
|
self._optimizer.zero_grad() |
|
pred = self.predict(img) |
|
loss = self.get_loss(pred, label) |
|
loss.backward() |
|
self._optimizer.step() |
|
return loss.item() |
|
|
|
|
|
|