abdu-l7hman
Update model to augmented version and add ffmpeg
8d3d1a5
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import os
from werkzeug.utils import secure_filename
import shutil
import subprocess
import sys
# --- ML Imports ---
import torch
from torch import nn
import numpy as np
import librosa
import transformers
# ======================================================
# 1. CONFIGURATION
# ======================================================
app = Flask(__name__)
CORS(app)
UPLOAD_FOLDER = '/tmp/uploads' # /tmp for read-write permissions on HF Spaces
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'}
# UPDATE THIS TO YOUR NEW FILENAME
MODEL_PATH = 'models/stutter_detector_attentive_augmented.pth'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 50MB limit
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# ======================================================
# 2. NEW MODEL ARCHITECTURE
# ======================================================
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, use_std=True):
super().__init__()
self.use_std = use_std
self.att = nn.Sequential(
nn.Linear(in_dim, in_dim // 2),
nn.Tanh(),
nn.Linear(in_dim // 2, 1)
)
def forward(self, H, mask=None):
alpha = torch.softmax(self.att(H), dim=1)
mean = (alpha * H).sum(dim=1)
ex2 = (alpha * (H ** 2)).sum(dim=1)
std = torch.sqrt(torch.clamp(ex2 - mean**2, min=1e-6))
return torch.cat([mean, std], dim=-1)
class Wav2VecAttentiveClassifier(nn.Module):
def __init__(self, hidden_dim=768, output_dim=2):
super().__init__()
# Load base wav2vec model
self.wav2vec = transformers.Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
# Freeze wav2vec weights
for p in self.wav2vec.parameters():
p.requires_grad = False
self.pool = AttentiveStatsPool(hidden_dim, use_std=True)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, output_dim)
)
def forward(self, x):
H = self.wav2vec(x).last_hidden_state
z = self.pool(H)
return self.classifier(z)
# ======================================================
# 3. INFERENCE HANDLER CLASS
# ======================================================
class StutterInferenceService:
def __init__(self, model_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading model using device: {self.device}")
try:
self.model = Wav2VecAttentiveClassifier()
# Load weights
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
print("Model loaded successfully.")
self.loaded = True
except Exception as e:
print(f"CRITICAL ERROR LOADING MODEL: {e}")
self.loaded = False
def analyze(self, file_path):
if not self.loaded:
raise Exception("Model not loaded")
# 1. Load Audio (Force 16kHz)
audio, sr = librosa.load(file_path, sr=16000)
# 2. Define segment parameters
SEGMENT_LENGTH = 48000 # 3 seconds exactly
# 3. Handle Short Audio (Pad if < 3 sec)
if len(audio) < SEGMENT_LENGTH:
padding = SEGMENT_LENGTH - len(audio)
audio = np.pad(audio, (0, padding), 'constant')
# 4. Process Logic (Sliding Window or Chunks)
# We will slice the audio into non-overlapping 3s chunks
num_chunks = len(audio) // SEGMENT_LENGTH
results = []
stutter_count = 0
for i in range(num_chunks):
start = i * SEGMENT_LENGTH
end = start + SEGMENT_LENGTH
chunk = audio[start:end]
# Preprocess for Pytorch
# Shape: (1, 48000)
tensor_input = torch.tensor(chunk).float().unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(tensor_input)
probs = torch.softmax(logits, dim=1)
# Assuming Class 1 = Stutter, Class 0 = Fluent
# Check your training labels! Usually 1 is the positive class.
stutter_prob = probs[0][1].item()
prediction = torch.argmax(probs, dim=1).item()
label = "Stutter" if prediction == 1 else "Fluent"
if prediction == 1:
stutter_count += 1
results.append({
"segment_id": i,
"start_time": i * 3.0,
"end_time": (i + 1) * 3.0,
"probability": float(stutter_prob),
"label": label
})
# 5. Calculate Overall Severity
total_segments = len(results) if len(results) > 0 else 1
severity_score = (stutter_count / total_segments) * 100
severity_label = "Normal"
if severity_score > 10: severity_label = "Mild"
if severity_score > 30: severity_label = "Moderate"
if severity_score > 60: severity_label = "Severe"
return {
"overall_severity": severity_label,
"stutter_percentage": round(severity_score, 2),
"details": results
}
# ======================================================
# 4. APP LOGIC
# ======================================================
# Initialize Service Global
detector_service = StutterInferenceService(MODEL_PATH)
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
'status': 'healthy',
'model_loaded': detector_service.loaded
}), 200
@app.route('/upload', methods=['POST'])
def upload_file():
if not detector_service.loaded:
return jsonify({'error': 'Model failed to load on server start.'}), 500
if 'audio' not in request.files:
return jsonify({'error': 'No audio file provided'}), 400
file = request.files['audio']
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
# Run analysis using the new service
results = detector_service.analyze(filepath)
# Cleanup
if os.path.exists(filepath): os.remove(filepath)
return jsonify(results), 200
except Exception as e:
print(f"Error processing file: {e}")
if os.path.exists(filepath): os.remove(filepath)
return jsonify({'error': str(e)}), 500
return jsonify({'error': 'Invalid file type'}), 400
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)