AlexK-PL commited on
Commit
f95a0dc
·
1 Parent(s): 080259f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -19,7 +19,7 @@ def init_models(hparams):
19
  model = load_model(hparams)
20
  checkpoint_path = "trained_models/checkpoint_78000.model"
21
  model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
22
- model.to('cuda')
23
  _ = model.eval()
24
 
25
  # load pre trained MelGAN model for mel2audio:
@@ -28,15 +28,15 @@ def init_models(hparams):
28
  hp_melgan = load_hparam("melgan/config/default.yaml")
29
  vocoder_model = Generator(80)
30
  vocoder_model.load_state_dict(checkpoint['model_g'])
31
- vocoder_model = vocoder_model.to('cuda')
32
  vocoder_model.eval(inference=False)
33
 
34
  def synthesize(text):
35
  sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
36
- sequence = torch.from_numpy(sequence).to(device='cuda', dtype=torch.int64)
37
 
38
  gst_head_scores = np.array([0.5, 0.15, 0.35]) # originally ([0.5, 0.15, 0.35])
39
- gst_scores = torch.from_numpy(gst_head_scores).cuda().float()
40
 
41
  mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, gst_scores)
42
 
 
19
  model = load_model(hparams)
20
  checkpoint_path = "trained_models/checkpoint_78000.model"
21
  model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
22
+ # model.to('cuda')
23
  _ = model.eval()
24
 
25
  # load pre trained MelGAN model for mel2audio:
 
28
  hp_melgan = load_hparam("melgan/config/default.yaml")
29
  vocoder_model = Generator(80)
30
  vocoder_model.load_state_dict(checkpoint['model_g'])
31
+ # vocoder_model = vocoder_model.to('cuda')
32
  vocoder_model.eval(inference=False)
33
 
34
  def synthesize(text):
35
  sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
36
+ sequence = torch.from_numpy(sequence).to(device='cpu', dtype=torch.int64)
37
 
38
  gst_head_scores = np.array([0.5, 0.15, 0.35]) # originally ([0.5, 0.15, 0.35])
39
+ gst_scores = torch.from_numpy(gst_head_scores).float()
40
 
41
  mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, gst_scores)
42