alexwengg commited on
Commit
0709c81
·
verified ·
1 Parent(s): d35a43b

Delete main2.py

Browse files
Files changed (1) hide show
  1. main2.py +0 -289
main2.py DELETED
@@ -1,289 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Optimal VAD Implementation using RNN Decoder + Fixed Classifier
4
-
5
- This uses the best combination discovered:
6
- - silero_rnn_decoder.mlmodel (proper output magnitudes)
7
- - correct_classifier_conv1d.mlpackage (fixed Conv1d)
8
- """
9
-
10
- import os
11
- import librosa
12
- import coremltools as ct
13
- import numpy as np
14
-
15
-
16
- class OptimalCoreMLVAD:
17
- """
18
- Optimal VAD using RNN Decoder + Fixed Classifier
19
- """
20
- def __init__(self):
21
- """Initialize the VAD pipeline with optimal models"""
22
- print("Loading Optimal CoreML models...")
23
-
24
- # Load existing preprocessing models with explicit ANE preference
25
- self.stft_model = ct.models.MLModel("silero_stft.mlmodel", compute_units=ct.ComputeUnit.ALL)
26
- self.encoder_model = ct.models.MLModel("silero_encoder.mlmodel", compute_units=ct.ComputeUnit.ALL)
27
-
28
- # Load OPTIMAL combination with ANE preference
29
- self.rnn_model = ct.models.MLModel("silero_rnn_decoder.mlmodel", compute_units=ct.ComputeUnit.ALL)
30
- self.classifier_model = ct.models.MLModel("correct_classifier_conv1d.mlpackage", compute_units=ct.ComputeUnit.ALL)
31
-
32
- print("✅ Optimal models loaded:")
33
- print(" - STFT: silero_stft.mlmodel")
34
- print(" - Encoder: silero_encoder.mlmodel")
35
- print(" - RNN: silero_rnn_decoder.mlmodel (🥇 BEST)")
36
- print(" - Classifier: correct_classifier_conv1d.mlpackage (🔧 FIXED)")
37
- print("🧠 All models configured for Neural Engine (ANE) acceleration")
38
-
39
- # Initialize state for RNN Decoder (requires 3D states)
40
- self.h_state = np.zeros((1, 1, 128), dtype=np.float32)
41
- self.c_state = np.zeros((1, 1, 128), dtype=np.float32)
42
-
43
- # Initialize feature buffer for temporal context
44
- self.feature_buffer = []
45
-
46
- print("✅ Optimal VAD loaded successfully!")
47
-
48
- def reset_state(self):
49
- """Reset the RNN state and feature buffer"""
50
- self.h_state = np.zeros((1, 1, 128), dtype=np.float32)
51
- self.c_state = np.zeros((1, 1, 128), dtype=np.float32)
52
-
53
- if hasattr(self, 'feature_buffer'):
54
- self.feature_buffer = []
55
-
56
- def process_chunk(self, audio_chunk):
57
- """Process audio chunk using optimal model combination"""
58
- # Ensure correct shape
59
- if audio_chunk.ndim == 1:
60
- audio_chunk = audio_chunk.reshape(1, -1)
61
-
62
- # STFT processing
63
- stft_result = self.stft_model.predict({"audio_input": audio_chunk})
64
- stft_output_key = list(stft_result.keys())[0]
65
- stft_features = stft_result[stft_output_key]
66
-
67
- # Temporal context management
68
- if not hasattr(self, 'feature_buffer'):
69
- self.feature_buffer = []
70
-
71
- # Add current features to buffer
72
- self.feature_buffer.append(stft_features)
73
-
74
- # Keep only the last 4 frames for temporal context
75
- if len(self.feature_buffer) > 4:
76
- self.feature_buffer = self.feature_buffer[-4:]
77
-
78
- # Pad with zeros if we have less than 4 frames
79
- while len(self.feature_buffer) < 4:
80
- self.feature_buffer.insert(0, np.zeros_like(stft_features))
81
-
82
- # Concatenate along time dimension
83
- stft_features = np.concatenate(self.feature_buffer, axis=-1)
84
-
85
- # Encoder processing
86
- encoder_result = self.encoder_model.predict({"stft_features": stft_features})
87
- encoder_output_key = list(encoder_result.keys())[0]
88
- encoder_features = encoder_result[encoder_output_key]
89
-
90
- # Reshape encoder features for RNN
91
- encoder_features = np.transpose(encoder_features, (0, 2, 1)) # (1, T, 64)
92
-
93
- # Take only the last 4 timesteps
94
- if encoder_features.shape[1] > 4:
95
- encoder_features = encoder_features[:, -4:, :]
96
- elif encoder_features.shape[1] < 4:
97
- # Pad with zeros if needed
98
- padding = 4 - encoder_features.shape[1]
99
- pad_shape = (encoder_features.shape[0], padding, encoder_features.shape[2])
100
- encoder_features = np.concatenate([np.zeros(pad_shape), encoder_features], axis=1)
101
-
102
- # Ensure the feature dimension is 128 for RNN
103
- if encoder_features.shape[2] != 128:
104
- # Resize/pad to 128 dimensions
105
- if encoder_features.shape[2] > 128:
106
- encoder_features = encoder_features[:, :, :128]
107
- else:
108
- padding = 128 - encoder_features.shape[2]
109
- pad_shape = (encoder_features.shape[0], encoder_features.shape[1], padding)
110
- encoder_features = np.concatenate([encoder_features, np.zeros(pad_shape)], axis=2)
111
-
112
- # RNN Decoder processing with proper state management
113
- rnn_result = self.rnn_model.predict({
114
- "encoder_features": encoder_features,
115
- "h_in": self.h_state,
116
- "c_in": self.c_state
117
- })
118
-
119
- # Extract RNN Decoder outputs properly
120
- rnn_features = None
121
- new_h_state = None
122
- new_c_state = None
123
-
124
- # RNN Decoder has specific output names - find them by shape
125
- for key, value in rnn_result.items():
126
- if len(value.shape) == 3 and value.shape[1] > 1: # Sequence output
127
- rnn_features = value
128
- elif len(value.shape) == 3 and value.shape == (1, 1, 128): # State outputs
129
- if new_h_state is None:
130
- new_h_state = value
131
- else:
132
- new_c_state = value
133
-
134
- # Update states for next chunk
135
- if new_h_state is not None:
136
- self.h_state = new_h_state
137
- if new_c_state is not None:
138
- self.c_state = new_c_state
139
-
140
- # Ensure we have the sequence output
141
- if rnn_features is None:
142
- raise RuntimeError("Could not find RNN sequence output")
143
-
144
- # Ensure correct shape for classifier (1, 4, 128)
145
- if rnn_features.shape != (1, 4, 128):
146
- if rnn_features.shape[1] != 4:
147
- if rnn_features.shape[1] > 4:
148
- rnn_features = rnn_features[:, -4:, :]
149
- else:
150
- last_timestep = rnn_features[:, -1:, :]
151
- padding_needed = 4 - rnn_features.shape[1]
152
- padding = np.repeat(last_timestep, padding_needed, axis=1)
153
- rnn_features = np.concatenate([rnn_features, padding], axis=1)
154
-
155
- if rnn_features.shape[2] != 128:
156
- if rnn_features.shape[2] > 128:
157
- rnn_features = rnn_features[:, :, :128]
158
- else:
159
- padding = 128 - rnn_features.shape[2]
160
- pad_shape = (rnn_features.shape[0], rnn_features.shape[1], padding)
161
- rnn_features = np.concatenate([rnn_features, np.zeros(pad_shape)], axis=2)
162
-
163
- # Classifier processing with fixed Conv1d model (clean output!)
164
- classifier_result = self.classifier_model.predict({"rnn_features": rnn_features})
165
- classifier_output_key = list(classifier_result.keys())[0]
166
- vad_prob = float(classifier_result[classifier_output_key].squeeze())
167
-
168
- return vad_prob
169
-
170
-
171
- def process_file(filename, vad, sample_rate=16000, chunk_size=512, threshold=0.5):
172
- """Process audio file with VAD and display results"""
173
- print(f"\n🎧 Processing: {filename}")
174
-
175
- # Reset state for new file
176
- vad.reset_state()
177
-
178
- # Load audio
179
- y, _ = librosa.load(filename, sr=sample_rate)
180
- if y.ndim > 1:
181
- y = librosa.to_mono(y)
182
-
183
- num_chunks = len(y) // chunk_size
184
- vad_scores = []
185
-
186
- for i in range(num_chunks):
187
- start = i * chunk_size
188
- end = start + chunk_size
189
- chunk = y[start:end]
190
- if len(chunk) < chunk_size:
191
- break # Skip last short chunk
192
-
193
- prob = vad.process_chunk(chunk.astype(np.float32))
194
- vad_scores.append(prob)
195
-
196
- # Average VAD probability across all chunks
197
- avg_vad = np.mean(vad_scores) if vad_scores else 0.0
198
- status = "🟢 Speech" if avg_vad >= threshold else "⚫️ Silence"
199
-
200
- print(f"{os.path.basename(filename):<18} | Avg VAD: {avg_vad:.4f} | {status}")
201
-
202
-
203
- def test_optimal_vad():
204
- """Test the optimal VAD implementation"""
205
- print("🚀 Testing OPTIMAL VAD Implementation")
206
- print("=" * 60)
207
- print("🥇 Using BEST model combination:")
208
- print(" - RNN: silero_rnn_decoder.mlmodel")
209
- print(" - Classifier: correct_classifier_conv1d.mlpackage")
210
- print()
211
-
212
- vad = OptimalCoreMLVAD()
213
-
214
- test_folder = "test"
215
- if not os.path.exists(test_folder):
216
- print(f"❌ Test folder '{test_folder}' not found!")
217
- return
218
-
219
- test_files = sorted(f for f in os.listdir(test_folder) if f.endswith(".mp3"))
220
-
221
- if not test_files:
222
- print(f"❌ No MP3 files found in '{test_folder}' folder!")
223
- return
224
-
225
- print(f"{'File':<18} | {'VAD Score':<9} | {'Result'}")
226
- print("-" * 50)
227
-
228
- human_scores = []
229
- ambient_scores = []
230
-
231
- for file in test_files:
232
- full_path = os.path.join(test_folder, file)
233
-
234
- # Capture the score for analysis
235
- vad.reset_state()
236
- y, _ = librosa.load(full_path, sr=16000)
237
- if y.ndim > 1:
238
- y = librosa.to_mono(y)
239
-
240
- chunk_size = 512
241
- num_chunks = min(10, len(y) // chunk_size)
242
- vad_scores = []
243
-
244
- for i in range(num_chunks):
245
- start = i * chunk_size
246
- end = start + chunk_size
247
- chunk = y[start:end]
248
- if len(chunk) < chunk_size:
249
- break
250
- prob = vad.process_chunk(chunk.astype(np.float32))
251
- vad_scores.append(prob)
252
-
253
- avg_vad = np.mean(vad_scores) if vad_scores else 0.0
254
-
255
- # Categorize for analysis
256
- if "human" in file:
257
- human_scores.append(avg_vad)
258
- elif "ambient" in file:
259
- ambient_scores.append(avg_vad)
260
-
261
- # Display result
262
- status = "🟢 Speech" if avg_vad >= 0.5 else "⚫️ Silence"
263
- print(f"{os.path.basename(file):<18} | {avg_vad:.4f} | {status}")
264
-
265
- # Analysis
266
- if human_scores and ambient_scores:
267
- human_avg = np.mean(human_scores)
268
- ambient_avg = np.mean(ambient_scores)
269
- separation = human_avg - ambient_avg
270
-
271
- print(f"\n📊 PERFORMANCE ANALYSIS:")
272
- print(f" 👤 Human average: {human_avg:.4f}")
273
- print(f" 🌿 Ambient average: {ambient_avg:.4f}")
274
- print(f" 📈 Separation: {separation:.4f}")
275
-
276
- if separation > 0.05:
277
- print(f" ✅ EXCELLENT: Strong separation")
278
- elif separation > 0.01:
279
- print(f" ✅ GOOD: Clear separation")
280
- elif separation > 0:
281
- print(f" ⚠️ WEAK: Small separation")
282
- else:
283
- print(f" ❌ POOR: No separation or inverted")
284
-
285
- print("\n✅ Optimal VAD testing completed!")
286
-
287
-
288
- if __name__ == "__main__":
289
- test_optimal_vad()