|
import torch |
|
from torch.utils.data import DataLoader |
|
from datasets import load_dataset |
|
from model import ImageToVideoModel |
|
from de_en.tokenizer import VideoTokenizer |
|
import torch.optim as optim |
|
from torch.nn import MSELoss |
|
from tqdm import tqdm |
|
import argparse |
|
|
|
def prepare_datasets(dataset_name, batch_size, resolution): |
|
dataset = load_dataset(dataset_name) |
|
|
|
|
|
def preprocess(examples): |
|
tokenizer = VideoTokenizer(resolution) |
|
examples['image'] = [tokenizer.encode_image(img) for img in examples['image']] |
|
examples['video'] = [tokenizer.encode_video(vid) for vid in examples['video']] |
|
return examples |
|
|
|
dataset = dataset.map(preprocess, batched=True) |
|
dataset.set_format(type='torch', columns=['image', 'video']) |
|
|
|
train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True) |
|
val_loader = DataLoader(dataset['validation'], batch_size=batch_size) |
|
|
|
return train_loader, val_loader |
|
|
|
def train_model(config): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = ImageToVideoModel( |
|
encoder_config=config['encoder'], |
|
decoder_config=config['decoder'], |
|
transformer_config=config['transformer'] |
|
).to(device) |
|
|
|
|
|
train_loader, val_loader = prepare_datasets( |
|
config['dataset_name'], |
|
config['batch_size'], |
|
config['resolution'] |
|
) |
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=config['lr']) |
|
criterion = MSELoss() |
|
|
|
|
|
for epoch in range(config['epochs']): |
|
model.train() |
|
train_loss = 0.0 |
|
|
|
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): |
|
images = batch['image'].to(device) |
|
videos = batch['video'].to(device) |
|
|
|
|
|
speed_levels = torch.randint(0, 10, (images.size(0),).to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
outputs = model(images, videos[:, :-1], speed_levels) |
|
|
|
loss = criterion(outputs, videos[:, 1:]) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
|
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
with torch.no_grad(): |
|
for batch in val_loader: |
|
images = batch['image'].to(device) |
|
videos = batch['video'].to(device) |
|
speed_levels = torch.randint(0, 10, (images.size(0),).to(device) |
|
|
|
outputs = model(images, videos[:, :-1], speed_levels) |
|
val_loss += criterion(outputs, videos[:, 1:]).item() |
|
|
|
print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}") |
|
|
|
|
|
torch.save(model.state_dict(), config['save_path']) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset", type=str, default="ucf101") |
|
parser.add_argument("--batch_size", type=int, default=8) |
|
parser.add_argument("--epochs", type=int, default=10) |
|
parser.add_argument("--lr", type=float, default=1e-4) |
|
parser.add_argument("--resolution", type=int, default=128) |
|
parser.add_argument("--save_path", type=str, default="image_to_video_model.pth") |
|
args = parser.parse_args() |
|
|
|
config = { |
|
'dataset_name': args.dataset, |
|
'batch_size': args.batch_size, |
|
'epochs': args.epochs, |
|
'lr': args.lr, |
|
'resolution': args.resolution, |
|
'save_path': args.save_path, |
|
'encoder': { |
|
'in_channels': 3, |
|
'hidden_dims': [64, 128, 256, 512], |
|
'embed_dim': 512 |
|
}, |
|
'decoder': { |
|
'embed_dim': 512, |
|
'hidden_dims': [512, 256, 128, 64], |
|
'out_channels': 3 |
|
}, |
|
'transformer': { |
|
'd_model': 512, |
|
'nhead': 8, |
|
'num_encoder_layers': 3, |
|
'num_decoder_layers': 3, |
|
'dim_feedforward': 2048, |
|
'dropout': 0.1 |
|
} |
|
} |
|
|
|
train_model(config) |