SkillForge45's picture
Create train.py
5c9efac verified
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from model import ImageToVideoModel
from de_en.tokenizer import VideoTokenizer
import torch.optim as optim
from torch.nn import MSELoss
from tqdm import tqdm
import argparse
def prepare_datasets(dataset_name, batch_size, resolution):
dataset = load_dataset(dataset_name)
# Preprocess function
def preprocess(examples):
tokenizer = VideoTokenizer(resolution)
examples['image'] = [tokenizer.encode_image(img) for img in examples['image']]
examples['video'] = [tokenizer.encode_video(vid) for vid in examples['video']]
return examples
dataset = dataset.map(preprocess, batched=True)
dataset.set_format(type='torch', columns=['image', 'video'])
train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset['validation'], batch_size=batch_size)
return train_loader, val_loader
def train_model(config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = ImageToVideoModel(
encoder_config=config['encoder'],
decoder_config=config['decoder'],
transformer_config=config['transformer']
).to(device)
# Load datasets
train_loader, val_loader = prepare_datasets(
config['dataset_name'],
config['batch_size'],
config['resolution']
)
# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
criterion = MSELoss()
# Training loop
for epoch in range(config['epochs']):
model.train()
train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
images = batch['image'].to(device)
videos = batch['video'].to(device)
# Random speed level for each sample in batch
speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
optimizer.zero_grad()
# Predict all frames at once (teacher forcing)
outputs = model(images, videos[:, :-1], speed_levels)
loss = criterion(outputs, videos[:, 1:])
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
images = batch['image'].to(device)
videos = batch['video'].to(device)
speed_levels = torch.randint(0, 10, (images.size(0),).to(device)
outputs = model(images, videos[:, :-1], speed_levels)
val_loss += criterion(outputs, videos[:, 1:]).item()
print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
# Save model
torch.save(model.state_dict(), config['save_path'])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="ucf101")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--resolution", type=int, default=128)
parser.add_argument("--save_path", type=str, default="image_to_video_model.pth")
args = parser.parse_args()
config = {
'dataset_name': args.dataset,
'batch_size': args.batch_size,
'epochs': args.epochs,
'lr': args.lr,
'resolution': args.resolution,
'save_path': args.save_path,
'encoder': {
'in_channels': 3,
'hidden_dims': [64, 128, 256, 512],
'embed_dim': 512
},
'decoder': {
'embed_dim': 512,
'hidden_dims': [512, 256, 128, 64],
'out_channels': 3
},
'transformer': {
'd_model': 512,
'nhead': 8,
'num_encoder_layers': 3,
'num_decoder_layers': 3,
'dim_feedforward': 2048,
'dropout': 0.1
}
}
train_model(config)