Update inference.py
Browse files- inference.py +139 -68
inference.py
CHANGED
@@ -78,11 +78,10 @@ def run_folder(model, args, config, device, verbose: bool = False):
|
|
78 |
instruments = prefer_target_instrument(config)[:]
|
79 |
os.makedirs(args.store_dir, exist_ok=True)
|
80 |
|
81 |
-
#
|
82 |
total_files = len(mixture_paths)
|
83 |
current_file = 0
|
84 |
|
85 |
-
# Progress tracking
|
86 |
for path in mixture_paths:
|
87 |
try:
|
88 |
# Dosya işleme başlangıcı
|
@@ -90,76 +89,148 @@ def run_folder(model, args, config, device, verbose: bool = False):
|
|
90 |
print(f"Processing file {current_file}/{total_files}")
|
91 |
|
92 |
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
except Exception as e:
|
94 |
print(f'Cannot read track: {path}')
|
95 |
print(f'Error message: {str(e)}')
|
96 |
continue
|
97 |
|
98 |
-
mix_orig = mix.copy()
|
99 |
-
if 'normalize' in config.inference:
|
100 |
-
if config.inference['normalize'] is True:
|
101 |
-
mix, norm_params = normalize_audio(mix)
|
102 |
-
|
103 |
-
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type)
|
104 |
-
|
105 |
-
if args.use_tta:
|
106 |
-
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
|
107 |
-
|
108 |
-
if args.demud_phaseremix_inst:
|
109 |
-
print(f"Demudding track (phase remix - instrumental): {path}")
|
110 |
-
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
111 |
-
instruments.append('instrumental_phaseremix')
|
112 |
-
if 'instrumental' not in instruments and 'Instrumental' not in instruments:
|
113 |
-
mix_modified = mix_orig - 2*waveforms_orig[instr]
|
114 |
-
mix_modified_ = mix_modified.copy()
|
115 |
-
|
116 |
-
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
|
117 |
-
if args.use_tta:
|
118 |
-
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type)
|
119 |
-
|
120 |
-
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
|
121 |
-
else:
|
122 |
-
mix_modified = 2*waveforms_orig[instr] - mix_orig
|
123 |
-
mix_modified_ = mix_modified.copy()
|
124 |
-
|
125 |
-
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
|
126 |
-
if args.use_tta:
|
127 |
-
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type)
|
128 |
-
|
129 |
-
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
|
130 |
-
|
131 |
-
if args.extract_instrumental:
|
132 |
-
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
133 |
-
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
|
134 |
-
if 'instrumental' not in instruments:
|
135 |
-
instruments.append('instrumental')
|
136 |
-
|
137 |
-
for instr in instruments:
|
138 |
-
estimates = waveforms_orig[instr]
|
139 |
-
if 'normalize' in config.inference:
|
140 |
-
if config.inference['normalize'] is True:
|
141 |
-
estimates = denormalize_audio(estimates, norm_params)
|
142 |
-
|
143 |
-
# Dosya formatı ve PCM türü belirleme
|
144 |
-
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
|
145 |
-
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
|
146 |
-
|
147 |
-
# Subtype belirleme
|
148 |
-
if codec == 'flac':
|
149 |
-
subtype = get_soundfile_subtype(args.pcm_type, is_float)
|
150 |
-
else:
|
151 |
-
subtype = get_soundfile_subtype('FLOAT', is_float)
|
152 |
-
|
153 |
-
shortened_filename = shorten_filename(os.path.basename(path))
|
154 |
-
output_filename = f"{shortened_filename}_{instr}.{codec}"
|
155 |
-
output_path = os.path.join(args.store_dir, output_filename)
|
156 |
-
|
157 |
-
sf.write(output_path, estimates.T, sr, subtype=subtype)
|
158 |
-
|
159 |
-
# Progress yüzdesi hesaplama
|
160 |
-
progress_percent = int((current_file / total_files) * 100)
|
161 |
-
print(f"Progress: {progress_percent}%")
|
162 |
-
|
163 |
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
|
164 |
|
165 |
def proc_folder(args):
|
@@ -239,4 +310,4 @@ def proc_folder(args):
|
|
239 |
|
240 |
|
241 |
if __name__ == "__main__":
|
242 |
-
proc_folder(None)
|
|
|
78 |
instruments = prefer_target_instrument(config)[:]
|
79 |
os.makedirs(args.store_dir, exist_ok=True)
|
80 |
|
81 |
+
# Progress tracking
|
82 |
total_files = len(mixture_paths)
|
83 |
current_file = 0
|
84 |
|
|
|
85 |
for path in mixture_paths:
|
86 |
try:
|
87 |
# Dosya işleme başlangıcı
|
|
|
89 |
print(f"Processing file {current_file}/{total_files}")
|
90 |
|
91 |
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
92 |
+
mix_orig = mix.copy()
|
93 |
+
|
94 |
+
if 'normalize' in config.inference:
|
95 |
+
if config.inference['normalize'] is True:
|
96 |
+
mix, norm_params = normalize_audio(mix)
|
97 |
+
|
98 |
+
# Toplam işlem sürelerini izlemek için başlangıç zamanı
|
99 |
+
total_duration = 0.0
|
100 |
+
total_steps = 100.0 # Toplam %100
|
101 |
+
current_progress = 0.0
|
102 |
+
|
103 |
+
# Model yükleme ve ilk ayrıştırma (%0 -> %30)
|
104 |
+
start_time_step = time.time()
|
105 |
+
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type)
|
106 |
+
step_duration = time.time() - start_time_step
|
107 |
+
total_duration += step_duration
|
108 |
+
current_progress += 30.0 * (step_duration / total_duration) if total_duration > 0 else 30.0
|
109 |
+
print(f"Progress: {min(current_progress, 30.0):.1f}%")
|
110 |
+
|
111 |
+
if args.use_tta:
|
112 |
+
# TTA işlemi (%30 -> %50)
|
113 |
+
start_time_step = time.time()
|
114 |
+
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
|
115 |
+
step_duration = time.time() - start_time_step
|
116 |
+
total_duration += step_duration
|
117 |
+
progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0
|
118 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
119 |
+
current_progress = min(30.0 + i, 50.0)
|
120 |
+
time.sleep(0.001) # Küçük bir gecikme, gerçek işlem için gereksiz olabilir
|
121 |
+
print(f"Progress: {current_progress:.1f}%")
|
122 |
+
|
123 |
+
if args.demud_phaseremix_inst:
|
124 |
+
print(f"Demudding track (phase remix - instrumental): {path}")
|
125 |
+
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
126 |
+
instruments.append('instrumental_phaseremix')
|
127 |
+
if 'instrumental' not in instruments and 'Instrumental' not in instruments:
|
128 |
+
mix_modified = mix_orig - 2*waveforms_orig[instr]
|
129 |
+
mix_modified_ = mix_modified.copy()
|
130 |
+
|
131 |
+
start_time_step = time.time()
|
132 |
+
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
|
133 |
+
step_duration = time.time() - start_time_step
|
134 |
+
total_duration += step_duration
|
135 |
+
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
|
136 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
137 |
+
current_progress = min(50.0 + i, 60.0)
|
138 |
+
time.sleep(0.001)
|
139 |
+
print(f"Progress: {current_progress:.1f}%")
|
140 |
+
|
141 |
+
if args.use_tta:
|
142 |
+
start_time_step = time.time()
|
143 |
+
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type)
|
144 |
+
step_duration = time.time() - start_time_step
|
145 |
+
total_duration += step_duration
|
146 |
+
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
|
147 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
148 |
+
current_progress = min(60.0 + i, 70.0)
|
149 |
+
time.sleep(0.001)
|
150 |
+
print(f"Progress: {current_progress:.1f}%")
|
151 |
+
|
152 |
+
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
|
153 |
+
else:
|
154 |
+
mix_modified = 2*waveforms_orig[instr] - mix_orig
|
155 |
+
mix_modified_ = mix_modified.copy()
|
156 |
+
|
157 |
+
start_time_step = time.time()
|
158 |
+
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type)
|
159 |
+
step_duration = time.time() - start_time_step
|
160 |
+
total_duration += step_duration
|
161 |
+
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
|
162 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
163 |
+
current_progress = min(50.0 + i, 60.0)
|
164 |
+
time.sleep(0.001)
|
165 |
+
print(f"Progress: {current_progress:.1f}%")
|
166 |
+
|
167 |
+
if args.use_tta:
|
168 |
+
start_time_step = time.time()
|
169 |
+
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type)
|
170 |
+
step_duration = time.time() - start_time_step
|
171 |
+
total_duration += step_duration
|
172 |
+
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
|
173 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
174 |
+
current_progress = min(60.0 + i, 70.0)
|
175 |
+
time.sleep(0.001)
|
176 |
+
print(f"Progress: {current_progress:.1f}%")
|
177 |
+
|
178 |
+
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
|
179 |
+
current_progress = 70.0
|
180 |
+
|
181 |
+
if args.extract_instrumental:
|
182 |
+
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
183 |
+
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
|
184 |
+
if 'instrumental' not in instruments:
|
185 |
+
instruments.append('instrumental')
|
186 |
+
|
187 |
+
# Dosya yazma ve finalize (%70 -> %100)
|
188 |
+
start_time_step = time.time()
|
189 |
+
for instr in instruments:
|
190 |
+
estimates = waveforms_orig[instr]
|
191 |
+
if 'normalize' in config.inference:
|
192 |
+
if config.inference['normalize'] is True:
|
193 |
+
estimates = denormalize_audio(estimates, norm_params)
|
194 |
+
|
195 |
+
# Dosya formatı ve PCM türü belirleme
|
196 |
+
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
|
197 |
+
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
|
198 |
+
|
199 |
+
# Subtype belirleme
|
200 |
+
if codec == 'flac':
|
201 |
+
subtype = get_soundfile_subtype(args.pcm_type, is_float)
|
202 |
+
else:
|
203 |
+
subtype = get_soundfile_subtype('FLOAT', is_float)
|
204 |
+
|
205 |
+
shortened_filename = shorten_filename(os.path.basename(path))
|
206 |
+
output_filename = f"{shortened_filename}_{instr}.{codec}"
|
207 |
+
output_path = os.path.join(args.store_dir, output_filename)
|
208 |
+
|
209 |
+
sf.write(output_path, estimates.T, sr, subtype=subtype)
|
210 |
+
step_duration = time.time() - start_time_step
|
211 |
+
total_duration += step_duration
|
212 |
+
progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0
|
213 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
214 |
+
current_progress = min(70.0 + i, 90.0)
|
215 |
+
time.sleep(0.001)
|
216 |
+
print(f"Progress: {current_progress:.1f}%")
|
217 |
+
|
218 |
+
# Finalize (%90 -> %100)
|
219 |
+
start_time_step = time.time()
|
220 |
+
time.sleep(0.1) # Finalize için küçük bir bekleme (gerçek işlem süresiyle değiştirilebilir)
|
221 |
+
step_duration = time.time() - start_time_step
|
222 |
+
total_duration += step_duration
|
223 |
+
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0
|
224 |
+
for i in np.arange(0.1, progress_increment + 0.1, 0.1):
|
225 |
+
current_progress = min(90.0 + i, 100.0)
|
226 |
+
time.sleep(0.001)
|
227 |
+
print(f"Progress: {current_progress:.1f}%")
|
228 |
+
|
229 |
except Exception as e:
|
230 |
print(f'Cannot read track: {path}')
|
231 |
print(f'Error message: {str(e)}')
|
232 |
continue
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
|
235 |
|
236 |
def proc_folder(args):
|
|
|
310 |
|
311 |
|
312 |
if __name__ == "__main__":
|
313 |
+
proc_folder(None)
|