AlexK-PL commited on
Commit
05073fc
·
1 Parent(s): fc52f44

Update melgan/utils/train.py

Browse files
Files changed (1) hide show
  1. melgan/utils/train.py +6 -6
melgan/utils/train.py CHANGED
@@ -14,8 +14,8 @@ from .validation import validate
14
 
15
 
16
  def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
17
- model_g = Generator(hp.audio.n_mel_channels).cuda()
18
- model_d = MultiScaleDiscriminator().cuda()
19
 
20
  optim_g = torch.optim.Adam(model_g.parameters(),
21
  lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
@@ -62,10 +62,10 @@ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp,
62
  trainloader.dataset.shuffle_mapping()
63
  loader = tqdm.tqdm(trainloader, desc='Loading train data')
64
  for (melG, audioG), (melD, audioD) in loader:
65
- melG = melG.cuda()
66
- audioG = audioG.cuda()
67
- melD = melD.cuda()
68
- audioD = audioD.cuda()
69
 
70
  # generator
71
  optim_g.zero_grad()
 
14
 
15
 
16
  def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
17
+ model_g = Generator(hp.audio.n_mel_channels) # cuda()
18
+ model_d = MultiScaleDiscriminator() # cuda()
19
 
20
  optim_g = torch.optim.Adam(model_g.parameters(),
21
  lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
 
62
  trainloader.dataset.shuffle_mapping()
63
  loader = tqdm.tqdm(trainloader, desc='Loading train data')
64
  for (melG, audioG), (melD, audioD) in loader:
65
+ # melG = melG.cuda()
66
+ # audioG = audioG.cuda()
67
+ # melD = melD.cuda()
68
+ # audioD = audioD.cuda()
69
 
70
  # generator
71
  optim_g.zero_grad()