abdu-l7hman commited on
Commit
8d3d1a5
·
1 Parent(s): 0c5d983

Update model to augmented version and add ffmpeg

Browse files
Files changed (3) hide show
  1. app.py +174 -78
  2. models/stutter_detector_all_types.pth +2 -2
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,89 +1,196 @@
1
  from flask import Flask, render_template, request, jsonify
2
- from flask_cors import CORS # <--- ADDED THIS
3
  import os
4
  from werkzeug.utils import secure_filename
5
- import time
6
- import sys
7
- import subprocess
8
  import shutil
 
 
9
 
10
- # Ensure we can find the models folder
11
- sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
12
-
13
- # Lazy-import model dependencies
14
- try:
15
- from models.stutter_detector_local import ImprovedStutterDetector, calculate_stutter_severity
16
- MODEL_DEPS_AVAILABLE = True
17
- except Exception as e:
18
- print(f"Model dependencies unavailable: {e}")
19
- ImprovedStutterDetector = None
20
- def calculate_stutter_severity(_):
21
- return None
22
- MODEL_DEPS_AVAILABLE = False
23
 
 
 
 
24
  app = Flask(__name__)
25
- CORS(app) # <--- ENABLE CORS HERE
26
 
27
- # config
28
- UPLOAD_FOLDER = '/tmp/uploads' # Use /tmp because other folders might be read-only on HF
29
  ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'}
30
- MODEL_PATH = 'models/stutter_detector_all_types.pth'
 
31
 
32
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
33
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
34
 
35
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
36
 
37
- # Check ffmpeg
38
- FFMPEG_AVAILABLE = shutil.which('ffmpeg') is not None
39
-
40
- # Load Model
41
- print("Loading stutter detection model...")
42
- detector = None
43
- if MODEL_DEPS_AVAILABLE and ImprovedStutterDetector is not None:
44
- try:
45
- detector = ImprovedStutterDetector(MODEL_PATH, device='cpu') # Force CPU
46
- print("Model loaded successfully!")
47
- except Exception as e:
48
- print(f"Error loading model: {e}")
49
- detector = None
50
- else:
51
- print("Skipping model load due to missing dependencies.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def allowed_file(filename):
54
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
55
 
56
- def convert_to_wav(input_path, output_path):
57
- try:
58
- subprocess.run([
59
- 'ffmpeg', '-i', input_path,
60
- '-acodec', 'pcm_s16le',
61
- '-ar', '16000',
62
- '-ac', '1',
63
- '-y',
64
- output_path
65
- ], check=True, capture_output=True)
66
- return True
67
- except subprocess.CalledProcessError as e:
68
- print(f"FFmpeg conversion error: {e.stderr.decode()}")
69
- return False
70
- except FileNotFoundError:
71
- return False
72
-
73
- # ... [Keep your analyze_audio_file function exactly as it is] ...
74
- # ... [Paste the analyze_audio_file function here from your original code] ...
75
-
76
- # Route: Only strictly necessary endpoints for your API
77
  @app.route('/health', methods=['GET'])
78
  def health_check():
79
- return jsonify({'status': 'healthy', 'model_loaded': detector is not None}), 200
 
 
 
80
 
81
  @app.route('/upload', methods=['POST'])
82
  def upload_file():
83
- # ... [Keep your existing upload logic exactly as it is] ...
84
- # Just ensure you use the analyze_audio_file function defined above
85
- if detector is None:
86
- return jsonify({'error': 'Model not loaded.'}), 500
87
 
88
  if 'audio' not in request.files:
89
  return jsonify({'error': 'No audio file provided'}), 400
@@ -95,20 +202,9 @@ def upload_file():
95
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
96
  file.save(filepath)
97
 
98
- # Get params from request or default
99
- segment_duration = float(request.form.get('segment_duration', 3.0))
100
- stutter_threshold = float(request.form.get('stutter_threshold', 0.5))
101
-
102
  try:
103
- # Run analysis
104
- results = detector.analyze_audio_file(
105
- filepath,
106
- segment_duration=segment_duration,
107
- stutter_threshold=stutter_threshold
108
- )
109
-
110
- # Calculate severity
111
- results['severity'] = calculate_stutter_severity(results)
112
 
113
  # Cleanup
114
  if os.path.exists(filepath): os.remove(filepath)
@@ -116,11 +212,11 @@ def upload_file():
116
  return jsonify(results), 200
117
 
118
  except Exception as e:
 
119
  if os.path.exists(filepath): os.remove(filepath)
120
  return jsonify({'error': str(e)}), 500
121
 
122
- return jsonify({'error': 'Invalid file'}), 400
123
 
124
  if __name__ == '__main__':
125
- # This is only for local testing, Docker uses Gunicorn
126
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, render_template, request, jsonify
2
+ from flask_cors import CORS
3
  import os
4
  from werkzeug.utils import secure_filename
 
 
 
5
  import shutil
6
+ import subprocess
7
+ import sys
8
 
9
+ # --- ML Imports ---
10
+ import torch
11
+ from torch import nn
12
+ import numpy as np
13
+ import librosa
14
+ import transformers
 
 
 
 
 
 
 
15
 
16
+ # ======================================================
17
+ # 1. CONFIGURATION
18
+ # ======================================================
19
  app = Flask(__name__)
20
+ CORS(app)
21
 
22
+ UPLOAD_FOLDER = '/tmp/uploads' # /tmp for read-write permissions on HF Spaces
 
23
  ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'}
24
+ # UPDATE THIS TO YOUR NEW FILENAME
25
+ MODEL_PATH = 'models/stutter_detector_attentive_augmented.pth'
26
 
27
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
28
+ app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB limit
29
 
30
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
31
 
32
+ # ======================================================
33
+ # 2. NEW MODEL ARCHITECTURE
34
+ # ======================================================
35
+
36
+ class AttentiveStatsPool(nn.Module):
37
+ def __init__(self, in_dim, use_std=True):
38
+ super().__init__()
39
+ self.use_std = use_std
40
+ self.att = nn.Sequential(
41
+ nn.Linear(in_dim, in_dim // 2),
42
+ nn.Tanh(),
43
+ nn.Linear(in_dim // 2, 1)
44
+ )
45
+
46
+ def forward(self, H, mask=None):
47
+ alpha = torch.softmax(self.att(H), dim=1)
48
+ mean = (alpha * H).sum(dim=1)
49
+ ex2 = (alpha * (H ** 2)).sum(dim=1)
50
+ std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6))
51
+ return torch.cat([mean, std], dim=-1)
52
+
53
+ class Wav2VecAttentiveClassifier(nn.Module):
54
+ def __init__(self, hidden_dim=768, output_dim=2):
55
+ super().__init__()
56
+ # Load base wav2vec model
57
+ self.wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
58
+
59
+ # Freeze wav2vec weights
60
+ for p in self.wav2vec.parameters():
61
+ p.requires_grad = False
62
+
63
+ self.pool = AttentiveStatsPool(hidden_dim, use_std=True)
64
+
65
+ self.classifier = nn.Sequential(
66
+ nn.Linear(hidden_dim * 2, 256),
67
+ nn.BatchNorm1d(256),
68
+ nn.ReLU(),
69
+ nn.Dropout(0.3),
70
+
71
+ nn.Linear(256, 128),
72
+ nn.BatchNorm1d(128),
73
+ nn.ReLU(),
74
+ nn.Dropout(0.3),
75
+
76
+ nn.Linear(128, output_dim)
77
+ )
78
+
79
+ def forward(self, x):
80
+ H = self.wav2vec(x).last_hidden_state
81
+ z = self.pool(H)
82
+ return self.classifier(z)
83
+
84
+ # ======================================================
85
+ # 3. INFERENCE HANDLER CLASS
86
+ # ======================================================
87
+
88
+ class StutterInferenceService:
89
+ def __init__(self, model_path):
90
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
+ print(f"Loading model using device: {self.device}")
92
+
93
+ try:
94
+ self.model = Wav2VecAttentiveClassifier()
95
+ # Load weights
96
+ state_dict = torch.load(model_path, map_location=self.device)
97
+ self.model.load_state_dict(state_dict)
98
+ self.model.to(self.device)
99
+ self.model.eval()
100
+ print("Model loaded successfully.")
101
+ self.loaded = True
102
+ except Exception as e:
103
+ print(f"CRITICAL ERROR LOADING MODEL: {e}")
104
+ self.loaded = False
105
+
106
+ def analyze(self, file_path):
107
+ if not self.loaded:
108
+ raise Exception("Model not loaded")
109
+
110
+ # 1. Load Audio (Force 16kHz)
111
+ audio, sr = librosa.load(file_path, sr=16000)
112
+
113
+ # 2. Define segment parameters
114
+ SEGMENT_LENGTH = 48000 # 3 seconds exactly
115
+
116
+ # 3. Handle Short Audio (Pad if < 3 sec)
117
+ if len(audio) < SEGMENT_LENGTH:
118
+ padding = SEGMENT_LENGTH - len(audio)
119
+ audio = np.pad(audio, (0, padding), 'constant')
120
+
121
+ # 4. Process Logic (Sliding Window or Chunks)
122
+ # We will slice the audio into non-overlapping 3s chunks
123
+ num_chunks = len(audio) // SEGMENT_LENGTH
124
+ results = []
125
+
126
+ stutter_count = 0
127
+
128
+ for i in range(num_chunks):
129
+ start = i * SEGMENT_LENGTH
130
+ end = start + SEGMENT_LENGTH
131
+ chunk = audio[start:end]
132
+
133
+ # Preprocess for Pytorch
134
+ # Shape: (1, 48000)
135
+ tensor_input = torch.tensor(chunk).float().unsqueeze(0).to(self.device)
136
+
137
+ with torch.no_grad():
138
+ logits = self.model(tensor_input)
139
+ probs = torch.softmax(logits, dim=1)
140
+
141
+ # Assuming Class 1 = Stutter, Class 0 = Fluent
142
+ # Check your training labels! Usually 1 is the positive class.
143
+ stutter_prob = probs[0][1].item()
144
+ prediction = torch.argmax(probs, dim=1).item()
145
+
146
+ label = "Stutter" if prediction == 1 else "Fluent"
147
+ if prediction == 1:
148
+ stutter_count += 1
149
+
150
+ results.append({
151
+ "segment_id": i,
152
+ "start_time": i * 3.0,
153
+ "end_time": (i + 1) * 3.0,
154
+ "probability": float(stutter_prob),
155
+ "label": label
156
+ })
157
+
158
+ # 5. Calculate Overall Severity
159
+ total_segments = len(results) if len(results) > 0 else 1
160
+ severity_score = (stutter_count / total_segments) * 100
161
+
162
+ severity_label = "Normal"
163
+ if severity_score > 10: severity_label = "Mild"
164
+ if severity_score > 30: severity_label = "Moderate"
165
+ if severity_score > 60: severity_label = "Severe"
166
+
167
+ return {
168
+ "overall_severity": severity_label,
169
+ "stutter_percentage": round(severity_score, 2),
170
+ "details": results
171
+ }
172
+
173
+ # ======================================================
174
+ # 4. APP LOGIC
175
+ # ======================================================
176
+
177
+ # Initialize Service Global
178
+ detector_service = StutterInferenceService(MODEL_PATH)
179
 
180
  def allowed_file(filename):
181
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  @app.route('/health', methods=['GET'])
184
  def health_check():
185
+ return jsonify({
186
+ 'status': 'healthy',
187
+ 'model_loaded': detector_service.loaded
188
+ }), 200
189
 
190
  @app.route('/upload', methods=['POST'])
191
  def upload_file():
192
+ if not detector_service.loaded:
193
+ return jsonify({'error': 'Model failed to load on server start.'}), 500
 
 
194
 
195
  if 'audio' not in request.files:
196
  return jsonify({'error': 'No audio file provided'}), 400
 
202
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
203
  file.save(filepath)
204
 
 
 
 
 
205
  try:
206
+ # Run analysis using the new service
207
+ results = detector_service.analyze(filepath)
 
 
 
 
 
 
 
208
 
209
  # Cleanup
210
  if os.path.exists(filepath): os.remove(filepath)
 
212
  return jsonify(results), 200
213
 
214
  except Exception as e:
215
+ print(f"Error processing file: {e}")
216
  if os.path.exists(filepath): os.remove(filepath)
217
  return jsonify({'error': str(e)}), 500
218
 
219
+ return jsonify({'error': 'Invalid file type'}), 400
220
 
221
  if __name__ == '__main__':
 
222
  app.run(host='0.0.0.0', port=7860)
models/stutter_detector_all_types.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:819effa4e2727f82295d9f7c7cd2647159ef55001c8f7eda7cb009130d45ea56
3
- size 378509393
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da5f2977f33bfac1112d588e3e5e187f0e600af13190eb9fb9bc0b3411eb6ada
3
+ size 380482371
requirements.txt CHANGED
@@ -8,4 +8,4 @@ librosa
8
  soundfile
9
  numpy
10
  scipy
11
- gunicorn
 
8
  soundfile
9
  numpy
10
  scipy
11
+ gunicorn