import torch import torch.nn as nn from transformers import PreTrainedModel, ResNetBackbone from .configuration_conditional_unet import ConditionalUNetConfig class UpSampleBlock(nn.Module): def __init__(self, in_channels, skip_channels, out_channels, condition_size): super(UpSampleBlock, self).__init__() self.up = nn.Upsample(scale_factor=2, mode='nearest') self.conv = nn.Sequential( nn.Conv2d(in_channels + skip_channels + condition_size, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x, skip, condition, upsample=True): if upsample: x = self.up(x) b, _, h, w = x.size() # Expand condition to match spatial dimensions condition = condition.view(b, -1, 1, 1).expand(-1, -1, h, w) x = torch.cat([x, skip, condition], dim=1) x = self.conv(x) return x class ConditionalUNet(PreTrainedModel): config_class = ConditionalUNetConfig def __init__(self, config): super().__init__(config) # self.config_class = 'configuration_conditional_unet.ConditionalUNetConfig' self.config = config self.encoder_rep = config.encoder_rep self.encoder = ResNetBackbone.from_pretrained( self.encoder_rep, return_dict=False, output_hidden_states=True ) self.encoder.eval() self.encoder.requires_grad_(False) self.num_labels = self.encoder.config.num_labels self.num_channels = self.encoder.config.num_channels self.config.num_labels = self.num_labels self.config.num_channels = self.num_channels hidden_sizes = self.encoder.config.hidden_sizes embedding_size = self.encoder.config.embedding_size self.up_blocks = nn.ModuleList() num_stages = len(hidden_sizes) in_channels = hidden_sizes[-1] for i in range(num_stages - 1, -1, -1): skip_channels = hidden_sizes[i - 1] if i > 0 else embedding_size out_channels = skip_channels self.up_blocks.append( UpSampleBlock( in_channels=in_channels, skip_channels=skip_channels, out_channels=out_channels, condition_size=self.num_labels ) ) in_channels = out_channels self.final_conv = nn.Sequential( nn.Conv2d(in_channels + self.num_labels, in_channels, kernel_size=3, padding=1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, self.num_channels, kernel_size=1) ) def forward(self, x, condition): outputs = self.encoder(x)[-1] x_stages = outputs[::-1] x = x_stages[0] for i, up_block in enumerate(self.up_blocks): skip = x_stages[i + 1] if i + 1 < len(x_stages) else None upsample = i < len(self.up_blocks) - 1 if skip is not None: x = up_block(x, skip, condition, upsample=upsample) else: x = up_block(x, torch.zeros_like(x), condition, upsample=upsample) x_upsampled = nn.functional.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) b, _, h, w = x_upsampled.size() condition_expanded = condition.view(b, -1, 1, 1).expand(-1, -1, h, w) final_input = torch.cat([x_upsampled, condition_expanded], dim=1) output = self.final_conv(final_input) return output