In [None]:
import os
import requests
import zipfile
import io

def lesgooo(name_it):
    response = requests.get(name_it, stream=True)

    # Create an in-memory bytes buffer for the zip file
    zip_buffer = io.BytesIO()

    # Download the file in chunks and write to the in-memory buffer
    for chunk in response.iter_content(chunk_size=1024):
        if chunk:
            zip_buffer.write(chunk)
            
    # Seek to the beginning of the buffer
    zip_buffer.seek(0)

    # Open the zip file in memory
    with zipfile.ZipFile(zip_buffer, 'r') as zip_ref:
        # Extract all files directly to your desired directory
        zip_ref.extractall('/kaggle/working/')

# Function to download and unzip files
def download_and_unzip(index):
    images_zip = f"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_images.zip"
    labels_zip = f"https://download.visinf.tu-darmstadt.de/data/from_games/data/{str(index).zfill(2)}_labels.zip"
    
    lesgooo(images_zip)
    lesgooo(labels_zip)

# Loop through indices 1 to 10
for i in range(1, 11):
    download_and_unzip(i)
    print(f"Part {i} done")

In [11]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Custom Dataset class
class ImageLabelDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.image_filenames = sorted(os.listdir(images_dir))
        self.label_filenames = sorted(os.listdir(labels_dir))
        self.transform = transform

        assert len(self.image_filenames) == len(self.label_filenames), "Mismatch in number of images and labels"

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_filenames[idx])
        label_path = os.path.join(self.labels_dir, self.label_filenames[idx])
        
        # Open image and label using PIL
        image = Image.open(image_path).convert('RGB')  # Ensure image is RGB
        label = Image.open(label_path).convert('RGB')  # Ensure label is RGB (or use 'L' for grayscale)

        if self.transform:
            # Apply the same transformation to both the image and the label
            image = self.transform(image)
            label = self.transform(label)

        return image, label

resize_dim = 512 # resize to dimensions
batch_size = 12 # Adjust batch size according to your GPU memory capacity
    
# Define transformations (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((resize_dim, resize_dim)),  # Resize images
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1] for each channel
])

# Create dataset instances for training
images_dir = '/kaggle/working/images/'  # Replace with your image directory path
labels_dir = '/kaggle/working/labels/'  # Replace with your label directory path
dataset = ImageLabelDataset(images_dir, labels_dir, transform=transform)

# Create DataLoader with batch size control
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=64):
        super(LinearAttention, self).__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(inner_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: t.reshape(b, self.heads, -1, h * w), qkv)
        q = q * self.scale
        k = k.softmax(dim=-1)
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = out.reshape(b, -1, h, w)
        return self.to_out(out)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = DepthwiseSeparableConv(channels, channels, 3, padding=1)
        self.in1 = nn.InstanceNorm2d(channels)
        self.conv2 = DepthwiseSeparableConv(channels, channels, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out += residual
        return F.relu(out)

class EfficientRealformer(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, base_channels=64, num_residuals=6):
        super(EfficientRealformer, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            DepthwiseSeparableConv(input_channels, base_channels, 7, padding=3),
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels * 4),
            nn.ReLU(inplace=True)
        )
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            nn.Sequential(
                LinearAttention(base_channels * 4),
                ResidualBlock(base_channels * 4)
            ) for _ in range(num_residuals)
        ])
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(base_channels * 2, base_channels, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(base_channels, output_channels, 7, padding=3)
        )
        
        # Style encoder
        self.style_encoder = nn.Sequential(
            DepthwiseSeparableConv(input_channels, base_channels, 3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            DepthwiseSeparableConv(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

    def forward(self, content, style):
        # Encode content
        x = self.encoder(content)
        
        # Extract style features
        style_features = self.style_encoder(style)
        
        # Apply transformer blocks with style injection
        for block in self.transformer_blocks:
            x = block(x)
            x = x + style_features.view(*style_features.shape, 1, 1)
        
        # Decode
        output = self.decoder(x)
        return torch.tanh(output)  # Ensure output is in [-1, 1] range

In [3]:
def total_variation_loss(x):
    return torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) + torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))

def combined_loss(output, target):
    l1_loss = nn.L1Loss()(output, target)
    tv_loss = total_variation_loss(output)
    return l1_loss + 0.0001 * tv_loss

def psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# Instantiate the model
model = EfficientRealformer(input_channels=3, output_channels=3, base_channels=64, num_residuals=6)

# Move model to the appropriate device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# If using 2 x T4 GPU only
model = nn.DataParallel(model, device_ids = [0,1])

# Move model to device
model = model.to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Number of epochs
num_epochs = 20

best_loss = float('inf')

# Path to save the best model
save_dir = 'saved_models'
os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist
best_model_path = os.path.join(save_dir, 'best_model.pth')

# Initialize best loss as a large value
best_loss = float('inf')

# Training loop with tqdm
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    
    # Wrap dataloader with tqdm for progress bar
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (input, target) in pbar:
        # Move data to the same device as the model
        input, target = input.to(device), target.to(device)
        
        optimizer.zero_grad()  # Clear the gradients from the last step
        output = model(input, target)  # Forward pass
#         print(f"Input shape: {input.shape}, Output shape: {output.shape}, Target shape: {target.shape}")
        loss = combined_loss(output, target)  # Compute the loss
        
        loss.backward()  # Backward pass (compute gradients)
        optimizer.step()  # Update the weights
        
        running_loss += loss.item()  # Accumulate the loss for this batch
        
        # Calculate metrics
        current_psnr = psnr(output, target).item()
        running_psnr += current_psnr
        
        # Update progress bar
        pbar.set_postfix({
            'loss': loss.item(),
            'psnr': current_psnr,
        })
    
    # Calculate the average loss and metrics for the epoch
    epoch_loss = running_loss / len(dataloader)
    avg_psnr = running_psnr / len(dataloader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, PSNR: {avg_psnr:.4f}")
    
    # Save the best model based on the lowest loss
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {epoch+1} with loss {best_loss:.4f}")

print("Training complete")

In [10]:
import gc
torch.cuda.empty_cache()
gc.collect()
#RuntimeError: The size of tensor a (768) must match the size of tensor b (512) at non-singleton dimension 3

0

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)
print(model)