SkillForge45 commited on
Commit
5c9efac
·
verified ·
1 Parent(s): 3797a20

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +127 -0
train.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from datasets import load_dataset
4
+ from model import ImageToVideoModel
5
+ from de_en.tokenizer import VideoTokenizer
6
+ import torch.optim as optim
7
+ from torch.nn import MSELoss
8
+ from tqdm import tqdm
9
+ import argparse
10
+
11
+ def prepare_datasets(dataset_name, batch_size, resolution):
12
+ dataset = load_dataset(dataset_name)
13
+
14
+ # Preprocess function
15
+ def preprocess(examples):
16
+ tokenizer = VideoTokenizer(resolution)
17
+ examples['image'] = [tokenizer.encode_image(img) for img in examples['image']]
18
+ examples['video'] = [tokenizer.encode_video(vid) for vid in examples['video']]
19
+ return examples
20
+
21
+ dataset = dataset.map(preprocess, batched=True)
22
+ dataset.set_format(type='torch', columns=['image', 'video'])
23
+
24
+ train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
25
+ val_loader = DataLoader(dataset['validation'], batch_size=batch_size)
26
+
27
+ return train_loader, val_loader
28
+
29
+ def train_model(config):
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Initialize model
33
+ model = ImageToVideoModel(
34
+ encoder_config=config['encoder'],
35
+ decoder_config=config['decoder'],
36
+ transformer_config=config['transformer']
37
+ ).to(device)
38
+
39
+ # Load datasets
40
+ train_loader, val_loader = prepare_datasets(
41
+ config['dataset_name'],
42
+ config['batch_size'],
43
+ config['resolution']
44
+ )
45
+
46
+ # Optimizer and loss
47
+ optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
48
+ criterion = MSELoss()
49
+
50
+ # Training loop
51
+ for epoch in range(config['epochs']):
52
+ model.train()
53
+ train_loss = 0.0
54
+
55
+ for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
56
+ images = batch['image'].to(device)
57
+ videos = batch['video'].to(device)
58
+
59
+ # Random speed level for each sample in batch
60
+ speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
61
+
62
+ optimizer.zero_grad()
63
+
64
+ # Predict all frames at once (teacher forcing)
65
+ outputs = model(images, videos[:, :-1], speed_levels)
66
+
67
+ loss = criterion(outputs, videos[:, 1:])
68
+ loss.backward()
69
+ optimizer.step()
70
+
71
+ train_loss += loss.item()
72
+
73
+ # Validation
74
+ model.eval()
75
+ val_loss = 0.0
76
+ with torch.no_grad():
77
+ for batch in val_loader:
78
+ images = batch['image'].to(device)
79
+ videos = batch['video'].to(device)
80
+ speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
81
+
82
+ outputs = model(images, videos[:, :-1], speed_levels)
83
+ val_loss += criterion(outputs, videos[:, 1:]).item()
84
+
85
+ print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
86
+
87
+ # Save model
88
+ torch.save(model.state_dict(), config['save_path'])
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser()
92
+ parser.add_argument("--dataset", type=str, default="ucf101")
93
+ parser.add_argument("--batch_size", type=int, default=8)
94
+ parser.add_argument("--epochs", type=int, default=10)
95
+ parser.add_argument("--lr", type=float, default=1e-4)
96
+ parser.add_argument("--resolution", type=int, default=128)
97
+ parser.add_argument("--save_path", type=str, default="image_to_video_model.pth")
98
+ args = parser.parse_args()
99
+
100
+ config = {
101
+ 'dataset_name': args.dataset,
102
+ 'batch_size': args.batch_size,
103
+ 'epochs': args.epochs,
104
+ 'lr': args.lr,
105
+ 'resolution': args.resolution,
106
+ 'save_path': args.save_path,
107
+ 'encoder': {
108
+ 'in_channels': 3,
109
+ 'hidden_dims': [64, 128, 256, 512],
110
+ 'embed_dim': 512
111
+ },
112
+ 'decoder': {
113
+ 'embed_dim': 512,
114
+ 'hidden_dims': [512, 256, 128, 64],
115
+ 'out_channels': 3
116
+ },
117
+ 'transformer': {
118
+ 'd_model': 512,
119
+ 'nhead': 8,
120
+ 'num_encoder_layers': 3,
121
+ 'num_decoder_layers': 3,
122
+ 'dim_feedforward': 2048,
123
+ 'dropout': 0.1
124
+ }
125
+ }
126
+
127
+ train_model(config)