ASesYusuf1 commited on
Commit
865a9b2
·
verified ·
1 Parent(s): cb1817f

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +139 -68
inference.py CHANGED
@@ -78,11 +78,10 @@ def run_folder(model, args, config, device, verbose: bool = False):
78
  instruments = prefer_target_instrument(config)[:]
79
  os.makedirs(args.store_dir, exist_ok=True)
80
 
81
- # Dosya sayısını ve progress için değişkenler
82
  total_files = len(mixture_paths)
83
  current_file = 0
84
 
85
- # Progress tracking
86
  for path in mixture_paths:
87
  try:
88
  # Dosya işleme başlangıcı
@@ -90,76 +89,148 @@ def run_folder(model, args, config, device, verbose: bool = False):
90
  print(f"Processing file {current_file}/{total_files}")
91
 
92
  mix, sr = librosa.load(path, sr=sample_rate, mono=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  except Exception as e:
94
  print(f'Cannot read track: {path}')
95
  print(f'Error message: {str(e)}')
96
  continue
97
 
98
- mix_orig = mix.copy()
99
- if 'normalize' in config.inference:
100
- if config.inference['normalize'] is True:
101
- mix, norm_params = normalize_audio(mix)
102
-
103
- waveforms_orig = demix(config, model, mix, device, model_type=args.model_type)
104
-
105
- if args.use_tta:
106
- waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
107
-
108
- if args.demud_phaseremix_inst:
109
- print(f"Demudding track (phase remix - instrumental): {path}")
110
- instr = 'vocals' if 'vocals' in instruments else instruments[0]
111
- instruments.append('instrumental_phaseremix')
112
- if 'instrumental' not in instruments and 'Instrumental' not in instruments:
113
- mix_modified = mix_orig - 2*waveforms_orig[instr]
114
- mix_modified_ = mix_modified.copy()
115
-
116
- waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
117
- if args.use_tta:
118
- waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type)
119
-
120
- waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
121
- else:
122
- mix_modified = 2*waveforms_orig[instr] - mix_orig
123
- mix_modified_ = mix_modified.copy()
124
-
125
- waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
126
- if args.use_tta:
127
- waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type)
128
-
129
- waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
130
-
131
- if args.extract_instrumental:
132
- instr = 'vocals' if 'vocals' in instruments else instruments[0]
133
- waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
134
- if 'instrumental' not in instruments:
135
- instruments.append('instrumental')
136
-
137
- for instr in instruments:
138
- estimates = waveforms_orig[instr]
139
- if 'normalize' in config.inference:
140
- if config.inference['normalize'] is True:
141
- estimates = denormalize_audio(estimates, norm_params)
142
-
143
- # Dosya formatı ve PCM türü belirleme
144
- is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
145
- codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
146
-
147
- # Subtype belirleme
148
- if codec == 'flac':
149
- subtype = get_soundfile_subtype(args.pcm_type, is_float)
150
- else:
151
- subtype = get_soundfile_subtype('FLOAT', is_float)
152
-
153
- shortened_filename = shorten_filename(os.path.basename(path))
154
- output_filename = f"{shortened_filename}_{instr}.{codec}"
155
- output_path = os.path.join(args.store_dir, output_filename)
156
-
157
- sf.write(output_path, estimates.T, sr, subtype=subtype)
158
-
159
- # Progress yüzdesi hesaplama
160
- progress_percent = int((current_file / total_files) * 100)
161
- print(f"Progress: {progress_percent}%")
162
-
163
  print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
164
 
165
  def proc_folder(args):
@@ -239,4 +310,4 @@ def proc_folder(args):
239
 
240
 
241
  if __name__ == "__main__":
242
- proc_folder(None)
 
78
  instruments = prefer_target_instrument(config)[:]
79
  os.makedirs(args.store_dir, exist_ok=True)
80
 
81
+ # Progress tracking
82
  total_files = len(mixture_paths)
83
  current_file = 0
84
 
 
85
  for path in mixture_paths:
86
  try:
87
  # Dosya işleme başlangıcı
 
89
  print(f"Processing file {current_file}/{total_files}")
90
 
91
  mix, sr = librosa.load(path, sr=sample_rate, mono=False)
92
+ mix_orig = mix.copy()
93
+
94
+ if 'normalize' in config.inference:
95
+ if config.inference['normalize'] is True:
96
+ mix, norm_params = normalize_audio(mix)
97
+
98
+ # Toplam işlem sürelerini izlemek için başlangıç zamanı
99
+ total_duration = 0.0
100
+ total_steps = 100.0 # Toplam %100
101
+ current_progress = 0.0
102
+
103
+ # Model yükleme ve ilk ayrıştırma (%0 -> %30)
104
+ start_time_step = time.time()
105
+ waveforms_orig = demix(config, model, mix, device, model_type=args.model_type)
106
+ step_duration = time.time() - start_time_step
107
+ total_duration += step_duration
108
+ current_progress += 30.0 * (step_duration / total_duration) if total_duration > 0 else 30.0
109
+ print(f"Progress: {min(current_progress, 30.0):.1f}%")
110
+
111
+ if args.use_tta:
112
+ # TTA işlemi (%30 -> %50)
113
+ start_time_step = time.time()
114
+ waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
115
+ step_duration = time.time() - start_time_step
116
+ total_duration += step_duration
117
+ progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0
118
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
119
+ current_progress = min(30.0 + i, 50.0)
120
+ time.sleep(0.001) # Küçük bir gecikme, gerçek işlem için gereksiz olabilir
121
+ print(f"Progress: {current_progress:.1f}%")
122
+
123
+ if args.demud_phaseremix_inst:
124
+ print(f"Demudding track (phase remix - instrumental): {path}")
125
+ instr = 'vocals' if 'vocals' in instruments else instruments[0]
126
+ instruments.append('instrumental_phaseremix')
127
+ if 'instrumental' not in instruments and 'Instrumental' not in instruments:
128
+ mix_modified = mix_orig - 2*waveforms_orig[instr]
129
+ mix_modified_ = mix_modified.copy()
130
+
131
+ start_time_step = time.time()
132
+ waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
133
+ step_duration = time.time() - start_time_step
134
+ total_duration += step_duration
135
+ progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
136
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
137
+ current_progress = min(50.0 + i, 60.0)
138
+ time.sleep(0.001)
139
+ print(f"Progress: {current_progress:.1f}%")
140
+
141
+ if args.use_tta:
142
+ start_time_step = time.time()
143
+ waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type)
144
+ step_duration = time.time() - start_time_step
145
+ total_duration += step_duration
146
+ progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
147
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
148
+ current_progress = min(60.0 + i, 70.0)
149
+ time.sleep(0.001)
150
+ print(f"Progress: {current_progress:.1f}%")
151
+
152
+ waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
153
+ else:
154
+ mix_modified = 2*waveforms_orig[instr] - mix_orig
155
+ mix_modified_ = mix_modified.copy()
156
+
157
+ start_time_step = time.time()
158
+ waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
159
+ step_duration = time.time() - start_time_step
160
+ total_duration += step_duration
161
+ progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
162
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
163
+ current_progress = min(50.0 + i, 60.0)
164
+ time.sleep(0.001)
165
+ print(f"Progress: {current_progress:.1f}%")
166
+
167
+ if args.use_tta:
168
+ start_time_step = time.time()
169
+ waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type)
170
+ step_duration = time.time() - start_time_step
171
+ total_duration += step_duration
172
+ progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
173
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
174
+ current_progress = min(60.0 + i, 70.0)
175
+ time.sleep(0.001)
176
+ print(f"Progress: {current_progress:.1f}%")
177
+
178
+ waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
179
+ current_progress = 70.0
180
+
181
+ if args.extract_instrumental:
182
+ instr = 'vocals' if 'vocals' in instruments else instruments[0]
183
+ waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
184
+ if 'instrumental' not in instruments:
185
+ instruments.append('instrumental')
186
+
187
+ # Dosya yazma ve finalize (%70 -> %100)
188
+ start_time_step = time.time()
189
+ for instr in instruments:
190
+ estimates = waveforms_orig[instr]
191
+ if 'normalize' in config.inference:
192
+ if config.inference['normalize'] is True:
193
+ estimates = denormalize_audio(estimates, norm_params)
194
+
195
+ # Dosya formatı ve PCM türü belirleme
196
+ is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
197
+ codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
198
+
199
+ # Subtype belirleme
200
+ if codec == 'flac':
201
+ subtype = get_soundfile_subtype(args.pcm_type, is_float)
202
+ else:
203
+ subtype = get_soundfile_subtype('FLOAT', is_float)
204
+
205
+ shortened_filename = shorten_filename(os.path.basename(path))
206
+ output_filename = f"{shortened_filename}_{instr}.{codec}"
207
+ output_path = os.path.join(args.store_dir, output_filename)
208
+
209
+ sf.write(output_path, estimates.T, sr, subtype=subtype)
210
+ step_duration = time.time() - start_time_step
211
+ total_duration += step_duration
212
+ progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0
213
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
214
+ current_progress = min(70.0 + i, 90.0)
215
+ time.sleep(0.001)
216
+ print(f"Progress: {current_progress:.1f}%")
217
+
218
+ # Finalize (%90 -> %100)
219
+ start_time_step = time.time()
220
+ time.sleep(0.1) # Finalize için küçük bir bekleme (gerçek işlem süresiyle değiştirilebilir)
221
+ step_duration = time.time() - start_time_step
222
+ total_duration += step_duration
223
+ progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
224
+ for i in np.arange(0.1, progress_increment + 0.1, 0.1):
225
+ current_progress = min(90.0 + i, 100.0)
226
+ time.sleep(0.001)
227
+ print(f"Progress: {current_progress:.1f}%")
228
+
229
  except Exception as e:
230
  print(f'Cannot read track: {path}')
231
  print(f'Error message: {str(e)}')
232
  continue
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
235
 
236
  def proc_folder(args):
 
310
 
311
 
312
  if __name__ == "__main__":
313
+ proc_folder(None)