|
import torch |
|
import torch.nn as nn |
|
from torch.nn import Transformer |
|
from de_en.encoder import ImageEncoder |
|
from de_en.decoder import VideoDecoder |
|
|
|
class ImageToVideoModel(nn.Module): |
|
def __init__(self, encoder_config, decoder_config, transformer_config): |
|
super().__init__() |
|
self.encoder = ImageEncoder(**encoder_config) |
|
self.decoder = VideoDecoder(**decoder_config) |
|
self.transformer = Transformer( |
|
d_model=transformer_config['d_model'], |
|
nhead=transformer_config['nhead'], |
|
num_encoder_layers=transformer_config['num_encoder_layers'], |
|
num_decoder_layers=transformer_config['num_decoder_layers'], |
|
dim_feedforward=transformer_config['dim_feedforward'], |
|
dropout=transformer_config['dropout'] |
|
) |
|
self.speed_embedding = nn.Embedding(10, transformer_config['d_model']) |
|
|
|
def forward(self, image, target_frames, speed_level): |
|
|
|
image_features = self.encoder(image) |
|
|
|
|
|
speed_emb = self.speed_embedding(speed_level).unsqueeze(0) |
|
image_features = image_features + speed_emb |
|
|
|
|
|
batch_size = image.size(0) |
|
seq_len = target_frames.size(1) |
|
|
|
|
|
tgt_mask = self._generate_square_subsequent_mask(seq_len) |
|
memory_mask = torch.zeros(seq_len, image_features.size(0)).to(image.device) |
|
|
|
|
|
output = self.transformer( |
|
src=image_features.unsqueeze(1), |
|
tgt=target_frames, |
|
tgt_mask=tgt_mask, |
|
memory_mask=memory_mask |
|
) |
|
|
|
|
|
video_frames = self.decoder(output) |
|
return video_frames |
|
|
|
def _generate_square_subsequent_mask(self, sz): |
|
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) |
|
|
|
def generate(self, image, num_frames, speed_level=5): |
|
self.eval() |
|
with torch.no_grad(): |
|
|
|
image_features = self.encoder(image) |
|
speed_emb = self.speed_embedding(speed_level).unsqueeze(0) |
|
image_features = image_features + speed_emb |
|
|
|
|
|
generated_frames = torch.zeros((1, num_frames, *image.shape[1:])).to(image.device) |
|
|
|
for i in range(1, num_frames): |
|
tgt_mask = self._generate_square_subsequent_mask(i) |
|
output = self.transformer( |
|
src=image_features.unsqueeze(1), |
|
tgt=generated_frames[:, :i], |
|
tgt_mask=tgt_mask |
|
) |
|
next_frame = self.decoder(output[:, -1:]) |
|
generated_frames[:, i] = next_frame.squeeze(1) |
|
|
|
return generated_frames |