abdu-l7hman commited on
Commit
1154abd
Β·
0 Parent(s):

Initial commit with model and app

Browse files
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # .dockerignore
2
+
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .DS_Store
8
+ .env
9
+ venv/
10
+ .git/
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ models/*.pth filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+
3
+ # 1. Base Image
4
+ FROM python:3.9-slim
5
+
6
+ # 2. Install system dependencies (ffmpeg, libsndfile)
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ libsndfile1 \
9
+ ffmpeg \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # 3. Create a non-root user (Required for Hugging Face Spaces)
13
+ RUN useradd -m -u 1000 user
14
+ USER user
15
+ ENV PATH="/home/user/.local/bin:$PATH"
16
+
17
+ # 4. Set working directory
18
+ WORKDIR /home/user/app
19
+
20
+ # 5. Copy requirements and install
21
+ COPY --chown=user requirements.txt .
22
+ RUN pip install --no-cache-dir --timeout=600 -r requirements.txt
23
+
24
+ # 6. Copy the application code
25
+ COPY --chown=user . .
26
+
27
+ # 7. Expose the correct port for Hugging Face
28
+ EXPOSE 7860
29
+
30
+ # 8. Run with Gunicorn on port 7860
31
+ # Note: Increased timeout to 120s because audio processing on CPU can be slow
32
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "1", "--timeout", "120", "app:app"]
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from flask_cors import CORS # <--- ADDED THIS
3
+ import os
4
+ from werkzeug.utils import secure_filename
5
+ import time
6
+ import sys
7
+ import subprocess
8
+ import shutil
9
+
10
+ # Ensure we can find the models folder
11
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
12
+
13
+ # Lazy-import model dependencies
14
+ try:
15
+ from models.stutter_detector_local import ImprovedStutterDetector, calculate_stutter_severity
16
+ MODEL_DEPS_AVAILABLE = True
17
+ except Exception as e:
18
+ print(f"Model dependencies unavailable: {e}")
19
+ ImprovedStutterDetector = None
20
+ def calculate_stutter_severity(_):
21
+ return None
22
+ MODEL_DEPS_AVAILABLE = False
23
+
24
+ app = Flask(__name__)
25
+ CORS(app) # <--- ENABLE CORS HERE
26
+
27
+ # config
28
+ UPLOAD_FOLDER = '/tmp/uploads' # Use /tmp because other folders might be read-only on HF
29
+ ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'}
30
+ MODEL_PATH = 'models/stutter_detector_all_types.pth'
31
+
32
+ app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
33
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
34
+
35
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
36
+
37
+ # Check ffmpeg
38
+ FFMPEG_AVAILABLE = shutil.which('ffmpeg') is not None
39
+
40
+ # Load Model
41
+ print("Loading stutter detection model...")
42
+ detector = None
43
+ if MODEL_DEPS_AVAILABLE and ImprovedStutterDetector is not None:
44
+ try:
45
+ detector = ImprovedStutterDetector(MODEL_PATH, device='cpu') # Force CPU
46
+ print("Model loaded successfully!")
47
+ except Exception as e:
48
+ print(f"Error loading model: {e}")
49
+ detector = None
50
+ else:
51
+ print("Skipping model load due to missing dependencies.")
52
+
53
+ def allowed_file(filename):
54
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
55
+
56
+ def convert_to_wav(input_path, output_path):
57
+ try:
58
+ subprocess.run([
59
+ 'ffmpeg', '-i', input_path,
60
+ '-acodec', 'pcm_s16le',
61
+ '-ar', '16000',
62
+ '-ac', '1',
63
+ '-y',
64
+ output_path
65
+ ], check=True, capture_output=True)
66
+ return True
67
+ except subprocess.CalledProcessError as e:
68
+ print(f"FFmpeg conversion error: {e.stderr.decode()}")
69
+ return False
70
+ except FileNotFoundError:
71
+ return False
72
+
73
+ # ... [Keep your analyze_audio_file function exactly as it is] ...
74
+ # ... [Paste the analyze_audio_file function here from your original code] ...
75
+
76
+ # Route: Only strictly necessary endpoints for your API
77
+ @app.route('/health', methods=['GET'])
78
+ def health_check():
79
+ return jsonify({'status': 'healthy', 'model_loaded': detector is not None}), 200
80
+
81
+ @app.route('/upload', methods=['POST'])
82
+ def upload_file():
83
+ # ... [Keep your existing upload logic exactly as it is] ...
84
+ # Just ensure you use the analyze_audio_file function defined above
85
+ if detector is None:
86
+ return jsonify({'error': 'Model not loaded.'}), 500
87
+
88
+ if 'audio' not in request.files:
89
+ return jsonify({'error': 'No audio file provided'}), 400
90
+
91
+ file = request.files['audio']
92
+
93
+ if file and allowed_file(file.filename):
94
+ filename = secure_filename(file.filename)
95
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
96
+ file.save(filepath)
97
+
98
+ # Get params from request or default
99
+ segment_duration = float(request.form.get('segment_duration', 3.0))
100
+ stutter_threshold = float(request.form.get('stutter_threshold', 0.5))
101
+
102
+ try:
103
+ # Run analysis
104
+ results = detector.analyze_audio_file(
105
+ filepath,
106
+ segment_duration=segment_duration,
107
+ stutter_threshold=stutter_threshold
108
+ )
109
+
110
+ # Calculate severity
111
+ results['severity'] = calculate_stutter_severity(results)
112
+
113
+ # Cleanup
114
+ if os.path.exists(filepath): os.remove(filepath)
115
+
116
+ return jsonify(results), 200
117
+
118
+ except Exception as e:
119
+ if os.path.exists(filepath): os.remove(filepath)
120
+ return jsonify({'error': str(e)}), 500
121
+
122
+ return jsonify({'error': 'Invalid file'}), 400
123
+
124
+ if __name__ == '__main__':
125
+ # This is only for local testing, Docker uses Gunicorn
126
+ app.run(host='0.0.0.0', port=7860)
models/__pycache__/stutter_detector_local.cpython-312.pyc ADDED
Binary file (14.8 kB). View file
 
models/__pycache__/stutter_detector_local.cpython-314.pyc ADDED
Binary file (16.2 kB). View file
 
models/stutter_detector_all_types.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:819effa4e2727f82295d9f7c7cd2647159ef55001c8f7eda7cb009130d45ea56
3
+ size 378509393
models/stutter_detector_local.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================
2
+ # LOCAL PC STUTTER DETECTION SETUP
3
+ # Run this on your local machine
4
+ # ============================================
5
+
6
+ import os
7
+ import numpy as np
8
+ import librosa
9
+ import torch
10
+ import torch.nn as nn
11
+ import transformers
12
+ from typing import List, Tuple
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+
16
+ # ============================================
17
+ # MODEL ARCHITECTURE (MUST MATCH TRAINING)
18
+ # ============================================
19
+
20
+ class ImprovedWav2VecClassifier(nn.Module):
21
+ """Improved classifier matching training architecture."""
22
+
23
+ def __init__(self, hidden_dim=768, intermediate_dim=256, output_dim=2, dropout=0.3):
24
+ super().__init__()
25
+
26
+ # Load pre-trained Wav2Vec model
27
+ self.wav2vec = transformers.Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
28
+
29
+ # Freeze Wav2Vec parameters
30
+ for param in self.wav2vec.parameters():
31
+ param.requires_grad = False
32
+
33
+ # Classification head
34
+ self.classifier = nn.Sequential(
35
+ nn.Linear(hidden_dim, intermediate_dim),
36
+ nn.BatchNorm1d(intermediate_dim),
37
+ nn.ReLU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(intermediate_dim, intermediate_dim // 2),
40
+ nn.BatchNorm1d(intermediate_dim // 2),
41
+ nn.ReLU(),
42
+ nn.Dropout(dropout),
43
+ nn.Linear(intermediate_dim // 2, output_dim)
44
+ )
45
+
46
+ def forward(self, x):
47
+ with torch.no_grad():
48
+ encoder_output = self.wav2vec(x).last_hidden_state
49
+ pooled_features = encoder_output.mean(dim=1)
50
+ return self.classifier(pooled_features)
51
+
52
+
53
+ # ============================================
54
+ # FEATURE EXTRACTOR
55
+ # ============================================
56
+
57
+ class Wav2VecFeatureExtractor:
58
+ """Extract features from audio files."""
59
+
60
+ def __init__(self, model_name='facebook/wav2vec2-base', duration=3):
61
+ self.processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained(model_name)
62
+ self.duration = duration
63
+ self.sample_rate = 16000
64
+
65
+ def extract_features(self, audio_data, sr):
66
+ try:
67
+ if sr != self.sample_rate:
68
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
69
+
70
+ features = self.processor(audio_data, sampling_rate=self.sample_rate, return_tensors='pt').input_values
71
+ return features.squeeze(0)
72
+ except Exception as e:
73
+ print(f"Error extracting features: {e}")
74
+ return None
75
+
76
+
77
+ # ============================================
78
+ # AUDIO PROCESSING FUNCTIONS
79
+ # ============================================
80
+
81
+ def load_audio_file(file_path: str) -> Tuple[np.ndarray, int]:
82
+ """Load an audio file."""
83
+ try:
84
+ audio_data, sr = librosa.load(file_path, sr=None)
85
+ return audio_data, sr
86
+ except Exception as e:
87
+ raise Exception(f"Error loading audio file: {e}")
88
+
89
+
90
+ def segment_audio(audio_data: np.ndarray, sr: int, segment_duration: float = 3.0) -> List[np.ndarray]:
91
+ """Split audio into fixed-duration segments."""
92
+ segment_samples = int(segment_duration * sr)
93
+ segments = []
94
+
95
+ for i in range(0, len(audio_data), segment_samples):
96
+ segment = audio_data[i:i + segment_samples]
97
+
98
+ if len(segment) >= sr: # At least 1 second
99
+ if len(segment) < segment_samples:
100
+ padding = segment_samples - len(segment)
101
+ segment = np.pad(segment, (0, padding), mode='constant')
102
+ segments.append(segment)
103
+
104
+ return segments
105
+
106
+
107
+ def pad_or_truncate_features(features: torch.Tensor, max_length: int = 32007) -> torch.Tensor:
108
+ """Pad or truncate features to match expected input length."""
109
+ if features.size(0) < max_length:
110
+ padding = max_length - features.size(0)
111
+ features = torch.cat([features, torch.zeros(padding)], dim=0)
112
+ elif features.size(0) > max_length:
113
+ features = features[:max_length]
114
+ return features
115
+
116
+
117
+ # ============================================
118
+ # STUTTER DETECTOR CLASS
119
+ # ============================================
120
+
121
+ class ImprovedStutterDetector:
122
+ """Stutter detector for all types: prolongations, blocks, repetitions, interjections."""
123
+
124
+ def __init__(self, model_path: str, device: str = None):
125
+ # Set device
126
+ if device is None:
127
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
+ else:
129
+ self.device = torch.device(device)
130
+
131
+ print(f"Using device: {self.device}")
132
+
133
+ # Load model
134
+ print("Loading model...")
135
+ self.model = ImprovedWav2VecClassifier()
136
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
137
+ self.model.to(self.device)
138
+ self.model.eval()
139
+ print("βœ“ Model loaded successfully!")
140
+
141
+ # Initialize feature extractor
142
+ self.feature_extractor = Wav2VecFeatureExtractor(duration=3)
143
+
144
+ # Class names
145
+ self.class_names = ['No Stutter', 'Stutter (All Types)']
146
+
147
+ print("\nThis model detects ALL stutter types:")
148
+ print(" β€’ Prolongations (ssssso)")
149
+ print(" β€’ Blocks (getting stuck)")
150
+ print(" β€’ Sound Repetitions (b-b-ball)")
151
+ print(" β€’ Word Repetitions (I-I-I want)")
152
+ print(" β€’ Interjections (um, uh)")
153
+
154
+ def analyze_audio_file(self, file_path: str, segment_duration: float = 3.0,
155
+ stutter_threshold: float = 0.5, show_probabilities: bool = True) -> dict:
156
+ """Analyze an entire audio file for stuttering."""
157
+
158
+ print(f"\n{'='*70}")
159
+ print(f"ANALYZING: {os.path.basename(file_path)}")
160
+ print(f"{'='*70}")
161
+
162
+ # Load audio
163
+ audio_data, sr = load_audio_file(file_path)
164
+ duration = len(audio_data) / sr
165
+ print(f"πŸ“Š Audio duration: {duration:.2f} seconds")
166
+
167
+ # Segment audio
168
+ segments = segment_audio(audio_data, sr, segment_duration)
169
+ print(f"πŸ“Š Number of segments: {len(segments)}")
170
+
171
+ if len(segments) == 0:
172
+ return {'error': 'Audio too short for analysis (minimum 1 second required)'}
173
+
174
+ # Analyze each segment
175
+ results = []
176
+ stutter_count = 0
177
+
178
+ print(f"\n{'='*70}")
179
+ print("SEGMENT ANALYSIS")
180
+ print(f"{'='*70}")
181
+
182
+ for i, segment in enumerate(segments):
183
+ features = self.feature_extractor.extract_features(segment, sr)
184
+ if features is None:
185
+ continue
186
+
187
+ features = pad_or_truncate_features(features)
188
+ features = features.unsqueeze(0).to(self.device)
189
+
190
+ with torch.no_grad():
191
+ outputs = self.model(features)
192
+ probabilities = torch.softmax(outputs, dim=1)
193
+ predicted_class = torch.argmax(probabilities, dim=1).item()
194
+ confidence = probabilities[0][predicted_class].item()
195
+
196
+ no_stutter_prob = probabilities[0][0].item()
197
+ stutter_prob = probabilities[0][1].item()
198
+
199
+ is_stutter = stutter_prob >= stutter_threshold
200
+
201
+ results.append({
202
+ 'segment': i + 1,
203
+ 'prediction': self.class_names[predicted_class],
204
+ 'confidence': confidence,
205
+ 'is_stutter': is_stutter,
206
+ 'no_stutter_probability': no_stutter_prob,
207
+ 'stutter_probability': stutter_prob
208
+ })
209
+
210
+ if is_stutter:
211
+ stutter_count += 1
212
+
213
+ if show_probabilities:
214
+ status_emoji = "πŸ”΄" if is_stutter else "🟒"
215
+ status_text = "STUTTER DETECTED" if is_stutter else "Clear"
216
+ print(f"{status_emoji} Segment {i+1}: {status_text}")
217
+ print(f" No Stutter: {no_stutter_prob:.2%} | Stutter: {stutter_prob:.2%}")
218
+
219
+ # Calculate statistics
220
+ total_segments = len(results)
221
+ stutter_percentage = (stutter_count / total_segments * 100) if total_segments > 0 else 0
222
+
223
+ print(f"\n{'='*70}")
224
+ print("FINAL RESULTS")
225
+ print(f"{'='*70}")
226
+ print(f"βœ“ Total segments analyzed: {total_segments}")
227
+ print(f"πŸ”΄ Segments with stutter: {stutter_count}")
228
+ print(f"🟒 Segments without stutter: {total_segments - stutter_count}")
229
+ print(f"πŸ“Š Stuttering percentage: {stutter_percentage:.1f}%")
230
+
231
+ return {
232
+ 'file_path': file_path,
233
+ 'duration': duration,
234
+ 'total_segments': total_segments,
235
+ 'stutter_count': stutter_count,
236
+ 'no_stutter_count': total_segments - stutter_count,
237
+ 'stutter_percentage': stutter_percentage,
238
+ 'segment_results': results
239
+ }
240
+
241
+
242
+ # ============================================
243
+ # SEVERITY ANALYSIS
244
+ # ============================================
245
+
246
+ def calculate_stutter_severity(results):
247
+ """Calculate detailed stutter severity metrics."""
248
+ segment_results = results['segment_results']
249
+ stutter_probs = [seg['stutter_probability'] for seg in segment_results]
250
+
251
+ avg_prob = sum(stutter_probs) / len(stutter_probs)
252
+ max_prob = max(stutter_probs)
253
+ min_prob = min(stutter_probs)
254
+
255
+ # Count segments by severity
256
+ severe = sum(1 for p in stutter_probs if p > 0.6)
257
+ moderate = sum(1 for p in stutter_probs if 0.4 < p <= 0.6)
258
+ mild = sum(1 for p in stutter_probs if 0.2 < p <= 0.4)
259
+ minimal = sum(1 for p in stutter_probs if p <= 0.2)
260
+
261
+ # Calculate severity score as stutters / total segments
262
+ total_segments = results.get('total_segments', 0)
263
+ stutter_count = results.get('stutter_count', 0)
264
+ severity_score = stutter_count / total_segments if total_segments > 0 else 0.0
265
+
266
+ print(f"\n{'='*70}")
267
+ print("DETAILED SEVERITY ANALYSIS")
268
+ print(f"{'='*70}")
269
+ print(f"Average stutter probability: {avg_prob:.2%}")
270
+ print(f"Peak stutter probability: {max_prob:.2%}")
271
+ print(f"Minimum stutter probability: {min_prob:.2%}")
272
+ print(f"\nSegment Severity Breakdown:")
273
+ print(f" πŸ”΄ Severe (>70%): {severe} segments")
274
+ print(f" 🟠 Moderate (40-70%): {moderate} segments")
275
+ print(f" 🟑 Mild (20-40%): {mild} segments")
276
+ print(f" 🟒 Minimal (<20%): {minimal} segments")
277
+
278
+ # Overall severity
279
+ if avg_prob < 0.15:
280
+ severity = "βœ“ Minimal or No Stuttering"
281
+ elif avg_prob < 0.35:
282
+ severity = "⚠️ Mild Stuttering"
283
+ elif avg_prob < 0.60:
284
+ severity = "⚠️ Moderate Stuttering"
285
+ else:
286
+ severity = "πŸ”΄ Significant Stuttering"
287
+
288
+ print(f"\n🎯 Overall Assessment: {severity}")
289
+ print(f"πŸ“Š Severity Score: {severity_score:.2%} (stutters/total segments)")
290
+
291
+ return {
292
+ 'average_probability': avg_prob,
293
+ 'max_probability': max_prob,
294
+ 'severity_level': severity,
295
+ 'severity_score': severity_score,
296
+ 'severe_segments': severe,
297
+ 'moderate_segments': moderate,
298
+ 'mild_segments': mild,
299
+ 'minimal_segments': minimal
300
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ flask-cors
3
+ Werkzeug
4
+ --extra-index-url https://download.pytorch.org/whl/cpu
5
+ torch
6
+ transformers
7
+ librosa
8
+ soundfile
9
+ numpy
10
+ scipy
11
+ gunicorn