File size: 2,946 Bytes
3797a20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
from torch.nn import Transformer
from de_en.encoder import ImageEncoder
from de_en.decoder import VideoDecoder

class ImageToVideoModel(nn.Module):
    def __init__(self, encoder_config, decoder_config, transformer_config):
        super().__init__()
        self.encoder = ImageEncoder(**encoder_config)
        self.decoder = VideoDecoder(**decoder_config)
        self.transformer = Transformer(
            d_model=transformer_config['d_model'],
            nhead=transformer_config['nhead'],
            num_encoder_layers=transformer_config['num_encoder_layers'],
            num_decoder_layers=transformer_config['num_decoder_layers'],
            dim_feedforward=transformer_config['dim_feedforward'],
            dropout=transformer_config['dropout']
        )
        self.speed_embedding = nn.Embedding(10, transformer_config['d_model'])
        
    def forward(self, image, target_frames, speed_level):
        # Encode image
        image_features = self.encoder(image)
        
        # Add speed information
        speed_emb = self.speed_embedding(speed_level).unsqueeze(0)
        image_features = image_features + speed_emb
        
        # Prepare target sequence
        batch_size = image.size(0)
        seq_len = target_frames.size(1)
        
        # Create masks
        tgt_mask = self._generate_square_subsequent_mask(seq_len)
        memory_mask = torch.zeros(seq_len, image_features.size(0)).to(image.device)
        
        # Transformer processing
        output = self.transformer(
            src=image_features.unsqueeze(1),
            tgt=target_frames,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask
        )
        
        # Decode to video frames
        video_frames = self.decoder(output)
        return video_frames
    
    def _generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def generate(self, image, num_frames, speed_level=5):
        self.eval()
        with torch.no_grad():
            # Encode image
            image_features = self.encoder(image)
            speed_emb = self.speed_embedding(speed_level).unsqueeze(0)
            image_features = image_features + speed_emb
            
            # Initialize with first frame (could be blank or repeated image)
            generated_frames = torch.zeros((1, num_frames, *image.shape[1:])).to(image.device)
            
            for i in range(1, num_frames):
                tgt_mask = self._generate_square_subsequent_mask(i)
                output = self.transformer(
                    src=image_features.unsqueeze(1),
                    tgt=generated_frames[:, :i],
                    tgt_mask=tgt_mask
                )
                next_frame = self.decoder(output[:, -1:])
                generated_frames[:, i] = next_frame.squeeze(1)
                
        return generated_frames