import torch from torch.utils.data import DataLoader from datasets import load_dataset from model import ImageToVideoModel from de_en.tokenizer import VideoTokenizer import torch.optim as optim from torch.nn import MSELoss from tqdm import tqdm import argparse def prepare_datasets(dataset_name, batch_size, resolution): dataset = load_dataset(dataset_name) # Preprocess function def preprocess(examples): tokenizer = VideoTokenizer(resolution) examples['image'] = [tokenizer.encode_image(img) for img in examples['image']] examples['video'] = [tokenizer.encode_video(vid) for vid in examples['video']] return examples dataset = dataset.map(preprocess, batched=True) dataset.set_format(type='torch', columns=['image', 'video']) train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True) val_loader = DataLoader(dataset['validation'], batch_size=batch_size) return train_loader, val_loader def train_model(config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model model = ImageToVideoModel( encoder_config=config['encoder'], decoder_config=config['decoder'], transformer_config=config['transformer'] ).to(device) # Load datasets train_loader, val_loader = prepare_datasets( config['dataset_name'], config['batch_size'], config['resolution'] ) # Optimizer and loss optimizer = optim.AdamW(model.parameters(), lr=config['lr']) criterion = MSELoss() # Training loop for epoch in range(config['epochs']): model.train() train_loss = 0.0 for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): images = batch['image'].to(device) videos = batch['video'].to(device) # Random speed level for each sample in batch speed_levels = torch.randint(0, 10, (images.size(0),).to(device) optimizer.zero_grad() # Predict all frames at once (teacher forcing) outputs = model(images, videos[:, :-1], speed_levels) loss = criterion(outputs, videos[:, 1:]) loss.backward() optimizer.step() train_loss += loss.item() # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for batch in val_loader: images = batch['image'].to(device) videos = batch['video'].to(device) speed_levels = torch.randint(0, 10, (images.size(0),).to(device) outputs = model(images, videos[:, :-1], speed_levels) val_loss += criterion(outputs, videos[:, 1:]).item() print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}") # Save model torch.save(model.state_dict(), config['save_path']) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="ucf101") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--resolution", type=int, default=128) parser.add_argument("--save_path", type=str, default="image_to_video_model.pth") args = parser.parse_args() config = { 'dataset_name': args.dataset, 'batch_size': args.batch_size, 'epochs': args.epochs, 'lr': args.lr, 'resolution': args.resolution, 'save_path': args.save_path, 'encoder': { 'in_channels': 3, 'hidden_dims': [64, 128, 256, 512], 'embed_dim': 512 }, 'decoder': { 'embed_dim': 512, 'hidden_dims': [512, 256, 128, 64], 'out_channels': 3 }, 'transformer': { 'd_model': 512, 'nhead': 8, 'num_encoder_layers': 3, 'num_decoder_layers': 3, 'dim_feedforward': 2048, 'dropout': 0.1 } } train_model(config)