Spaces:
Running
Running
abdu-l7hman
commited on
Commit
Β·
1154abd
0
Parent(s):
Initial commit with model and app
Browse files- .dockerignore +10 -0
- .gitattributes +1 -0
- Dockerfile +32 -0
- app.py +126 -0
- models/__pycache__/stutter_detector_local.cpython-312.pyc +0 -0
- models/__pycache__/stutter_detector_local.cpython-314.pyc +0 -0
- models/stutter_detector_all_types.pth +3 -0
- models/stutter_detector_local.py +300 -0
- requirements.txt +11 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .dockerignore
|
| 2 |
+
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
.DS_Store
|
| 8 |
+
.env
|
| 9 |
+
venv/
|
| 10 |
+
.git/
|
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
models/*.pth filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile
|
| 2 |
+
|
| 3 |
+
# 1. Base Image
|
| 4 |
+
FROM python:3.9-slim
|
| 5 |
+
|
| 6 |
+
# 2. Install system dependencies (ffmpeg, libsndfile)
|
| 7 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 8 |
+
libsndfile1 \
|
| 9 |
+
ffmpeg \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# 3. Create a non-root user (Required for Hugging Face Spaces)
|
| 13 |
+
RUN useradd -m -u 1000 user
|
| 14 |
+
USER user
|
| 15 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 16 |
+
|
| 17 |
+
# 4. Set working directory
|
| 18 |
+
WORKDIR /home/user/app
|
| 19 |
+
|
| 20 |
+
# 5. Copy requirements and install
|
| 21 |
+
COPY --chown=user requirements.txt .
|
| 22 |
+
RUN pip install --no-cache-dir --timeout=600 -r requirements.txt
|
| 23 |
+
|
| 24 |
+
# 6. Copy the application code
|
| 25 |
+
COPY --chown=user . .
|
| 26 |
+
|
| 27 |
+
# 7. Expose the correct port for Hugging Face
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# 8. Run with Gunicorn on port 7860
|
| 31 |
+
# Note: Increased timeout to 120s because audio processing on CPU can be slow
|
| 32 |
+
CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "1", "--timeout", "120", "app:app"]
|
app.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, render_template, request, jsonify
|
| 2 |
+
from flask_cors import CORS # <--- ADDED THIS
|
| 3 |
+
import os
|
| 4 |
+
from werkzeug.utils import secure_filename
|
| 5 |
+
import time
|
| 6 |
+
import sys
|
| 7 |
+
import subprocess
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
# Ensure we can find the models folder
|
| 11 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
|
| 12 |
+
|
| 13 |
+
# Lazy-import model dependencies
|
| 14 |
+
try:
|
| 15 |
+
from models.stutter_detector_local import ImprovedStutterDetector, calculate_stutter_severity
|
| 16 |
+
MODEL_DEPS_AVAILABLE = True
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Model dependencies unavailable: {e}")
|
| 19 |
+
ImprovedStutterDetector = None
|
| 20 |
+
def calculate_stutter_severity(_):
|
| 21 |
+
return None
|
| 22 |
+
MODEL_DEPS_AVAILABLE = False
|
| 23 |
+
|
| 24 |
+
app = Flask(__name__)
|
| 25 |
+
CORS(app) # <--- ENABLE CORS HERE
|
| 26 |
+
|
| 27 |
+
# config
|
| 28 |
+
UPLOAD_FOLDER = '/tmp/uploads' # Use /tmp because other folders might be read-only on HF
|
| 29 |
+
ALLOWED_EXTENSIONS = {'wav', 'mp3', 'ogg', 'webm', 'm4a'}
|
| 30 |
+
MODEL_PATH = 'models/stutter_detector_all_types.pth'
|
| 31 |
+
|
| 32 |
+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
| 33 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
| 34 |
+
|
| 35 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# Check ffmpeg
|
| 38 |
+
FFMPEG_AVAILABLE = shutil.which('ffmpeg') is not None
|
| 39 |
+
|
| 40 |
+
# Load Model
|
| 41 |
+
print("Loading stutter detection model...")
|
| 42 |
+
detector = None
|
| 43 |
+
if MODEL_DEPS_AVAILABLE and ImprovedStutterDetector is not None:
|
| 44 |
+
try:
|
| 45 |
+
detector = ImprovedStutterDetector(MODEL_PATH, device='cpu') # Force CPU
|
| 46 |
+
print("Model loaded successfully!")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error loading model: {e}")
|
| 49 |
+
detector = None
|
| 50 |
+
else:
|
| 51 |
+
print("Skipping model load due to missing dependencies.")
|
| 52 |
+
|
| 53 |
+
def allowed_file(filename):
|
| 54 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 55 |
+
|
| 56 |
+
def convert_to_wav(input_path, output_path):
|
| 57 |
+
try:
|
| 58 |
+
subprocess.run([
|
| 59 |
+
'ffmpeg', '-i', input_path,
|
| 60 |
+
'-acodec', 'pcm_s16le',
|
| 61 |
+
'-ar', '16000',
|
| 62 |
+
'-ac', '1',
|
| 63 |
+
'-y',
|
| 64 |
+
output_path
|
| 65 |
+
], check=True, capture_output=True)
|
| 66 |
+
return True
|
| 67 |
+
except subprocess.CalledProcessError as e:
|
| 68 |
+
print(f"FFmpeg conversion error: {e.stderr.decode()}")
|
| 69 |
+
return False
|
| 70 |
+
except FileNotFoundError:
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
# ... [Keep your analyze_audio_file function exactly as it is] ...
|
| 74 |
+
# ... [Paste the analyze_audio_file function here from your original code] ...
|
| 75 |
+
|
| 76 |
+
# Route: Only strictly necessary endpoints for your API
|
| 77 |
+
@app.route('/health', methods=['GET'])
|
| 78 |
+
def health_check():
|
| 79 |
+
return jsonify({'status': 'healthy', 'model_loaded': detector is not None}), 200
|
| 80 |
+
|
| 81 |
+
@app.route('/upload', methods=['POST'])
|
| 82 |
+
def upload_file():
|
| 83 |
+
# ... [Keep your existing upload logic exactly as it is] ...
|
| 84 |
+
# Just ensure you use the analyze_audio_file function defined above
|
| 85 |
+
if detector is None:
|
| 86 |
+
return jsonify({'error': 'Model not loaded.'}), 500
|
| 87 |
+
|
| 88 |
+
if 'audio' not in request.files:
|
| 89 |
+
return jsonify({'error': 'No audio file provided'}), 400
|
| 90 |
+
|
| 91 |
+
file = request.files['audio']
|
| 92 |
+
|
| 93 |
+
if file and allowed_file(file.filename):
|
| 94 |
+
filename = secure_filename(file.filename)
|
| 95 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 96 |
+
file.save(filepath)
|
| 97 |
+
|
| 98 |
+
# Get params from request or default
|
| 99 |
+
segment_duration = float(request.form.get('segment_duration', 3.0))
|
| 100 |
+
stutter_threshold = float(request.form.get('stutter_threshold', 0.5))
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Run analysis
|
| 104 |
+
results = detector.analyze_audio_file(
|
| 105 |
+
filepath,
|
| 106 |
+
segment_duration=segment_duration,
|
| 107 |
+
stutter_threshold=stutter_threshold
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Calculate severity
|
| 111 |
+
results['severity'] = calculate_stutter_severity(results)
|
| 112 |
+
|
| 113 |
+
# Cleanup
|
| 114 |
+
if os.path.exists(filepath): os.remove(filepath)
|
| 115 |
+
|
| 116 |
+
return jsonify(results), 200
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
if os.path.exists(filepath): os.remove(filepath)
|
| 120 |
+
return jsonify({'error': str(e)}), 500
|
| 121 |
+
|
| 122 |
+
return jsonify({'error': 'Invalid file'}), 400
|
| 123 |
+
|
| 124 |
+
if __name__ == '__main__':
|
| 125 |
+
# This is only for local testing, Docker uses Gunicorn
|
| 126 |
+
app.run(host='0.0.0.0', port=7860)
|
models/__pycache__/stutter_detector_local.cpython-312.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
models/__pycache__/stutter_detector_local.cpython-314.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
models/stutter_detector_all_types.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:819effa4e2727f82295d9f7c7cd2647159ef55001c8f7eda7cb009130d45ea56
|
| 3 |
+
size 378509393
|
models/stutter_detector_local.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================
|
| 2 |
+
# LOCAL PC STUTTER DETECTION SETUP
|
| 3 |
+
# Run this on your local machine
|
| 4 |
+
# ============================================
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import librosa
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import transformers
|
| 12 |
+
from typing import List, Tuple
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings('ignore')
|
| 15 |
+
|
| 16 |
+
# ============================================
|
| 17 |
+
# MODEL ARCHITECTURE (MUST MATCH TRAINING)
|
| 18 |
+
# ============================================
|
| 19 |
+
|
| 20 |
+
class ImprovedWav2VecClassifier(nn.Module):
|
| 21 |
+
"""Improved classifier matching training architecture."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, hidden_dim=768, intermediate_dim=256, output_dim=2, dropout=0.3):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Load pre-trained Wav2Vec model
|
| 27 |
+
self.wav2vec = transformers.Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
|
| 28 |
+
|
| 29 |
+
# Freeze Wav2Vec parameters
|
| 30 |
+
for param in self.wav2vec.parameters():
|
| 31 |
+
param.requires_grad = False
|
| 32 |
+
|
| 33 |
+
# Classification head
|
| 34 |
+
self.classifier = nn.Sequential(
|
| 35 |
+
nn.Linear(hidden_dim, intermediate_dim),
|
| 36 |
+
nn.BatchNorm1d(intermediate_dim),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(intermediate_dim, intermediate_dim // 2),
|
| 40 |
+
nn.BatchNorm1d(intermediate_dim // 2),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.Dropout(dropout),
|
| 43 |
+
nn.Linear(intermediate_dim // 2, output_dim)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
encoder_output = self.wav2vec(x).last_hidden_state
|
| 49 |
+
pooled_features = encoder_output.mean(dim=1)
|
| 50 |
+
return self.classifier(pooled_features)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ============================================
|
| 54 |
+
# FEATURE EXTRACTOR
|
| 55 |
+
# ============================================
|
| 56 |
+
|
| 57 |
+
class Wav2VecFeatureExtractor:
|
| 58 |
+
"""Extract features from audio files."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, model_name='facebook/wav2vec2-base', duration=3):
|
| 61 |
+
self.processor = transformers.Wav2Vec2FeatureExtractor.from_pretrained(model_name)
|
| 62 |
+
self.duration = duration
|
| 63 |
+
self.sample_rate = 16000
|
| 64 |
+
|
| 65 |
+
def extract_features(self, audio_data, sr):
|
| 66 |
+
try:
|
| 67 |
+
if sr != self.sample_rate:
|
| 68 |
+
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=self.sample_rate)
|
| 69 |
+
|
| 70 |
+
features = self.processor(audio_data, sampling_rate=self.sample_rate, return_tensors='pt').input_values
|
| 71 |
+
return features.squeeze(0)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error extracting features: {e}")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ============================================
|
| 78 |
+
# AUDIO PROCESSING FUNCTIONS
|
| 79 |
+
# ============================================
|
| 80 |
+
|
| 81 |
+
def load_audio_file(file_path: str) -> Tuple[np.ndarray, int]:
|
| 82 |
+
"""Load an audio file."""
|
| 83 |
+
try:
|
| 84 |
+
audio_data, sr = librosa.load(file_path, sr=None)
|
| 85 |
+
return audio_data, sr
|
| 86 |
+
except Exception as e:
|
| 87 |
+
raise Exception(f"Error loading audio file: {e}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def segment_audio(audio_data: np.ndarray, sr: int, segment_duration: float = 3.0) -> List[np.ndarray]:
|
| 91 |
+
"""Split audio into fixed-duration segments."""
|
| 92 |
+
segment_samples = int(segment_duration * sr)
|
| 93 |
+
segments = []
|
| 94 |
+
|
| 95 |
+
for i in range(0, len(audio_data), segment_samples):
|
| 96 |
+
segment = audio_data[i:i + segment_samples]
|
| 97 |
+
|
| 98 |
+
if len(segment) >= sr: # At least 1 second
|
| 99 |
+
if len(segment) < segment_samples:
|
| 100 |
+
padding = segment_samples - len(segment)
|
| 101 |
+
segment = np.pad(segment, (0, padding), mode='constant')
|
| 102 |
+
segments.append(segment)
|
| 103 |
+
|
| 104 |
+
return segments
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def pad_or_truncate_features(features: torch.Tensor, max_length: int = 32007) -> torch.Tensor:
|
| 108 |
+
"""Pad or truncate features to match expected input length."""
|
| 109 |
+
if features.size(0) < max_length:
|
| 110 |
+
padding = max_length - features.size(0)
|
| 111 |
+
features = torch.cat([features, torch.zeros(padding)], dim=0)
|
| 112 |
+
elif features.size(0) > max_length:
|
| 113 |
+
features = features[:max_length]
|
| 114 |
+
return features
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ============================================
|
| 118 |
+
# STUTTER DETECTOR CLASS
|
| 119 |
+
# ============================================
|
| 120 |
+
|
| 121 |
+
class ImprovedStutterDetector:
|
| 122 |
+
"""Stutter detector for all types: prolongations, blocks, repetitions, interjections."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, model_path: str, device: str = None):
|
| 125 |
+
# Set device
|
| 126 |
+
if device is None:
|
| 127 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 128 |
+
else:
|
| 129 |
+
self.device = torch.device(device)
|
| 130 |
+
|
| 131 |
+
print(f"Using device: {self.device}")
|
| 132 |
+
|
| 133 |
+
# Load model
|
| 134 |
+
print("Loading model...")
|
| 135 |
+
self.model = ImprovedWav2VecClassifier()
|
| 136 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
| 137 |
+
self.model.to(self.device)
|
| 138 |
+
self.model.eval()
|
| 139 |
+
print("β Model loaded successfully!")
|
| 140 |
+
|
| 141 |
+
# Initialize feature extractor
|
| 142 |
+
self.feature_extractor = Wav2VecFeatureExtractor(duration=3)
|
| 143 |
+
|
| 144 |
+
# Class names
|
| 145 |
+
self.class_names = ['No Stutter', 'Stutter (All Types)']
|
| 146 |
+
|
| 147 |
+
print("\nThis model detects ALL stutter types:")
|
| 148 |
+
print(" β’ Prolongations (ssssso)")
|
| 149 |
+
print(" β’ Blocks (getting stuck)")
|
| 150 |
+
print(" β’ Sound Repetitions (b-b-ball)")
|
| 151 |
+
print(" β’ Word Repetitions (I-I-I want)")
|
| 152 |
+
print(" β’ Interjections (um, uh)")
|
| 153 |
+
|
| 154 |
+
def analyze_audio_file(self, file_path: str, segment_duration: float = 3.0,
|
| 155 |
+
stutter_threshold: float = 0.5, show_probabilities: bool = True) -> dict:
|
| 156 |
+
"""Analyze an entire audio file for stuttering."""
|
| 157 |
+
|
| 158 |
+
print(f"\n{'='*70}")
|
| 159 |
+
print(f"ANALYZING: {os.path.basename(file_path)}")
|
| 160 |
+
print(f"{'='*70}")
|
| 161 |
+
|
| 162 |
+
# Load audio
|
| 163 |
+
audio_data, sr = load_audio_file(file_path)
|
| 164 |
+
duration = len(audio_data) / sr
|
| 165 |
+
print(f"π Audio duration: {duration:.2f} seconds")
|
| 166 |
+
|
| 167 |
+
# Segment audio
|
| 168 |
+
segments = segment_audio(audio_data, sr, segment_duration)
|
| 169 |
+
print(f"π Number of segments: {len(segments)}")
|
| 170 |
+
|
| 171 |
+
if len(segments) == 0:
|
| 172 |
+
return {'error': 'Audio too short for analysis (minimum 1 second required)'}
|
| 173 |
+
|
| 174 |
+
# Analyze each segment
|
| 175 |
+
results = []
|
| 176 |
+
stutter_count = 0
|
| 177 |
+
|
| 178 |
+
print(f"\n{'='*70}")
|
| 179 |
+
print("SEGMENT ANALYSIS")
|
| 180 |
+
print(f"{'='*70}")
|
| 181 |
+
|
| 182 |
+
for i, segment in enumerate(segments):
|
| 183 |
+
features = self.feature_extractor.extract_features(segment, sr)
|
| 184 |
+
if features is None:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
features = pad_or_truncate_features(features)
|
| 188 |
+
features = features.unsqueeze(0).to(self.device)
|
| 189 |
+
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
outputs = self.model(features)
|
| 192 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 193 |
+
predicted_class = torch.argmax(probabilities, dim=1).item()
|
| 194 |
+
confidence = probabilities[0][predicted_class].item()
|
| 195 |
+
|
| 196 |
+
no_stutter_prob = probabilities[0][0].item()
|
| 197 |
+
stutter_prob = probabilities[0][1].item()
|
| 198 |
+
|
| 199 |
+
is_stutter = stutter_prob >= stutter_threshold
|
| 200 |
+
|
| 201 |
+
results.append({
|
| 202 |
+
'segment': i + 1,
|
| 203 |
+
'prediction': self.class_names[predicted_class],
|
| 204 |
+
'confidence': confidence,
|
| 205 |
+
'is_stutter': is_stutter,
|
| 206 |
+
'no_stutter_probability': no_stutter_prob,
|
| 207 |
+
'stutter_probability': stutter_prob
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
if is_stutter:
|
| 211 |
+
stutter_count += 1
|
| 212 |
+
|
| 213 |
+
if show_probabilities:
|
| 214 |
+
status_emoji = "π΄" if is_stutter else "π’"
|
| 215 |
+
status_text = "STUTTER DETECTED" if is_stutter else "Clear"
|
| 216 |
+
print(f"{status_emoji} Segment {i+1}: {status_text}")
|
| 217 |
+
print(f" No Stutter: {no_stutter_prob:.2%} | Stutter: {stutter_prob:.2%}")
|
| 218 |
+
|
| 219 |
+
# Calculate statistics
|
| 220 |
+
total_segments = len(results)
|
| 221 |
+
stutter_percentage = (stutter_count / total_segments * 100) if total_segments > 0 else 0
|
| 222 |
+
|
| 223 |
+
print(f"\n{'='*70}")
|
| 224 |
+
print("FINAL RESULTS")
|
| 225 |
+
print(f"{'='*70}")
|
| 226 |
+
print(f"β Total segments analyzed: {total_segments}")
|
| 227 |
+
print(f"π΄ Segments with stutter: {stutter_count}")
|
| 228 |
+
print(f"π’ Segments without stutter: {total_segments - stutter_count}")
|
| 229 |
+
print(f"π Stuttering percentage: {stutter_percentage:.1f}%")
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
'file_path': file_path,
|
| 233 |
+
'duration': duration,
|
| 234 |
+
'total_segments': total_segments,
|
| 235 |
+
'stutter_count': stutter_count,
|
| 236 |
+
'no_stutter_count': total_segments - stutter_count,
|
| 237 |
+
'stutter_percentage': stutter_percentage,
|
| 238 |
+
'segment_results': results
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ============================================
|
| 243 |
+
# SEVERITY ANALYSIS
|
| 244 |
+
# ============================================
|
| 245 |
+
|
| 246 |
+
def calculate_stutter_severity(results):
|
| 247 |
+
"""Calculate detailed stutter severity metrics."""
|
| 248 |
+
segment_results = results['segment_results']
|
| 249 |
+
stutter_probs = [seg['stutter_probability'] for seg in segment_results]
|
| 250 |
+
|
| 251 |
+
avg_prob = sum(stutter_probs) / len(stutter_probs)
|
| 252 |
+
max_prob = max(stutter_probs)
|
| 253 |
+
min_prob = min(stutter_probs)
|
| 254 |
+
|
| 255 |
+
# Count segments by severity
|
| 256 |
+
severe = sum(1 for p in stutter_probs if p > 0.6)
|
| 257 |
+
moderate = sum(1 for p in stutter_probs if 0.4 < p <= 0.6)
|
| 258 |
+
mild = sum(1 for p in stutter_probs if 0.2 < p <= 0.4)
|
| 259 |
+
minimal = sum(1 for p in stutter_probs if p <= 0.2)
|
| 260 |
+
|
| 261 |
+
# Calculate severity score as stutters / total segments
|
| 262 |
+
total_segments = results.get('total_segments', 0)
|
| 263 |
+
stutter_count = results.get('stutter_count', 0)
|
| 264 |
+
severity_score = stutter_count / total_segments if total_segments > 0 else 0.0
|
| 265 |
+
|
| 266 |
+
print(f"\n{'='*70}")
|
| 267 |
+
print("DETAILED SEVERITY ANALYSIS")
|
| 268 |
+
print(f"{'='*70}")
|
| 269 |
+
print(f"Average stutter probability: {avg_prob:.2%}")
|
| 270 |
+
print(f"Peak stutter probability: {max_prob:.2%}")
|
| 271 |
+
print(f"Minimum stutter probability: {min_prob:.2%}")
|
| 272 |
+
print(f"\nSegment Severity Breakdown:")
|
| 273 |
+
print(f" π΄ Severe (>70%): {severe} segments")
|
| 274 |
+
print(f" π Moderate (40-70%): {moderate} segments")
|
| 275 |
+
print(f" π‘ Mild (20-40%): {mild} segments")
|
| 276 |
+
print(f" π’ Minimal (<20%): {minimal} segments")
|
| 277 |
+
|
| 278 |
+
# Overall severity
|
| 279 |
+
if avg_prob < 0.15:
|
| 280 |
+
severity = "β Minimal or No Stuttering"
|
| 281 |
+
elif avg_prob < 0.35:
|
| 282 |
+
severity = "β οΈ Mild Stuttering"
|
| 283 |
+
elif avg_prob < 0.60:
|
| 284 |
+
severity = "β οΈ Moderate Stuttering"
|
| 285 |
+
else:
|
| 286 |
+
severity = "π΄ Significant Stuttering"
|
| 287 |
+
|
| 288 |
+
print(f"\nπ― Overall Assessment: {severity}")
|
| 289 |
+
print(f"π Severity Score: {severity_score:.2%} (stutters/total segments)")
|
| 290 |
+
|
| 291 |
+
return {
|
| 292 |
+
'average_probability': avg_prob,
|
| 293 |
+
'max_probability': max_prob,
|
| 294 |
+
'severity_level': severity,
|
| 295 |
+
'severity_score': severity_score,
|
| 296 |
+
'severe_segments': severe,
|
| 297 |
+
'moderate_segments': moderate,
|
| 298 |
+
'mild_segments': mild,
|
| 299 |
+
'minimal_segments': minimal
|
| 300 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Flask
|
| 2 |
+
flask-cors
|
| 3 |
+
Werkzeug
|
| 4 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 5 |
+
torch
|
| 6 |
+
transformers
|
| 7 |
+
librosa
|
| 8 |
+
soundfile
|
| 9 |
+
numpy
|
| 10 |
+
scipy
|
| 11 |
+
gunicorn
|