File size: 3,786 Bytes
dd983c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
from transformers import PreTrainedModel, ResNetBackbone
from .configuration_conditional_unet import ConditionalUNetConfig

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, condition_size):
        super(UpSampleBlock, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels + skip_channels + condition_size, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip, condition, upsample=True):
        if upsample:
            x = self.up(x)
        b, _, h, w = x.size()
        # Expand condition to match spatial dimensions
        condition = condition.view(b, -1, 1, 1).expand(-1, -1, h, w)
        x = torch.cat([x, skip, condition], dim=1)
        x = self.conv(x)
        return x

class ConditionalUNet(PreTrainedModel):
    config_class = ConditionalUNetConfig

    def __init__(self, config):
        super().__init__(config)
        # self.config_class = 'configuration_conditional_unet.ConditionalUNetConfig'
        self.config = config

        self.encoder_rep = config.encoder_rep
        self.encoder = ResNetBackbone.from_pretrained(
            self.encoder_rep,
            return_dict=False,
            output_hidden_states=True
        )
        self.encoder.eval()
        self.encoder.requires_grad_(False)

        self.num_labels = self.encoder.config.num_labels
        self.num_channels = self.encoder.config.num_channels

        self.config.num_labels = self.num_labels
        self.config.num_channels = self.num_channels

        hidden_sizes = self.encoder.config.hidden_sizes
        embedding_size = self.encoder.config.embedding_size

        self.up_blocks = nn.ModuleList()
        num_stages = len(hidden_sizes)

        in_channels = hidden_sizes[-1]
        for i in range(num_stages - 1, -1, -1):
            skip_channels = hidden_sizes[i - 1] if i > 0 else embedding_size
            out_channels = skip_channels
            self.up_blocks.append(
                UpSampleBlock(
                    in_channels=in_channels,
                    skip_channels=skip_channels,
                    out_channels=out_channels,
                    condition_size=self.num_labels
                )
            )
            in_channels = out_channels

        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels + self.num_labels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, self.num_channels, kernel_size=1)
        )

    def forward(self, x, condition):
        outputs = self.encoder(x)[-1]
        x_stages = outputs[::-1]
        x = x_stages[0]

        for i, up_block in enumerate(self.up_blocks):
            skip = x_stages[i + 1] if i + 1 < len(x_stages) else None
            upsample = i < len(self.up_blocks) - 1
            if skip is not None:
                x = up_block(x, skip, condition, upsample=upsample)
            else:
                x = up_block(x, torch.zeros_like(x), condition, upsample=upsample)

        x_upsampled = nn.functional.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)
        b, _, h, w = x_upsampled.size()
        condition_expanded = condition.view(b, -1, 1, 1).expand(-1, -1, h, w)
        final_input = torch.cat([x_upsampled, condition_expanded], dim=1)
        output = self.final_conv(final_input)

        return output