Spaces:
Running
on
A100
Running
on
A100
import einops | |
import torch.nn.functional as F | |
import collections.abc | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
from pathlib import Path | |
from einops import rearrange | |
from typing import Any, Dict, Optional, Tuple, Union | |
from diffusers.models.modeling_utils import ModelMixin | |
from itertools import repeat | |
from .embed_layers import PatchEmbed | |
def _ntuple(n): | |
""" | |
Creates a helper function to convert inputs to tuples of specified length. | |
Functionality: | |
- Converts iterable inputs (excluding strings) to tuples, ensuring length n | |
- Repeats single values n times to form a tuple | |
Useful for handling multi-dimensional parameters like kernel sizes and strides. | |
Args: | |
n (int): Target length of the tuple | |
Returns: | |
function: A parser function that converts inputs to n-length tuples | |
""" | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
x = tuple(x) | |
if len(x) == 1: | |
x = tuple(repeat(x[0], n)) | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
# Create common tuple conversion functions | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
to_3tuple = _ntuple(3) | |
to_4tuple = _ntuple(4) | |
class CameraNet(ModelMixin): | |
""" | |
Camera state encoding network that processes camera parameters into feature embeddings. | |
This network converts camera state information into suitable feature representations | |
for video generation models through downsampling, convolutional encoding, and | |
temporal dimension compression. Supports loading from pretrained weights. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
downscale_coef, | |
out_channels, | |
patch_size, | |
hidden_size, | |
): | |
super().__init__() | |
# Calculate initial channels: PixelUnshuffle moves spatial info to channel dimension | |
# resulting in channels = in_channels * (downscale_coef^2) | |
start_channels = in_channels * (downscale_coef ** 2) | |
input_channels = [start_channels, start_channels // 2, start_channels // 4] | |
self.input_channels = input_channels | |
self.unshuffle = nn.PixelUnshuffle(downscale_coef) | |
self.encode_first = nn.Sequential( | |
nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0), | |
nn.GroupNorm(2, input_channels[1]), | |
nn.ReLU(), | |
) | |
self._initialize_weights(self.encode_first) | |
self.encode_second = nn.Sequential( | |
nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0), | |
nn.GroupNorm(2, input_channels[2]), | |
nn.ReLU(), | |
) | |
self._initialize_weights(self.encode_second) | |
self.final_proj = nn.Conv2d(input_channels[2], out_channels, kernel_size=1) | |
self.zeros_init_linear(self.final_proj) | |
self.scale = nn.Parameter(torch.ones(1)) | |
self.camera_in = PatchEmbed(patch_size=patch_size, in_chans=out_channels, embed_dim=hidden_size) | |
def zeros_init_linear(self, linear: nn.Module): | |
""" | |
Zero-initializes weights and biases of linear or convolutional layers. | |
Args: | |
linear (nn.Module): Linear or convolutional layer to initialize | |
""" | |
if isinstance(linear, (nn.Linear, nn.Conv2d)): | |
if hasattr(linear, "weight"): | |
nn.init.zeros_(linear.weight) | |
if hasattr(linear, "bias"): | |
nn.init.zeros_(linear.bias) | |
def _initialize_weights(self, block): | |
""" | |
Initializes convolutional layer weights using He initialization, | |
with biases initialized to zero. | |
Args: | |
block (nn.Sequential): Sequential block containing convolutional layers | |
""" | |
for m in block: | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels | |
init.normal_(m.weight, mean=0.0, std=np.sqrt(2.0 / n)) | |
if m.bias is not None: | |
init.zeros_(m.bias) | |
def compress_time(self, x, num_frames): | |
""" | |
Temporal dimension compression: reduces number of frames using average pooling | |
while preserving key temporal information. | |
Handling logic: | |
- Special frame counts (66 or 34): split into two segments, keep first frame of each | |
segment then pool remaining frames | |
- Odd frame counts: keep first frame, pool remaining frames | |
- Even frame counts: directly pool all frames | |
Args: | |
x (torch.Tensor): Input tensor with shape (b*f, c, h, w) | |
num_frames (int): Number of frames in temporal dimension | |
Returns: | |
torch.Tensor: Temporally compressed tensor with shape (b*f', c, h, w) where f' < f | |
""" | |
# Reshape: (b*f, c, h, w) -> (b, f, c, h, w) | |
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) | |
batch_size, frames, channels, height, width = x.shape | |
x = rearrange(x, 'b f c h w -> (b h w) c f') | |
# print(x.shape) | |
# raise Exception | |
# Handle special frame counts (66 or 34) | |
if x.shape[-1] == 66 or x.shape[-1] == 34: | |
x_len = x.shape[-1] | |
# Process first segment: keep first frame, pool remaining | |
x_clip1 = x[...,:x_len//2] | |
x_clip1_first, x_clip1_rest = x_clip1[..., 0].unsqueeze(-1), x_clip1[..., 1:] | |
x_clip1_rest = F.avg_pool1d(x_clip1_rest, kernel_size=2, stride=2) | |
# Process second segment: keep first frame, pool remaining | |
x_clip2 = x[...,x_len//2:x_len] | |
x_clip2_first, x_clip2_rest = x_clip2[..., 0].unsqueeze(-1), x_clip2[..., 1:] | |
x_clip2_rest = F.avg_pool1d(x_clip2_rest, kernel_size=2, stride=2) | |
# Concatenate results from both segments | |
x = torch.cat([x_clip1_first, x_clip1_rest, x_clip2_first, x_clip2_rest], dim=-1) | |
elif x.shape[-1] % 2 == 1: | |
x_first, x_rest = x[..., 0], x[..., 1:] | |
if x_rest.shape[-1] > 0: | |
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) | |
x = torch.cat([x_first[..., None], x_rest], dim=-1) | |
else: | |
x = F.avg_pool1d(x, kernel_size=2, stride=2) | |
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width) | |
return x | |
def forward( | |
self, | |
camera_states: torch.Tensor, | |
): | |
""" | |
Forward pass: encodes camera states into feature embeddings. | |
Args: | |
camera_states (torch.Tensor): Camera state tensor with dimensions | |
(batch, frames, channels, height, width) | |
Returns: | |
torch.Tensor: Encoded feature embeddings after patch embedding and scaling | |
""" | |
# import pdb;pdb.set_trace() | |
batch_size, num_frames, channels, height, width = camera_states.shape | |
camera_states = rearrange(camera_states, 'b f c h w -> (b f) c h w') | |
camera_states = self.unshuffle(camera_states) | |
camera_states = self.encode_first(camera_states) | |
camera_states = self.compress_time(camera_states, num_frames=num_frames) | |
num_frames = camera_states.shape[0] // batch_size | |
camera_states = self.encode_second(camera_states) | |
camera_states = self.compress_time(camera_states, num_frames=num_frames) | |
# camera_states = rearrange(camera_states, '(b f) c h w -> b f c h w', b=batch_size) | |
camera_states = self.final_proj(camera_states) | |
camera_states = rearrange(camera_states, "(b f) c h w -> b c f h w", b=batch_size) | |
camera_states = self.camera_in(camera_states) | |
return camera_states * self.scale | |
def from_pretrained(cls, pretrained_model_path): | |
""" | |
Loads model from pretrained weight file. | |
Args: | |
pretrained_model_path (str): Path to pretrained weight file | |
Returns: | |
CameraNet: Model instance with loaded pretrained weights | |
""" | |
if not Path(pretrained_model_path).exists(): | |
print(f"There is no model file in {pretrained_model_path}") | |
print(f"loaded CameraNet's pretrained weights from {pretrained_model_path}.") | |
state_dict = torch.load(pretrained_model_path, map_location="cpu") | |
model = CameraNet(in_channels=6, downscale_coef=8, out_channels=16) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
if __name__ == "__main__": | |
# Test model initialization and forward pass | |
model = CameraNet( | |
in_channels=6, | |
downscale_coef=8, | |
out_channels=16, | |
patch_size=[1,2,2], | |
hidden_size=3072 | |
) | |
print("Model structure:") | |
print(model) | |
# Generate test input (batch 1, 33 frames, 6 channels, 704x1280 resolution) | |
num_frames = 33 | |
input_tensor = torch.randn(1, num_frames, 6, 704, 1280) | |
# Forward pass | |
output_tensor = model(input_tensor) | |
# Print results | |
print(f"Output shape: {output_tensor.shape}") # Expected: torch.Size([1, ...]) | |
print("Output tensor example:") | |
print(output_tensor) | |