ASesYusuf1 commited on
Commit
fc049c4
·
verified ·
1 Parent(s): 4d1b48e

Update Apollo/inference.py

Browse files
Files changed (1) hide show
  1. Apollo/inference.py +37 -32
Apollo/inference.py CHANGED
@@ -8,28 +8,35 @@ import argparse
8
  import numpy as np
9
  import yaml
10
  from ml_collections import ConfigDict
 
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
14
  def get_config(config_path):
15
  with open(config_path) as f:
 
16
  config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
17
  return config
18
 
19
- def load_audio(file_path, sr=44100):
20
- audio, samplerate = librosa.load(file_path, mono=False, sr=sr)
21
  print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
 
22
  return torch.from_numpy(audio), samplerate
23
 
24
  def save_audio(file_path, audio, samplerate=44100):
 
25
  sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
26
 
27
- def process_chunk(chunk, model, device):
28
- chunk = chunk.unsqueeze(0).to(device)
29
  with torch.no_grad():
30
  return model(chunk).squeeze(0).squeeze(0).cpu()
31
 
32
  def _getWindowingArray(window_size, fade_size):
 
 
33
  fadein = torch.linspace(1, 1, fade_size)
34
  fadeout = torch.linspace(0, 0, fade_size)
35
  window = torch.ones(window_size)
@@ -42,26 +49,28 @@ def dBgain(audio, volume_gain_dB):
42
  gained_audio = audio * gain
43
  return gained_audio
44
 
45
- def main(input_wav, output_wav, ckpt_path, feature_dim, layer, sr, win, chunk_size, overlap):
46
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- print(f"Using device: {device}")
48
 
49
- # Modeli yükle
50
- model = look2hear.models.BaseModel.from_pretrain(
51
- ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer
52
- ).to(device)
 
 
 
 
 
53
 
54
- test_data, samplerate = load_audio(input_wav, sr=sr)
55
 
56
  C = chunk_size * samplerate # chunk_size seconds to samples
57
  N = overlap
58
  step = C // N
59
- fade_size = 3 * samplerate # 3 seconds
60
  print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
61
 
62
  border = C - step
63
 
64
- # Handle mono inputs correctly
65
  if len(test_data.shape) == 1:
66
  test_data = test_data.unsqueeze(0)
67
 
@@ -86,7 +95,7 @@ def main(input_wav, output_wav, ckpt_path, feature_dim, layer, sr, win, chunk_si
86
  else:
87
  part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
88
 
89
- out = process_chunk(part, model, device)
90
 
91
  window = windowingArray
92
  if i == 0: # First audio chunk, no fadein
@@ -122,24 +131,20 @@ if __name__ == "__main__":
122
  parser = argparse.ArgumentParser(description="Audio Inference Script")
123
  parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
124
  parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
125
- parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file")
126
- parser.add_argument("--config", type=str, required=True, help="Path to model config file")
127
- parser.add_argument("--chunk_size", type=int, default=10, help="Chunk size value in seconds")
128
- parser.add_argument("--overlap", type=int, default=2, help="Overlap")
129
- parser.add_argument("--feature_dim", type=int, default=256, help="Feature dimension")
130
- parser.add_argument("--layer", type=int, default=6, help="Number of layers")
131
- parser.add_argument("--sr", type=int, default=44100, help="Sample rate")
132
- parser.add_argument("--win", type=int, default=20, help="Window size")
133
  args = parser.parse_args()
134
 
 
 
 
135
  config = get_config(args.config)
136
  print(config['model'])
137
- print(f'ckpt_path = {args.ckpt}')
138
- print(f'chunk_size = {args.chunk_size}, overlap = {args.overlap}')
139
- print(f'feature_dim = {args.feature_dim}, layer = {args.layer}, sr = {args.sr}, win = {args.win}')
140
-
141
- main(
142
- args.in_wav, args.out_wav, args.ckpt,
143
- args.feature_dim, args.layer, args.sr, args.win,
144
- args.chunk_size, args.overlap
145
- )
 
8
  import numpy as np
9
  import yaml
10
  from ml_collections import ConfigDict
11
+ #from omegaconf import OmegaConf
12
+
13
  import warnings
14
  warnings.filterwarnings("ignore")
15
 
16
  def get_config(config_path):
17
  with open(config_path) as f:
18
+ #config = OmegaConf.load(config_path)
19
  config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
  return config
21
 
22
+ def load_audio(file_path):
23
+ audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
24
  print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
25
+ #audio = dBgain(audio, -6)
26
  return torch.from_numpy(audio), samplerate
27
 
28
  def save_audio(file_path, audio, samplerate=44100):
29
+ #audio = dBgain(audio, +6)
30
  sf.write(file_path, audio.T, samplerate, subtype="PCM_16")
31
 
32
+ def process_chunk(chunk):
33
+ chunk = chunk.unsqueeze(0).cuda()
34
  with torch.no_grad():
35
  return model(chunk).squeeze(0).squeeze(0).cpu()
36
 
37
  def _getWindowingArray(window_size, fade_size):
38
+ # IMPORTANT NOTE :
39
+ # no fades here in the end, only removing the failed ending of the chunk
40
  fadein = torch.linspace(1, 1, fade_size)
41
  fadeout = torch.linspace(0, 0, fade_size)
42
  window = torch.ones(window_size)
 
49
  gained_audio = audio * gain
50
  return gained_audio
51
 
 
 
 
52
 
53
+ def main(input_wav, output_wav, ckpt_path):
54
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
55
+
56
+ global model
57
+ feature_dim = config['model']['feature_dim']
58
+ sr = config['model']['sr']
59
+ win = config['model']['win']
60
+ layer = config['model']['layer']
61
+ model = look2hear.models.BaseModel.from_pretrain(ckpt_path, sr=sr, win=win, feature_dim=feature_dim, layer=layer).cuda()
62
 
63
+ test_data, samplerate = load_audio(input_wav)
64
 
65
  C = chunk_size * samplerate # chunk_size seconds to samples
66
  N = overlap
67
  step = C // N
68
+ fade_size = 3 * 44100 # 3 seconds
69
  print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
70
 
71
  border = C - step
72
 
73
+ # handle mono inputs correctly
74
  if len(test_data.shape) == 1:
75
  test_data = test_data.unsqueeze(0)
76
 
 
95
  else:
96
  part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
97
 
98
+ out = process_chunk(part)
99
 
100
  window = windowingArray
101
  if i == 0: # First audio chunk, no fadein
 
131
  parser = argparse.ArgumentParser(description="Audio Inference Script")
132
  parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
133
  parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
134
+ parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint file", default="model/pytorch_model.bin")
135
+ parser.add_argument("--config", type=str, help="Path to model config file", default="configs/apollo.yaml")
136
+ parser.add_argument("--chunk_size", type=int, help="chunk size value in seconds", default=10)
137
+ parser.add_argument("--overlap", type=int, help="Overlap", default=2)
 
 
 
 
138
  args = parser.parse_args()
139
 
140
+ ckpt_path = args.ckpt
141
+ chunk_size = args.chunk_size
142
+ overlap = args.overlap
143
  config = get_config(args.config)
144
  print(config['model'])
145
+ print(f'ckpt_path = {ckpt_path}')
146
+ #print(f'config = {config}')
147
+ print(f'chunk_size = {chunk_size}, overlap = {overlap}')
148
+
149
+
150
+ main(args.in_wav, args.out_wav, ckpt_path)