Spaces:
Sleeping
Sleeping
from pytorch_lightning.loggers import WandbLogger | |
import pytorch_lightning as pl | |
import torch | |
from torch.utils.data import DataLoader | |
from datasets import GuitarFXDataset | |
from models import DiffusionGenerationModel, OpenUnmixModel | |
SAMPLE_RATE = 22050 | |
TRAIN_SPLIT = 0.8 | |
def main(): | |
wandb_logger = WandbLogger(project="RemFX", save_dir="./") | |
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10) | |
guitfx = GuitarFXDataset( | |
root="/Users/matthewrice/Developer/remfx/data/egfx", | |
sample_rate=SAMPLE_RATE, | |
effect_type=["Phaser"], | |
) | |
train_size = int(TRAIN_SPLIT * len(guitfx)) | |
val_size = len(guitfx) - train_size | |
train_dataset, val_dataset = torch.utils.data.random_split( | |
guitfx, [train_size, val_size] | |
) | |
train = DataLoader(train_dataset, batch_size=2) | |
val = DataLoader(val_dataset, batch_size=2) | |
# model = DiffusionGenerationModel() | |
model = OpenUnmixModel() | |
trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val) | |
if __name__ == "__main__": | |
main() | |