SkillForge45's picture
Create model.py
3797a20 verified
import torch
import torch.nn as nn
from torch.nn import Transformer
from de_en.encoder import ImageEncoder
from de_en.decoder import VideoDecoder
class ImageToVideoModel(nn.Module):
def __init__(self, encoder_config, decoder_config, transformer_config):
super().__init__()
self.encoder = ImageEncoder(**encoder_config)
self.decoder = VideoDecoder(**decoder_config)
self.transformer = Transformer(
d_model=transformer_config['d_model'],
nhead=transformer_config['nhead'],
num_encoder_layers=transformer_config['num_encoder_layers'],
num_decoder_layers=transformer_config['num_decoder_layers'],
dim_feedforward=transformer_config['dim_feedforward'],
dropout=transformer_config['dropout']
)
self.speed_embedding = nn.Embedding(10, transformer_config['d_model'])
def forward(self, image, target_frames, speed_level):
# Encode image
image_features = self.encoder(image)
# Add speed information
speed_emb = self.speed_embedding(speed_level).unsqueeze(0)
image_features = image_features + speed_emb
# Prepare target sequence
batch_size = image.size(0)
seq_len = target_frames.size(1)
# Create masks
tgt_mask = self._generate_square_subsequent_mask(seq_len)
memory_mask = torch.zeros(seq_len, image_features.size(0)).to(image.device)
# Transformer processing
output = self.transformer(
src=image_features.unsqueeze(1),
tgt=target_frames,
tgt_mask=tgt_mask,
memory_mask=memory_mask
)
# Decode to video frames
video_frames = self.decoder(output)
return video_frames
def _generate_square_subsequent_mask(self, sz):
return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
def generate(self, image, num_frames, speed_level=5):
self.eval()
with torch.no_grad():
# Encode image
image_features = self.encoder(image)
speed_emb = self.speed_embedding(speed_level).unsqueeze(0)
image_features = image_features + speed_emb
# Initialize with first frame (could be blank or repeated image)
generated_frames = torch.zeros((1, num_frames, *image.shape[1:])).to(image.device)
for i in range(1, num_frames):
tgt_mask = self._generate_square_subsequent_mask(i)
output = self.transformer(
src=image_features.unsqueeze(1),
tgt=generated_frames[:, :i],
tgt_mask=tgt_mask
)
next_frame = self.decoder(output[:, -1:])
generated_frames[:, i] = next_frame.squeeze(1)
return generated_frames