LEAR / modeling_conditional_unet.py
wltjr1007's picture
Upload ConditionalUNet
dd983c1 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, ResNetBackbone
from .configuration_conditional_unet import ConditionalUNetConfig
class UpSampleBlock(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, condition_size):
super(UpSampleBlock, self).__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = nn.Sequential(
nn.Conv2d(in_channels + skip_channels + condition_size, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x, skip, condition, upsample=True):
if upsample:
x = self.up(x)
b, _, h, w = x.size()
# Expand condition to match spatial dimensions
condition = condition.view(b, -1, 1, 1).expand(-1, -1, h, w)
x = torch.cat([x, skip, condition], dim=1)
x = self.conv(x)
return x
class ConditionalUNet(PreTrainedModel):
config_class = ConditionalUNetConfig
def __init__(self, config):
super().__init__(config)
# self.config_class = 'configuration_conditional_unet.ConditionalUNetConfig'
self.config = config
self.encoder_rep = config.encoder_rep
self.encoder = ResNetBackbone.from_pretrained(
self.encoder_rep,
return_dict=False,
output_hidden_states=True
)
self.encoder.eval()
self.encoder.requires_grad_(False)
self.num_labels = self.encoder.config.num_labels
self.num_channels = self.encoder.config.num_channels
self.config.num_labels = self.num_labels
self.config.num_channels = self.num_channels
hidden_sizes = self.encoder.config.hidden_sizes
embedding_size = self.encoder.config.embedding_size
self.up_blocks = nn.ModuleList()
num_stages = len(hidden_sizes)
in_channels = hidden_sizes[-1]
for i in range(num_stages - 1, -1, -1):
skip_channels = hidden_sizes[i - 1] if i > 0 else embedding_size
out_channels = skip_channels
self.up_blocks.append(
UpSampleBlock(
in_channels=in_channels,
skip_channels=skip_channels,
out_channels=out_channels,
condition_size=self.num_labels
)
)
in_channels = out_channels
self.final_conv = nn.Sequential(
nn.Conv2d(in_channels + self.num_labels, in_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, self.num_channels, kernel_size=1)
)
def forward(self, x, condition):
outputs = self.encoder(x)[-1]
x_stages = outputs[::-1]
x = x_stages[0]
for i, up_block in enumerate(self.up_blocks):
skip = x_stages[i + 1] if i + 1 < len(x_stages) else None
upsample = i < len(self.up_blocks) - 1
if skip is not None:
x = up_block(x, skip, condition, upsample=upsample)
else:
x = up_block(x, torch.zeros_like(x), condition, upsample=upsample)
x_upsampled = nn.functional.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)
b, _, h, w = x_upsampled.size()
condition_expanded = condition.view(b, -1, 1, 1).expand(-1, -1, h, w)
final_input = torch.cat([x_upsampled, condition_expanded], dim=1)
output = self.final_conv(final_input)
return output