jbilcke-hf's picture
jbilcke-hf HF Staff
Initial commit with LFS-tracked binary files
01c0e76
import einops
import torch.nn.functional as F
import collections.abc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
from pathlib import Path
from einops import rearrange
from typing import Any, Dict, Optional, Tuple, Union
from diffusers.models.modeling_utils import ModelMixin
from itertools import repeat
from .embed_layers import PatchEmbed
def _ntuple(n):
"""
Creates a helper function to convert inputs to tuples of specified length.
Functionality:
- Converts iterable inputs (excluding strings) to tuples, ensuring length n
- Repeats single values n times to form a tuple
Useful for handling multi-dimensional parameters like kernel sizes and strides.
Args:
n (int): Target length of the tuple
Returns:
function: A parser function that converts inputs to n-length tuples
"""
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
x = tuple(x)
if len(x) == 1:
x = tuple(repeat(x[0], n))
return x
return tuple(repeat(x, n))
return parse
# Create common tuple conversion functions
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
class CameraNet(ModelMixin):
"""
Camera state encoding network that processes camera parameters into feature embeddings.
This network converts camera state information into suitable feature representations
for video generation models through downsampling, convolutional encoding, and
temporal dimension compression. Supports loading from pretrained weights.
"""
def __init__(
self,
in_channels,
downscale_coef,
out_channels,
patch_size,
hidden_size,
):
super().__init__()
# Calculate initial channels: PixelUnshuffle moves spatial info to channel dimension
# resulting in channels = in_channels * (downscale_coef^2)
start_channels = in_channels * (downscale_coef ** 2)
input_channels = [start_channels, start_channels // 2, start_channels // 4]
self.input_channels = input_channels
self.unshuffle = nn.PixelUnshuffle(downscale_coef)
self.encode_first = nn.Sequential(
nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
nn.GroupNorm(2, input_channels[1]),
nn.ReLU(),
)
self._initialize_weights(self.encode_first)
self.encode_second = nn.Sequential(
nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
nn.GroupNorm(2, input_channels[2]),
nn.ReLU(),
)
self._initialize_weights(self.encode_second)
self.final_proj = nn.Conv2d(input_channels[2], out_channels, kernel_size=1)
self.zeros_init_linear(self.final_proj)
self.scale = nn.Parameter(torch.ones(1))
self.camera_in = PatchEmbed(patch_size=patch_size, in_chans=out_channels, embed_dim=hidden_size)
def zeros_init_linear(self, linear: nn.Module):
"""
Zero-initializes weights and biases of linear or convolutional layers.
Args:
linear (nn.Module): Linear or convolutional layer to initialize
"""
if isinstance(linear, (nn.Linear, nn.Conv2d)):
if hasattr(linear, "weight"):
nn.init.zeros_(linear.weight)
if hasattr(linear, "bias"):
nn.init.zeros_(linear.bias)
def _initialize_weights(self, block):
"""
Initializes convolutional layer weights using He initialization,
with biases initialized to zero.
Args:
block (nn.Sequential): Sequential block containing convolutional layers
"""
for m in block:
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
init.normal_(m.weight, mean=0.0, std=np.sqrt(2.0 / n))
if m.bias is not None:
init.zeros_(m.bias)
def compress_time(self, x, num_frames):
"""
Temporal dimension compression: reduces number of frames using average pooling
while preserving key temporal information.
Handling logic:
- Special frame counts (66 or 34): split into two segments, keep first frame of each
segment then pool remaining frames
- Odd frame counts: keep first frame, pool remaining frames
- Even frame counts: directly pool all frames
Args:
x (torch.Tensor): Input tensor with shape (b*f, c, h, w)
num_frames (int): Number of frames in temporal dimension
Returns:
torch.Tensor: Temporally compressed tensor with shape (b*f', c, h, w) where f' < f
"""
# Reshape: (b*f, c, h, w) -> (b, f, c, h, w)
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
batch_size, frames, channels, height, width = x.shape
x = rearrange(x, 'b f c h w -> (b h w) c f')
# print(x.shape)
# raise Exception
# Handle special frame counts (66 or 34)
if x.shape[-1] == 66 or x.shape[-1] == 34:
x_len = x.shape[-1]
# Process first segment: keep first frame, pool remaining
x_clip1 = x[...,:x_len//2]
x_clip1_first, x_clip1_rest = x_clip1[..., 0].unsqueeze(-1), x_clip1[..., 1:]
x_clip1_rest = F.avg_pool1d(x_clip1_rest, kernel_size=2, stride=2)
# Process second segment: keep first frame, pool remaining
x_clip2 = x[...,x_len//2:x_len]
x_clip2_first, x_clip2_rest = x_clip2[..., 0].unsqueeze(-1), x_clip2[..., 1:]
x_clip2_rest = F.avg_pool1d(x_clip2_rest, kernel_size=2, stride=2)
# Concatenate results from both segments
x = torch.cat([x_clip1_first, x_clip1_rest, x_clip2_first, x_clip2_rest], dim=-1)
elif x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
else:
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
return x
def forward(
self,
camera_states: torch.Tensor,
):
"""
Forward pass: encodes camera states into feature embeddings.
Args:
camera_states (torch.Tensor): Camera state tensor with dimensions
(batch, frames, channels, height, width)
Returns:
torch.Tensor: Encoded feature embeddings after patch embedding and scaling
"""
# import pdb;pdb.set_trace()
batch_size, num_frames, channels, height, width = camera_states.shape
camera_states = rearrange(camera_states, 'b f c h w -> (b f) c h w')
camera_states = self.unshuffle(camera_states)
camera_states = self.encode_first(camera_states)
camera_states = self.compress_time(camera_states, num_frames=num_frames)
num_frames = camera_states.shape[0] // batch_size
camera_states = self.encode_second(camera_states)
camera_states = self.compress_time(camera_states, num_frames=num_frames)
# camera_states = rearrange(camera_states, '(b f) c h w -> b f c h w', b=batch_size)
camera_states = self.final_proj(camera_states)
camera_states = rearrange(camera_states, "(b f) c h w -> b c f h w", b=batch_size)
camera_states = self.camera_in(camera_states)
return camera_states * self.scale
@classmethod
def from_pretrained(cls, pretrained_model_path):
"""
Loads model from pretrained weight file.
Args:
pretrained_model_path (str): Path to pretrained weight file
Returns:
CameraNet: Model instance with loaded pretrained weights
"""
if not Path(pretrained_model_path).exists():
print(f"There is no model file in {pretrained_model_path}")
print(f"loaded CameraNet's pretrained weights from {pretrained_model_path}.")
state_dict = torch.load(pretrained_model_path, map_location="cpu")
model = CameraNet(in_channels=6, downscale_coef=8, out_channels=16)
model.load_state_dict(state_dict, strict=True)
return model
if __name__ == "__main__":
# Test model initialization and forward pass
model = CameraNet(
in_channels=6,
downscale_coef=8,
out_channels=16,
patch_size=[1,2,2],
hidden_size=3072
)
print("Model structure:")
print(model)
# Generate test input (batch 1, 33 frames, 6 channels, 704x1280 resolution)
num_frames = 33
input_tensor = torch.randn(1, num_frames, 6, 704, 1280)
# Forward pass
output_tensor = model(input_tensor)
# Print results
print(f"Output shape: {output_tensor.shape}") # Expected: torch.Size([1, ...])
print("Output tensor example:")
print(output_tensor)