| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class TextEncoder(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.transformer = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer(embed_dim, nhead=8), | |
| num_layers=6 | |
| ) | |
| def forward(self, text): | |
| x = self.embedding(text) | |
| return self.transformer(x) | |
| class VideoGenerator(nn.Module): | |
| def __init__(self, latent_dim, num_frames, frame_size): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.num_frames = num_frames | |
| self.generator = nn.Sequential( | |
| nn.ConvTranspose3d(latent_dim, 512, kernel_size=4, stride=2, padding=1), | |
| nn.BatchNorm3d(512), | |
| nn.ReLU(), | |
| nn.ConvTranspose3d(512, 256, kernel_size=4, stride=2, padding=1), | |
| nn.BatchNorm3d(256), | |
| nn.ReLU(), | |
| nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=1), | |
| nn.BatchNorm3d(128), | |
| nn.ReLU(), | |
| nn.ConvTranspose3d(128, 3, kernel_size=4, stride=2, padding=1), | |
| nn.Tanh() | |
| ) | |
| def forward(self, z): | |
| return self.generator(z) | |
| class Text2VideoModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, latent_dim, num_frames, frame_size): | |
| super().__init__() | |
| self.text_encoder = TextEncoder(vocab_size, embed_dim, hidden_dim=512) | |
| self.video_generator = VideoGenerator(latent_dim, num_frames, frame_size) | |
| self.latent_mapper = nn.Linear(embed_dim, latent_dim * num_frames) | |
| def forward(self, text): | |
| text_features = self.text_encoder(text) | |
| latent_vector = self.latent_mapper(text_features.mean(dim=1)) | |
| latent_video = latent_vector.view(-1, self.video_generator.latent_dim, 1, 1, 1) | |
| generated_video = self.video_generator(latent_video) | |
| return generated_video | |