Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify | |
| from flask_cors import CORS | |
| import os | |
| from werkzeug.utils import secure_filename | |
| import shutil | |
| import subprocess | |
| import sys | |
| # --- ML Imports --- | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import librosa | |
| import transformers | |
| # ====================================================== | |
| # 1. CONFIGURATION | |
| # ====================================================== | |
| app = Flask(__name__) | |
| CORS(app) | |
| UPLOAD_FOLDER = '/tmp/uploads' # /tmp for read-write permissions on HF Spaces | |
| ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'} | |
| # UPDATE THIS TO YOUR NEW FILENAME | |
| MODEL_PATH = 'models/stutter_detector_attentive_augmented.pth' | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB limit | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| # ====================================================== | |
| # 2. NEW MODEL ARCHITECTURE | |
| # ====================================================== | |
| class AttentiveStatsPool(nn.Module): | |
| def __init__(self, in_dim, use_std=True): | |
| super().__init__() | |
| self.use_std = use_std | |
| self.att = nn.Sequential( | |
| nn.Linear(in_dim, in_dim // 2), | |
| nn.Tanh(), | |
| nn.Linear(in_dim // 2, 1) | |
| ) | |
| def forward(self, H, mask=None): | |
| alpha = torch.softmax(self.att(H), dim=1) | |
| mean = (alpha * H).sum(dim=1) | |
| ex2 = (alpha * (H ** 2)).sum(dim=1) | |
| std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6)) | |
| return torch.cat([mean, std], dim=-1) | |
| class Wav2VecAttentiveClassifier(nn.Module): | |
| def __init__(self, hidden_dim=768, output_dim=2): | |
| super().__init__() | |
| # Load base wav2vec model | |
| self.wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| # Freeze wav2vec weights | |
| for p in self.wav2vec.parameters(): | |
| p.requires_grad = False | |
| self.pool = AttentiveStatsPool(hidden_dim, use_std=True) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim * 2, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, output_dim) | |
| ) | |
| def forward(self, x): | |
| H = self.wav2vec(x).last_hidden_state | |
| z = self.pool(H) | |
| return self.classifier(z) | |
| # ====================================================== | |
| # 3. INFERENCE HANDLER CLASS | |
| # ====================================================== | |
| class StutterInferenceService: | |
| def __init__(self, model_path): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading model using device: {self.device}") | |
| try: | |
| self.model = Wav2VecAttentiveClassifier() | |
| # Load weights | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Model loaded successfully.") | |
| self.loaded = True | |
| except Exception as e: | |
| print(f"CRITICAL ERROR LOADING MODEL: {e}") | |
| self.loaded = False | |
| def analyze(self, file_path): | |
| if not self.loaded: | |
| raise Exception("Model not loaded") | |
| # 1. Load Audio (Force 16kHz) | |
| audio, sr = librosa.load(file_path, sr=16000) | |
| # 2. Define segment parameters | |
| SEGMENT_LENGTH = 48000 # 3 seconds exactly | |
| # 3. Handle Short Audio (Pad if < 3 sec) | |
| if len(audio) < SEGMENT_LENGTH: | |
| padding = SEGMENT_LENGTH - len(audio) | |
| audio = np.pad(audio, (0, padding), 'constant') | |
| # 4. Process Logic (Sliding Window or Chunks) | |
| # We will slice the audio into non-overlapping 3s chunks | |
| num_chunks = len(audio) // SEGMENT_LENGTH | |
| results = [] | |
| stutter_count = 0 | |
| for i in range(num_chunks): | |
| start = i * SEGMENT_LENGTH | |
| end = start + SEGMENT_LENGTH | |
| chunk = audio[start:end] | |
| # Preprocess for Pytorch | |
| # Shape: (1, 48000) | |
| tensor_input = torch.tensor(chunk).float().unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(tensor_input) | |
| probs = torch.softmax(logits, dim=1) | |
| # Assuming Class 1 = Stutter, Class 0 = Fluent | |
| # Check your training labels! Usually 1 is the positive class. | |
| stutter_prob = probs[0][1].item() | |
| prediction = torch.argmax(probs, dim=1).item() | |
| label = "Stutter" if prediction == 1 else "Fluent" | |
| if prediction == 1: | |
| stutter_count += 1 | |
| results.append({ | |
| "segment_id": i, | |
| "start_time": i * 3.0, | |
| "end_time": (i + 1) * 3.0, | |
| "probability": float(stutter_prob), | |
| "label": label | |
| }) | |
| # 5. Calculate Overall Severity | |
| total_segments = len(results) if len(results) > 0 else 1 | |
| severity_score = (stutter_count / total_segments) * 100 | |
| severity_label = "Normal" | |
| if severity_score > 10: severity_label = "Mild" | |
| if severity_score > 30: severity_label = "Moderate" | |
| if severity_score > 60: severity_label = "Severe" | |
| return { | |
| "overall_severity": severity_label, | |
| "stutter_percentage": round(severity_score, 2), | |
| "details": results | |
| } | |
| # ====================================================== | |
| # 4. APP LOGIC | |
| # ====================================================== | |
| # Initialize Service Global | |
| detector_service = StutterInferenceService(MODEL_PATH) | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def health_check(): | |
| return jsonify({ | |
| 'status': 'healthy', | |
| 'model_loaded': detector_service.loaded | |
| }), 200 | |
| def upload_file(): | |
| if not detector_service.loaded: | |
| return jsonify({'error': 'Model failed to load on server start.'}), 500 | |
| if 'audio' not in request.files: | |
| return jsonify({'error': 'No audio file provided'}), 400 | |
| file = request.files['audio'] | |
| if file and allowed_file(file.filename): | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| try: | |
| # Run analysis using the new service | |
| results = detector_service.analyze(filepath) | |
| # Cleanup | |
| if os.path.exists(filepath): os.remove(filepath) | |
| return jsonify(results), 200 | |
| except Exception as e: | |
| print(f"Error processing file: {e}") | |
| if os.path.exists(filepath): os.remove(filepath) | |
| return jsonify({'error': str(e)}), 500 | |
| return jsonify({'error': 'Invalid file type'}), 400 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |