|
from typing import Union |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
from torch.nn.functional import normalize |
|
|
|
from auxiliary.settings import USE_CONFIDENCE_WEIGHTED_POOLING |
|
from classes.fc4.squeezenet.SqueezeNetLoader import SqueezeNetLoader |
|
|
|
""" |
|
FC4: Fully Convolutional Color Constancy with Confidence-weighted Pooling |
|
* Original code: https://github.com/yuanming-hu/fc4 |
|
* Paper: https://www.microsoft.com/en-us/research/publication/fully-convolutional-color-constancy-confidence-weighted-pooling/ |
|
""" |
|
|
|
|
|
class FC4(torch.nn.Module): |
|
|
|
def __init__(self, squeezenet_version: float = 1.1): |
|
super().__init__() |
|
|
|
|
|
squeezenet = SqueezeNetLoader(squeezenet_version).load(pretrained=True) |
|
self.backbone = nn.Sequential(*list(squeezenet.children())[0][:12]) |
|
|
|
|
|
self.final_convs = nn.Sequential( |
|
nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True), |
|
nn.Conv2d(512, 64, kernel_size=6, stride=1, padding=3), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.5), |
|
nn.Conv2d(64, 4 if USE_CONFIDENCE_WEIGHTED_POOLING else 3, kernel_size=1, stride=1), |
|
nn.ReLU(inplace=True) |
|
) |
|
|
|
def forward(self, x: Tensor) -> Union[tuple, Tensor]: |
|
""" |
|
Estimate an RGB colour for the illuminant of the input image |
|
@param x: the image for which the colour of the illuminant has to be estimated |
|
@return: the colour estimate as a Tensor. If confidence-weighted pooling is used, the per-path colour estimates |
|
and the confidence weights are returned as well (used for visualizations) |
|
""" |
|
|
|
x = self.backbone(x) |
|
out = self.final_convs(x) |
|
|
|
|
|
if USE_CONFIDENCE_WEIGHTED_POOLING: |
|
|
|
rgb = normalize(out[:, :3, :, :], dim=1) |
|
|
|
|
|
confidence = out[:, 3:4, :, :] |
|
|
|
|
|
pred = normalize(torch.sum(torch.sum(rgb * confidence, 2), 2), dim=1) |
|
|
|
return pred, rgb, confidence |
|
|
|
|
|
pred = normalize(torch.sum(torch.sum(out, 2), 2), dim=1) |
|
|
|
return pred |
|
|