Rolando
commited on
Commit
·
8718761
1
Parent(s):
e9ccfaf
Set it up
Browse files- .gitignore +129 -0
- LICENSE +21 -0
- README.md +1872 -0
- examples/non-whisper.ipynb +425 -0
- setup.py +40 -0
- silence_suppresion0.png +0 -0
- silence_suppresion1.png +0 -0
- stable_whisper/__init__.py +8 -0
- stable_whisper/__main__.py +3 -0
- stable_whisper/_version.py +1 -0
- stable_whisper/alignment.py +1265 -0
- stable_whisper/audio.py +288 -0
- stable_whisper/decode.py +109 -0
- stable_whisper/non_whisper.py +348 -0
- stable_whisper/quantization.py +40 -0
- stable_whisper/result.py +2281 -0
- stable_whisper/stabilization.py +424 -0
- stable_whisper/text_output.py +620 -0
- stable_whisper/timing.py +275 -0
- stable_whisper/utils.py +78 -0
- stable_whisper/video_output.py +111 -0
- stable_whisper/whisper_compatibility.py +73 -0
- stable_whisper/whisper_word_level.py +1651 -0
.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 jian
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,1872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stabilizing Timestamps for Whisper
|
2 |
+
|
3 |
+
This library modifies [Whisper](https://github.com/openai/whisper) to produce more reliable timestamps and extends its functionality.
|
4 |
+
|
5 |
+
https://github.com/jianfch/stable-ts/assets/28970749/7adf0540-3620-4b2b-b2d4-e316906d6dfa
|
6 |
+
|
7 |
+
* [Setup](#setup)
|
8 |
+
* [Usage](#usage)
|
9 |
+
* [Transcribe](#transcribe)
|
10 |
+
* [Output](#output)
|
11 |
+
* [Alignment](#alignment)
|
12 |
+
* [Adjustments](#adjustments)
|
13 |
+
* [Refinement](#refinement)
|
14 |
+
* [Regrouping Words](#regrouping-words)
|
15 |
+
* [Editing](#editing)
|
16 |
+
* [Locating Words](#locating-words)
|
17 |
+
* [Silence Suppression](#silence-suppression)
|
18 |
+
* [Tips](#tips)
|
19 |
+
* [Visualizing Suppression](#visualizing-suppression)
|
20 |
+
* [Encode Comparison](#encode-comparison)
|
21 |
+
* [Use with any ASR](#any-asr)
|
22 |
+
* [Quick 1.X → 2.X Guide](#quick-1x--2x-guide)
|
23 |
+
|
24 |
+
## Setup
|
25 |
+
```
|
26 |
+
pip install -U stable-ts
|
27 |
+
```
|
28 |
+
|
29 |
+
To install the latest commit:
|
30 |
+
```
|
31 |
+
pip install -U git+https://github.com/jianfch/stable-ts.git
|
32 |
+
```
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
### Transcribe
|
37 |
+
|
38 |
+
```python
|
39 |
+
import stable_whisper
|
40 |
+
model = stable_whisper.load_model('base')
|
41 |
+
result = model.transcribe('audio.mp3')
|
42 |
+
result.to_srt_vtt('audio.srt')
|
43 |
+
```
|
44 |
+
<details>
|
45 |
+
<summary>CLI</summary>
|
46 |
+
|
47 |
+
```commandline
|
48 |
+
stable-ts audio.mp3 -o audio.srt
|
49 |
+
```
|
50 |
+
</details>
|
51 |
+
|
52 |
+
Docstrings:
|
53 |
+
<details>
|
54 |
+
<summary>load_model()</summary>
|
55 |
+
|
56 |
+
Load an instance if :class:`whisper.model.Whisper`.
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
name : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
|
61 |
+
'large-v2', 'large-v3', or 'large'}
|
62 |
+
One of the official model names listed by :func:`whisper.available_models`, or
|
63 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
64 |
+
device : str or torch.device, optional
|
65 |
+
PyTorch device to put the model into.
|
66 |
+
download_root : str, optional
|
67 |
+
Path to download the model files; by default, it uses "~/.cache/whisper".
|
68 |
+
in_memory : bool, default False
|
69 |
+
Whether to preload the model weights into host memory.
|
70 |
+
cpu_preload : bool, default True
|
71 |
+
Load model into CPU memory first then move model to specified device
|
72 |
+
to reduce GPU memory usage when loading model
|
73 |
+
dq : bool, default False
|
74 |
+
Whether to apply Dynamic Quantization to model to reduced memory usage and increase inference speed
|
75 |
+
but at the cost of a slight decrease in accuracy. Only for CPU.
|
76 |
+
|
77 |
+
Returns
|
78 |
+
-------
|
79 |
+
model : "Whisper"
|
80 |
+
The Whisper ASR model instance.
|
81 |
+
|
82 |
+
Notes
|
83 |
+
-----
|
84 |
+
The overhead from ``dq = True`` might make inference slower for models smaller than 'large'.
|
85 |
+
|
86 |
+
</details>
|
87 |
+
|
88 |
+
<details>
|
89 |
+
<summary>transcribe()</summary>
|
90 |
+
|
91 |
+
Transcribe audio using Whisper.
|
92 |
+
|
93 |
+
This is a modified version of :func:`whisper.transcribe.transcribe` with slightly different decoding logic while
|
94 |
+
allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating
|
95 |
+
voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
|
96 |
+
result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.
|
97 |
+
|
98 |
+
Parameters
|
99 |
+
----------
|
100 |
+
model : whisper.model.Whisper
|
101 |
+
An instance of Whisper ASR model.
|
102 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
103 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
104 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
105 |
+
verbose : bool or None, default False
|
106 |
+
Whether to display the text being decoded to the console.
|
107 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
108 |
+
temperature : float or iterable of float, default (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
109 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
110 |
+
upon failures according to either ``compression_ratio_threshold`` or ``logprob_threshold``.
|
111 |
+
compression_ratio_threshold : float, default 2.4
|
112 |
+
If the gzip compression ratio is above this value, treat as failed.
|
113 |
+
logprob_threshold : float, default -1
|
114 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
115 |
+
no_speech_threshold : float, default 0.6
|
116 |
+
If the no_speech probability is higher than this value AND the average log probability
|
117 |
+
over sampled tokens is below ``logprob_threshold``, consider the segment as silent
|
118 |
+
condition_on_previous_text : bool, default True
|
119 |
+
If ``True``, the previous output of the model is provided as a prompt for the next window;
|
120 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
121 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
122 |
+
initial_prompt : str, optional
|
123 |
+
Text to provide as a prompt for the first window. This can be used to provide, or
|
124 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
125 |
+
to make it more likely to predict those word correctly.
|
126 |
+
word_timestamps : bool, default True
|
127 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
128 |
+
and include the timestamps for each word in each segment.
|
129 |
+
Disabling this will prevent segments from splitting/merging properly.
|
130 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
131 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
132 |
+
Ignored if ``word_timestamps = False``.
|
133 |
+
ts_num : int, default 0, meaning disable this option
|
134 |
+
Number of extra timestamp inferences to perform then use average of these extra timestamps.
|
135 |
+
An experimental option that might hurt performance.
|
136 |
+
ts_noise : float, default 0.1
|
137 |
+
Percentage of noise to add to audio_features to perform inferences for ``ts_num``.
|
138 |
+
suppress_silence : bool, default True
|
139 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
140 |
+
suppress_word_ts : bool, default True
|
141 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
142 |
+
use_word_position : bool, default True
|
143 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
144 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
145 |
+
q_levels : int, default 20
|
146 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
147 |
+
Acts as a threshold to marking sound as silent.
|
148 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
149 |
+
k_size : int, default 5
|
150 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
151 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
152 |
+
time_scale : float, optional
|
153 |
+
Factor for scaling audio duration for inference.
|
154 |
+
Greater than 1.0 'slows down' the audio, and less than 1.0 'speeds up' the audio. None is same as 1.0.
|
155 |
+
A factor of 1.5 will stretch 10s audio to 15s for inference. This increases the effective resolution
|
156 |
+
of the model but can increase word error rate.
|
157 |
+
demucs : bool or torch.nn.Module, default False
|
158 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
159 |
+
a Demucs model to avoid reloading the model for each run.
|
160 |
+
Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
|
161 |
+
demucs_output : str, optional
|
162 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
163 |
+
Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
|
164 |
+
demucs_options : dict, optional
|
165 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
166 |
+
vad : bool, default False
|
167 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
168 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
169 |
+
vad_threshold : float, default 0.35
|
170 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
171 |
+
vad_onnx : bool, default False
|
172 |
+
Whether to use ONNX for Silero VAD.
|
173 |
+
min_word_dur : float, default 0.1
|
174 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
175 |
+
nonspeech_error : float, default 0.3
|
176 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
177 |
+
only_voice_freq : bool, default False
|
178 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
179 |
+
prepend_punctuations : str, default '"\'“¿([{-)'
|
180 |
+
Punctuations to prepend to next word.
|
181 |
+
append_punctuations : str, default '.。,,!!??::”)]}、)'
|
182 |
+
Punctuations to append to previous word.
|
183 |
+
mel_first : bool, default False
|
184 |
+
Process entire audio track into log-Mel spectrogram first instead in chunks.
|
185 |
+
Used if odd behavior seen in stable-ts but not in whisper, but use significantly more memory for long audio.
|
186 |
+
split_callback : Callable, optional
|
187 |
+
Custom callback for grouping tokens up with their corresponding words.
|
188 |
+
The callback must take two arguments, list of tokens and tokenizer.
|
189 |
+
The callback returns a tuple with a list of words and a corresponding nested list of tokens.
|
190 |
+
suppress_ts_tokens : bool, default False
|
191 |
+
Whether to suppress timestamp tokens during inference for timestamps are detected at silent.
|
192 |
+
Reduces hallucinations in some cases, but also prone to ignore disfluencies and repetitions.
|
193 |
+
This option is ignored if ``suppress_silence = False``.
|
194 |
+
gap_padding : str, default ' ...'
|
195 |
+
Padding prepend to each segments for word timing alignment.
|
196 |
+
Used to reduce the probability of model predicting timestamps earlier than the first utterance.
|
197 |
+
only_ffmpeg : bool, default False
|
198 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
199 |
+
max_instant_words : float, default 0.5
|
200 |
+
If percentage of instantaneous words in a segment exceed this amount, the segment is removed.
|
201 |
+
avg_prob_threshold: float or None, default None
|
202 |
+
Transcribe the gap after the previous word and if the average word proababiliy of a segment falls below this
|
203 |
+
value, discard the segment. If ``None``, skip transcribing the gap to reduce chance of timestamps starting
|
204 |
+
before the next utterance.
|
205 |
+
progress_callback : Callable, optional
|
206 |
+
A function that will be called when transcription progress is updated.
|
207 |
+
The callback need two parameters.
|
208 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
209 |
+
The second parameter is a float for total duration of audio in seconds.
|
210 |
+
ignore_compatibility : bool, default False
|
211 |
+
Whether to ignore warnings for compatibility issues with the detected Whisper version.
|
212 |
+
decode_options
|
213 |
+
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
|
214 |
+
|
215 |
+
Returns
|
216 |
+
-------
|
217 |
+
stable_whisper.result.WhisperResult
|
218 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
219 |
+
|
220 |
+
See Also
|
221 |
+
--------
|
222 |
+
stable_whisper.non_whisper.transcribe_any : Return :class:`stable_whisper.result.WhisperResult` containing all the
|
223 |
+
data from transcribing audio with unmodified :func:`whisper.transcribe.transcribe` with preprocessing and
|
224 |
+
postprocessing.
|
225 |
+
stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe : Return
|
226 |
+
:class:`stable_whisper.result.WhisperResult` containing all the data from transcribing audio with
|
227 |
+
:meth:`faster_whisper.WhisperModel.transcribe` with preprocessing and postprocessing.
|
228 |
+
|
229 |
+
Examples
|
230 |
+
--------
|
231 |
+
>>> import stable_whisper
|
232 |
+
>>> model = stable_whisper.load_model('base')
|
233 |
+
>>> result = model.transcribe('audio.mp3', vad=True)
|
234 |
+
>>> result.to_srt_vtt('audio.srt')
|
235 |
+
Saved: audio.srt
|
236 |
+
|
237 |
+
</details>
|
238 |
+
|
239 |
+
<details>
|
240 |
+
<summary>transcribe_minimal()</summary>
|
241 |
+
|
242 |
+
Transcribe audio using Whisper.
|
243 |
+
|
244 |
+
This is uses the original whisper transcribe function, :func:`whisper.transcribe.transcribe`, while still allowing
|
245 |
+
additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating voice /
|
246 |
+
removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
|
247 |
+
result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.
|
248 |
+
|
249 |
+
Parameters
|
250 |
+
----------
|
251 |
+
model : whisper.model.Whisper
|
252 |
+
An instance of Whisper ASR model.
|
253 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
254 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
255 |
+
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
|
256 |
+
verbose : bool or None, default False
|
257 |
+
Whether to display the text being decoded to the console.
|
258 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
259 |
+
word_timestamps : bool, default True
|
260 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
261 |
+
and include the timestamps for each word in each segment.
|
262 |
+
Disabling this will prevent segments from splitting/merging properly.
|
263 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
264 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
265 |
+
Ignored if ``word_timestamps = False``.
|
266 |
+
suppress_silence : bool, default True
|
267 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
268 |
+
suppress_word_ts : bool, default True
|
269 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
270 |
+
use_word_position : bool, default True
|
271 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
272 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
273 |
+
q_levels : int, default 20
|
274 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
275 |
+
Acts as a threshold to marking sound as silent.
|
276 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
277 |
+
k_size : int, default 5
|
278 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
279 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
280 |
+
demucs : bool or torch.nn.Module, default False
|
281 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
282 |
+
a Demucs model to avoid reloading the model for each run.
|
283 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
284 |
+
demucs_output : str, optional
|
285 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
286 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
287 |
+
demucs_options : dict, optional
|
288 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
289 |
+
vad : bool, default False
|
290 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
291 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
292 |
+
vad_threshold : float, default 0.35
|
293 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
294 |
+
vad_onnx : bool, default False
|
295 |
+
Whether to use ONNX for Silero VAD.
|
296 |
+
min_word_dur : float, default 0.1
|
297 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
298 |
+
nonspeech_error : float, default 0.3
|
299 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
300 |
+
only_voice_freq : bool, default False
|
301 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
302 |
+
only_ffmpeg : bool, default False
|
303 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
304 |
+
options
|
305 |
+
Additional options used for :func:`whisper.transcribe.transcribe` and
|
306 |
+
:func:`stable_whisper.non_whisper.transcribe_any`.
|
307 |
+
Returns
|
308 |
+
-------
|
309 |
+
stable_whisper.result.WhisperResult
|
310 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
311 |
+
|
312 |
+
Examples
|
313 |
+
--------
|
314 |
+
>>> import stable_whisper
|
315 |
+
>>> model = stable_whisper.load_model('base')
|
316 |
+
>>> result = model.transcribe_minimal('audio.mp3', vad=True)
|
317 |
+
>>> result.to_srt_vtt('audio.srt')
|
318 |
+
Saved: audio.srt
|
319 |
+
|
320 |
+
</details>
|
321 |
+
|
322 |
+
<br>
|
323 |
+
<details>
|
324 |
+
<summary>faster-whisper</summary>
|
325 |
+
|
326 |
+
Use with [faster-whisper](https://github.com/guillaumekln/faster-whisper):
|
327 |
+
```python
|
328 |
+
model = stable_whisper.load_faster_whisper('base')
|
329 |
+
result = model.transcribe_stable('audio.mp3')
|
330 |
+
```
|
331 |
+
```commandline
|
332 |
+
stable-ts audio.mp3 -o audio.srt -fw
|
333 |
+
```
|
334 |
+
Docstring:
|
335 |
+
<details>
|
336 |
+
<summary>load_faster_whisper()</summary>
|
337 |
+
|
338 |
+
Load an instance of :class:`faster_whisper.WhisperModel`.
|
339 |
+
|
340 |
+
Parameters
|
341 |
+
----------
|
342 |
+
model_size_or_path : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
|
343 |
+
'large-v2', 'large-v3', or 'large'}
|
344 |
+
Size of the model.
|
345 |
+
|
346 |
+
model_init_options
|
347 |
+
Additional options to use for initialization of :class:`faster_whisper.WhisperModel`.
|
348 |
+
|
349 |
+
Returns
|
350 |
+
-------
|
351 |
+
faster_whisper.WhisperModel
|
352 |
+
A modified instance with :func:`stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe`
|
353 |
+
assigned to :meth:`faster_whisper.WhisperModel.transcribe_stable`.
|
354 |
+
|
355 |
+
</details>
|
356 |
+
|
357 |
+
<details>
|
358 |
+
<summary>transcribe_stable()</summary>
|
359 |
+
|
360 |
+
Transcribe audio using faster-whisper (https://github.com/guillaumekln/faster-whisper).
|
361 |
+
|
362 |
+
This is uses the transcribe method from faster-whisper, :meth:`faster_whisper.WhisperModel.transcribe`, while
|
363 |
+
still allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes:
|
364 |
+
isolating voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the
|
365 |
+
transcription result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation
|
366 |
+
and speech gaps.
|
367 |
+
|
368 |
+
Parameters
|
369 |
+
----------
|
370 |
+
model : faster_whisper.WhisperModel
|
371 |
+
The faster-whisper ASR model instance.
|
372 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
373 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
374 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
375 |
+
verbose : bool or None, default False
|
376 |
+
Whether to display the text being decoded to the console.
|
377 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
378 |
+
word_timestamps : bool, default True
|
379 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
380 |
+
and include the timestamps for each word in each segment.
|
381 |
+
Disabling this will prevent segments from splitting/merging properly.
|
382 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
383 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
384 |
+
Ignored if ``word_timestamps = False``.
|
385 |
+
suppress_silence : bool, default True
|
386 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
387 |
+
suppress_word_ts : bool, default True
|
388 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
389 |
+
use_word_position : bool, default True
|
390 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
391 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
392 |
+
q_levels : int, default 20
|
393 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
394 |
+
Acts as a threshold to marking sound as silent.
|
395 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
396 |
+
k_size : int, default 5
|
397 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
398 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
399 |
+
demucs : bool or torch.nn.Module, default False
|
400 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance
|
401 |
+
of a Demucs model to avoid reloading the model for each run.
|
402 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
403 |
+
demucs_output : str, optional
|
404 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
405 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
406 |
+
demucs_options : dict, optional
|
407 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
408 |
+
vad : bool, default False
|
409 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
410 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
411 |
+
vad_threshold : float, default 0.35
|
412 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
413 |
+
vad_onnx : bool, default False
|
414 |
+
Whether to use ONNX for Silero VAD.
|
415 |
+
min_word_dur : float, default 0.1
|
416 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
417 |
+
nonspeech_error : float, default 0.3
|
418 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
419 |
+
only_voice_freq : bool, default False
|
420 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
421 |
+
only_ffmpeg : bool, default False
|
422 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
423 |
+
check_sorted : bool, default True
|
424 |
+
Whether to raise an error when timestamps returned by faster-whipser are not in ascending order.
|
425 |
+
progress_callback : Callable, optional
|
426 |
+
A function that will be called when transcription progress is updated.
|
427 |
+
The callback need two parameters.
|
428 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
429 |
+
The second parameter is a float for total duration of audio in seconds.
|
430 |
+
options
|
431 |
+
Additional options used for :meth:`faster_whisper.WhisperModel.transcribe` and
|
432 |
+
:func:`stable_whisper.non_whisper.transcribe_any`.
|
433 |
+
|
434 |
+
Returns
|
435 |
+
-------
|
436 |
+
stable_whisper.result.WhisperResult
|
437 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
438 |
+
|
439 |
+
Examples
|
440 |
+
--------
|
441 |
+
>>> import stable_whisper
|
442 |
+
>>> model = stable_whisper.load_faster_whisper('base')
|
443 |
+
>>> result = model.transcribe_stable('audio.mp3', vad=True)
|
444 |
+
>>> result.to_srt_vtt('audio.srt')
|
445 |
+
Saved: audio.srt
|
446 |
+
|
447 |
+
</details>
|
448 |
+
|
449 |
+
</details>
|
450 |
+
|
451 |
+
### Output
|
452 |
+
Stable-ts supports various text output formats.
|
453 |
+
```python
|
454 |
+
result.to_srt_vtt('audio.srt') #SRT
|
455 |
+
result.to_srt_vtt('audio.vtt') #VTT
|
456 |
+
result.to_ass('audio.ass') #ASS
|
457 |
+
result.to_tsv('audio.tsv') #TSV
|
458 |
+
```
|
459 |
+
Docstrings:
|
460 |
+
<details>
|
461 |
+
<summary>result_to_srt_vtt()</summary>
|
462 |
+
|
463 |
+
Generate SRT/VTT from ``result`` to display segment-level and/or word-level timestamp.
|
464 |
+
|
465 |
+
Parameters
|
466 |
+
----------
|
467 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
468 |
+
Result of transcription.
|
469 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
470 |
+
Path to save file.
|
471 |
+
segment_level : bool, default True
|
472 |
+
Whether to use segment-level timestamps in output.
|
473 |
+
word_level : bool, default True
|
474 |
+
Whether to use word-level timestamps in output.
|
475 |
+
min_dur : float, default 0.2
|
476 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
477 |
+
tag: tuple of (str, str), default None, meaning ('<font color="#00ff00">', '</font>') if SRT else ('<u>', '</u>')
|
478 |
+
Tag used to change the properties a word at its timestamp.
|
479 |
+
vtt : bool, default None, meaning determined by extension of ``filepath`` or ``False`` if no valid extension.
|
480 |
+
Whether to output VTT.
|
481 |
+
strip : bool, default True
|
482 |
+
Whether to remove spaces before and after text on each segment for output.
|
483 |
+
reverse_text: bool or tuple, default False
|
484 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
485 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
486 |
+
|
487 |
+
Returns
|
488 |
+
-------
|
489 |
+
str
|
490 |
+
String of the content if ``filepath`` is ``None``.
|
491 |
+
|
492 |
+
Notes
|
493 |
+
-----
|
494 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
495 |
+
seems to not suffer from this issue.
|
496 |
+
|
497 |
+
Examples
|
498 |
+
--------
|
499 |
+
>>> import stable_whisper
|
500 |
+
>>> model = stable_whisper.load_model('base')
|
501 |
+
>>> result = model.transcribe('audio.mp3')
|
502 |
+
>>> result.to_srt_vtt('audio.srt')
|
503 |
+
Saved: audio.srt
|
504 |
+
|
505 |
+
</details>
|
506 |
+
|
507 |
+
<details>
|
508 |
+
<summary>result_to_ass()</summary>
|
509 |
+
|
510 |
+
Generate Advanced SubStation Alpha (ASS) file from ``result`` to display segment-level and/or word-level timestamp.
|
511 |
+
|
512 |
+
Parameters
|
513 |
+
----------
|
514 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
515 |
+
Result of transcription.
|
516 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
517 |
+
Path to save file.
|
518 |
+
segment_level : bool, default True
|
519 |
+
Whether to use segment-level timestamps in output.
|
520 |
+
word_level : bool, default True
|
521 |
+
Whether to use word-level timestamps in output.
|
522 |
+
min_dur : float, default 0.2
|
523 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
524 |
+
tag: tuple of (str, str) or int, default None, meaning use default highlighting
|
525 |
+
Tag used to change the properties a word at its timestamp. -1 for individual word highlight tag.
|
526 |
+
font : str, default `Arial`
|
527 |
+
Word font.
|
528 |
+
font_size : int, default 48
|
529 |
+
Word font size.
|
530 |
+
strip : bool, default True
|
531 |
+
Whether to remove spaces before and after text on each segment for output.
|
532 |
+
highlight_color : str, default '00ff00'
|
533 |
+
Hexadecimal of the color use for default highlights as '<bb><gg><rr>'.
|
534 |
+
karaoke : bool, default False
|
535 |
+
Whether to use progressive filling highlights (for karaoke effect).
|
536 |
+
reverse_text: bool or tuple, default False
|
537 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
538 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
539 |
+
kwargs:
|
540 |
+
Format styles:
|
541 |
+
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
|
542 |
+
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
|
543 |
+
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
|
544 |
+
|
545 |
+
Returns
|
546 |
+
-------
|
547 |
+
str
|
548 |
+
String of the content if ``filepath`` is ``None``.
|
549 |
+
|
550 |
+
Notes
|
551 |
+
-----
|
552 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
553 |
+
seems to not suffer from this issue.
|
554 |
+
|
555 |
+
Examples
|
556 |
+
--------
|
557 |
+
>>> import stable_whisper
|
558 |
+
>>> model = stable_whisper.load_model('base')
|
559 |
+
>>> result = model.transcribe('audio.mp3')
|
560 |
+
>>> result.to_ass('audio.ass')
|
561 |
+
Saved: audio.ass
|
562 |
+
|
563 |
+
</details>
|
564 |
+
|
565 |
+
<details>
|
566 |
+
<summary>result_to_tsv()</summary>
|
567 |
+
|
568 |
+
Generate TSV from ``result`` to display segment-level and/or word-level timestamp.
|
569 |
+
|
570 |
+
Parameters
|
571 |
+
----------
|
572 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
573 |
+
Result of transcription.
|
574 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
575 |
+
Path to save file.
|
576 |
+
segment_level : bool, default True
|
577 |
+
Whether to use segment-level timestamps in output.
|
578 |
+
word_level : bool, default True
|
579 |
+
Whether to use word-level timestamps in output.
|
580 |
+
min_dur : float, default 0.2
|
581 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
582 |
+
strip : bool, default True
|
583 |
+
Whether to remove spaces before and after text on each segment for output.
|
584 |
+
reverse_text: bool or tuple, default False
|
585 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
586 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
587 |
+
|
588 |
+
Returns
|
589 |
+
-------
|
590 |
+
str
|
591 |
+
String of the content if ``filepath`` is ``None``.
|
592 |
+
|
593 |
+
Notes
|
594 |
+
-----
|
595 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
596 |
+
seems to not suffer from this issue.
|
597 |
+
|
598 |
+
Examples
|
599 |
+
--------
|
600 |
+
>>> import stable_whisper
|
601 |
+
>>> model = stable_whisper.load_model('base')
|
602 |
+
>>> result = model.transcribe('audio.mp3')
|
603 |
+
>>> result.to_tsv('audio.tsv')
|
604 |
+
Saved: audio.tsv
|
605 |
+
|
606 |
+
</details>
|
607 |
+
|
608 |
+
<details>
|
609 |
+
<summary>result_to_txt()</summary>
|
610 |
+
|
611 |
+
Generate plain-text without timestamps from ``result``.
|
612 |
+
|
613 |
+
Parameters
|
614 |
+
----------
|
615 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
616 |
+
Result of transcription.
|
617 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
618 |
+
Path to save file.
|
619 |
+
min_dur : float, default 0.2
|
620 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
621 |
+
strip : bool, default True
|
622 |
+
Whether to remove spaces before and after text on each segment for output.
|
623 |
+
reverse_text: bool or tuple, default False
|
624 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
625 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
626 |
+
|
627 |
+
Returns
|
628 |
+
-------
|
629 |
+
str
|
630 |
+
String of the content if ``filepath`` is ``None``.
|
631 |
+
|
632 |
+
Notes
|
633 |
+
-----
|
634 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
635 |
+
seems to not suffer from this issue.
|
636 |
+
|
637 |
+
Examples
|
638 |
+
--------
|
639 |
+
>>> import stable_whisper
|
640 |
+
>>> model = stable_whisper.load_model('base')
|
641 |
+
>>> result = model.transcribe('audio.mp3')
|
642 |
+
>>> result.to_txt('audio.txt')
|
643 |
+
Saved: audio.txt
|
644 |
+
|
645 |
+
</details>
|
646 |
+
|
647 |
+
<details>
|
648 |
+
<summary>save_as_json()</summary>
|
649 |
+
|
650 |
+
Save ``result`` as JSON file to ``path``.
|
651 |
+
|
652 |
+
Parameters
|
653 |
+
----------
|
654 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
655 |
+
Result of transcription.
|
656 |
+
path : str
|
657 |
+
Path to save file.
|
658 |
+
ensure_ascii : bool, default False
|
659 |
+
Whether to escape non-ASCII characters.
|
660 |
+
|
661 |
+
Examples
|
662 |
+
--------
|
663 |
+
>>> import stable_whisper
|
664 |
+
>>> model = stable_whisper.load_model('base')
|
665 |
+
>>> result = model.transcribe('audio.mp3')
|
666 |
+
>>> result.save_as_json('audio.json')
|
667 |
+
Saved: audio.json
|
668 |
+
|
669 |
+
</details>
|
670 |
+
|
671 |
+
<br /><br />
|
672 |
+
There are word-level and segment-level timestamps. All output formats support them.
|
673 |
+
They also support will both levels simultaneously except TSV.
|
674 |
+
By default, `segment_level` and `word_level` are both `True` for all the formats that support both simultaneously.<br /><br />
|
675 |
+
Examples in VTT.
|
676 |
+
|
677 |
+
Default: `segment_level=True` + `word_level=True`
|
678 |
+
<details>
|
679 |
+
<summary>CLI</summary>
|
680 |
+
|
681 |
+
`--segment_level true` + `--word_level true`
|
682 |
+
|
683 |
+
</details>
|
684 |
+
|
685 |
+
```
|
686 |
+
00:00:07.760 --> 00:00:09.900
|
687 |
+
But<00:00:07.860> when<00:00:08.040> you<00:00:08.280> arrived<00:00:08.580> at<00:00:08.800> that<00:00:09.000> distant<00:00:09.400> world,
|
688 |
+
```
|
689 |
+
|
690 |
+
`segment_level=True` + `word_level=False`
|
691 |
+
```
|
692 |
+
00:00:07.760 --> 00:00:09.900
|
693 |
+
But when you arrived at that distant world,
|
694 |
+
```
|
695 |
+
|
696 |
+
`segment_level=False` + `word_level=True`
|
697 |
+
```
|
698 |
+
00:00:07.760 --> 00:00:07.860
|
699 |
+
But
|
700 |
+
|
701 |
+
00:00:07.860 --> 00:00:08.040
|
702 |
+
when
|
703 |
+
|
704 |
+
00:00:08.040 --> 00:00:08.280
|
705 |
+
you
|
706 |
+
|
707 |
+
00:00:08.280 --> 00:00:08.580
|
708 |
+
arrived
|
709 |
+
|
710 |
+
...
|
711 |
+
```
|
712 |
+
|
713 |
+
#### JSON
|
714 |
+
The result can also be saved as a JSON file to preserve all the data for future reprocessing.
|
715 |
+
This is useful for testing different sets of postprocessing arguments without the need to redo inference.
|
716 |
+
|
717 |
+
```python
|
718 |
+
result.save_as_json('audio.json')
|
719 |
+
```
|
720 |
+
<details>
|
721 |
+
<summary>CLI</summary>
|
722 |
+
|
723 |
+
```commandline
|
724 |
+
stable-ts audio.mp3 -o audio.json
|
725 |
+
```
|
726 |
+
|
727 |
+
</details>
|
728 |
+
|
729 |
+
Processing JSON file of the results into SRT.
|
730 |
+
```python
|
731 |
+
result = stable_whisper.WhisperResult('audio.json')
|
732 |
+
result.to_srt_vtt('audio.srt')
|
733 |
+
```
|
734 |
+
<details>
|
735 |
+
<summary>CLI</summary>
|
736 |
+
|
737 |
+
```commandline
|
738 |
+
stable-ts audio.json -o audio.srt
|
739 |
+
```
|
740 |
+
|
741 |
+
</details>
|
742 |
+
|
743 |
+
### Alignment
|
744 |
+
Audio can be aligned/synced with plain text on word-level.
|
745 |
+
```python
|
746 |
+
text = 'Machines thinking, breeding. You were to bear us a new, promised land.'
|
747 |
+
result = model.align('audio.mp3', text, language='en')
|
748 |
+
```
|
749 |
+
When the text is correct but the timestamps need more work,
|
750 |
+
`align()` is a faster alternative for testing various settings/models.
|
751 |
+
```python
|
752 |
+
new_result = model.align('audio.mp3', result, language='en')
|
753 |
+
```
|
754 |
+
<details>
|
755 |
+
<summary>CLI</summary>
|
756 |
+
|
757 |
+
```commandline
|
758 |
+
stable-ts audio.mp3 --align text.txt --language en
|
759 |
+
```
|
760 |
+
`--align` can also a JSON file of a result
|
761 |
+
|
762 |
+
</details>
|
763 |
+
|
764 |
+
Docstring:
|
765 |
+
<details>
|
766 |
+
<summary>align()</summary>
|
767 |
+
|
768 |
+
Align plain text or tokens with audio at word-level.
|
769 |
+
|
770 |
+
Since this is significantly faster than transcribing, it is a more efficient method for testing various settings
|
771 |
+
without re-transcribing. This is also useful for timing a more correct transcript than one that Whisper can produce.
|
772 |
+
|
773 |
+
Parameters
|
774 |
+
----------
|
775 |
+
model : "Whisper"
|
776 |
+
The Whisper ASR model modified instance
|
777 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
778 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
779 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
780 |
+
text : str or list of int or stable_whisper.result.WhisperResult
|
781 |
+
String of plain-text, list of tokens, or instance of :class:`stable_whisper.result.WhisperResult`.
|
782 |
+
language : str, default None, uses ``language`` in ``text`` if it is a :class:`stable_whisper.result.WhisperResult`
|
783 |
+
Language of ``text``. Required if ``text`` does not contain ``language``.
|
784 |
+
remove_instant_words : bool, default False
|
785 |
+
Whether to truncate any words with zero duration.
|
786 |
+
token_step : int, default 100
|
787 |
+
Max number of tokens to align each pass. Use higher values to reduce chance of misalignment.
|
788 |
+
original_split : bool, default False
|
789 |
+
Whether to preserve the original segment groupings. Segments are spit by line break if ``text`` is plain-text.
|
790 |
+
max_word_dur : float or None, default 3.0
|
791 |
+
Global maximum word duration in seconds. Re-align words that exceed the global maximum word duration.
|
792 |
+
word_dur_factor : float or None, default 2.0
|
793 |
+
Factor to compute the Local maximum word duration, which is ``word_dur_factor`` * local medium word duration.
|
794 |
+
Words that need re-alignment, are re-algined with duration <= local/global maximum word duration.
|
795 |
+
nonspeech_skip : float or None, default 3.0
|
796 |
+
Skip non-speech sections that are equal or longer than this duration in seconds. Disable skipping if ``None``.
|
797 |
+
fast_mode : bool, default False
|
798 |
+
Whether to speed up alignment by re-alignment with local/global maximum word duration.
|
799 |
+
``True`` tends produce better timestamps when ``text`` is accurate and there are no large speechless gaps.
|
800 |
+
tokenizer : "Tokenizer", default None, meaning a new tokenizer is created according ``language`` and ``model``
|
801 |
+
A tokenizer to used tokenizer text and detokenize tokens.
|
802 |
+
verbose : bool or None, default False
|
803 |
+
Whether to display the text being decoded to the console.
|
804 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
805 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
806 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
807 |
+
Ignored if ``word_timestamps = False``.
|
808 |
+
suppress_silence : bool, default True
|
809 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
810 |
+
suppress_word_ts : bool, default True
|
811 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
812 |
+
use_word_position : bool, default True
|
813 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
814 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
815 |
+
q_levels : int, default 20
|
816 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
817 |
+
Acts as a threshold to marking sound as silent.
|
818 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
819 |
+
k_size : int, default 5
|
820 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
821 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
822 |
+
demucs : bool or torch.nn.Module, default False
|
823 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
824 |
+
a Demucs model to avoid reloading the model for each run.
|
825 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
826 |
+
demucs_output : str, optional
|
827 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
828 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
829 |
+
demucs_options : dict, optional
|
830 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
831 |
+
vad : bool, default False
|
832 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
833 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
834 |
+
vad_threshold : float, default 0.35
|
835 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
836 |
+
vad_onnx : bool, default False
|
837 |
+
Whether to use ONNX for Silero VAD.
|
838 |
+
min_word_dur : float, default 0.1
|
839 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
840 |
+
nonspeech_error : float, default 0.3
|
841 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
842 |
+
only_voice_freq : bool, default False
|
843 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
844 |
+
prepend_punctuations : str, default '"'“¿([{-)'
|
845 |
+
Punctuations to prepend to next word.
|
846 |
+
append_punctuations : str, default '.。,,!!??::”)]}、)'
|
847 |
+
Punctuations to append to previous word.
|
848 |
+
progress_callback : Callable, optional
|
849 |
+
A function that will be called when transcription progress is updated.
|
850 |
+
The callback need two parameters.
|
851 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
852 |
+
The second parameter is a float for total duration of audio in seconds.
|
853 |
+
ignore_compatibility : bool, default False
|
854 |
+
Whether to ignore warnings for compatibility issues with the detected Whisper version.
|
855 |
+
|
856 |
+
Returns
|
857 |
+
-------
|
858 |
+
stable_whisper.result.WhisperResult or None
|
859 |
+
All timestamps, words, probabilities, and other data from the alignment of ``audio``. Return None if alignment
|
860 |
+
fails and ``remove_instant_words = True``.
|
861 |
+
|
862 |
+
Notes
|
863 |
+
-----
|
864 |
+
If ``token_step`` is less than 1, ``token_step`` will be set to its maximum value, 442. This value is computed with
|
865 |
+
``whisper.model.Whisper.dims.n_text_ctx`` - 6.
|
866 |
+
|
867 |
+
IF ``original_split = True`` and a line break is found in middle of a word in ``text``, the split will occur after
|
868 |
+
that word.
|
869 |
+
|
870 |
+
``regroup`` is ignored if ``original_split = True``.
|
871 |
+
|
872 |
+
Examples
|
873 |
+
--------
|
874 |
+
>>> import stable_whisper
|
875 |
+
>>> model = stable_whisper.load_model('base')
|
876 |
+
>>> result = model.align('helloworld.mp3', 'Hello, World!', 'English')
|
877 |
+
>>> result.to_srt_vtt('helloword.srt')
|
878 |
+
Saved 'helloworld.srt'
|
879 |
+
|
880 |
+
</details>
|
881 |
+
|
882 |
+
#### Adjustments
|
883 |
+
Timestamps are adjusted after the model predicts them.
|
884 |
+
When `suppress_silence=True` (default), `transcribe()`/`transcribe_minimal()`/`align()` adjust based on silence/non-speech.
|
885 |
+
The timestamps can be further adjusted base on another result with `adjust_by_result()`,
|
886 |
+
which acts as a logical AND operation for the timestamps of both results, further reducing duration of each word.
|
887 |
+
Note: both results are required to have word timestamps and matching words.
|
888 |
+
```python
|
889 |
+
# the adjustments are in-place for `result`
|
890 |
+
result.adjust_by_result(new_result)
|
891 |
+
```
|
892 |
+
Docstring:
|
893 |
+
<details>
|
894 |
+
<summary>adjust_by_result()</summary>
|
895 |
+
|
896 |
+
Minimize the duration of words using timestamps of another result.
|
897 |
+
|
898 |
+
Parameters
|
899 |
+
----------
|
900 |
+
other_result : "WhisperResult"
|
901 |
+
Timing data of the same words in a WhisperResult instance.
|
902 |
+
min_word_dur : float, default 0.1
|
903 |
+
Prevent changes to timestamps if the resultant word duration is less than ``min_word_dur``.
|
904 |
+
verbose : bool, default False
|
905 |
+
Whether to print out the timestamp changes.
|
906 |
+
|
907 |
+
</details>
|
908 |
+
|
909 |
+
### Refinement
|
910 |
+
Timestamps can be further improved with `refine()`.
|
911 |
+
This method iteratively mutes portions of the audio based on current timestamps
|
912 |
+
then compute the probabilities of the tokens.
|
913 |
+
Then by monitoring the fluctuation of the probabilities, it tries to find the most precise timestamps.
|
914 |
+
"Most precise" in this case means the latest start and earliest end for the word
|
915 |
+
such that it still meets the specified conditions.
|
916 |
+
```python
|
917 |
+
model.refine('audio.mp3', result)
|
918 |
+
```
|
919 |
+
<details>
|
920 |
+
<summary>CLI</summary>
|
921 |
+
|
922 |
+
```commandline
|
923 |
+
stable-ts audio.mp3 --refine -o audio.srt
|
924 |
+
```
|
925 |
+
Input can also be JSON file of a result.
|
926 |
+
```commandline
|
927 |
+
stable-ts result.json --refine -o audio.srt --refine_option "audio=audio.mp3"
|
928 |
+
```
|
929 |
+
|
930 |
+
</details>
|
931 |
+
|
932 |
+
Docstring:
|
933 |
+
<details>
|
934 |
+
<summary>refine()</summary>
|
935 |
+
|
936 |
+
Improve existing timestamps.
|
937 |
+
|
938 |
+
This function iteratively muting portions of the audio and monitoring token probabilities to find the most precise
|
939 |
+
timestamps. This "most precise" in this case means the latest start and earliest end of a word that maintains an
|
940 |
+
acceptable probability determined by the specified arguments.
|
941 |
+
|
942 |
+
This is useful readjusting timestamps when they start too early or end too late.
|
943 |
+
|
944 |
+
Parameters
|
945 |
+
----------
|
946 |
+
model : "Whisper"
|
947 |
+
The Whisper ASR model modified instance
|
948 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
949 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
950 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
951 |
+
result : stable_whisper.result.WhisperResult
|
952 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
953 |
+
steps : str, default 'se'
|
954 |
+
Instructions for refinement. A 's' means refine start-timestamps. An 'e' means refine end-timestamps.
|
955 |
+
rel_prob_decrease : float, default 0.3
|
956 |
+
Maximum percent decrease in probability relative to original probability which is the probability from muting
|
957 |
+
according initial timestamps.
|
958 |
+
abs_prob_decrease : float, default 0.05
|
959 |
+
Maximum decrease in probability from original probability.
|
960 |
+
rel_rel_prob_decrease : float, optional
|
961 |
+
Maximum percent decrease in probability relative to previous probability which is the probability from previous
|
962 |
+
iteration of muting.
|
963 |
+
prob_threshold : float, default 0.5
|
964 |
+
Stop refining the timestamp if the probability of its token goes below this value.
|
965 |
+
rel_dur_change : float, default 0.5
|
966 |
+
Maximum percent change in duration of a word relative to its original duration.
|
967 |
+
abs_dur_change : float, optional
|
968 |
+
Maximum seconds a word is allowed deviate from its original duration.
|
969 |
+
word_level : bool, default True
|
970 |
+
Whether to refine timestamps on word-level. If ``False``, only refine start/end timestamps of each segment.
|
971 |
+
precision : float, default 0.1
|
972 |
+
Precision of refined timestamps in seconds. The lowest precision is 0.02 second.
|
973 |
+
single_batch : bool, default False
|
974 |
+
Whether to process in only batch size of one to reduce memory usage.
|
975 |
+
inplace : bool, default True, meaning return a deepcopy of ``result``
|
976 |
+
Whether to alter timestamps in-place.
|
977 |
+
demucs : bool or torch.nn.Module, default False
|
978 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
979 |
+
a Demucs model to avoid reloading the model for each run.
|
980 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
981 |
+
demucs_options : dict, optional
|
982 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
983 |
+
only_voice_freq : bool, default False
|
984 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
985 |
+
verbose : bool or None, default False
|
986 |
+
Whether to display the text being decoded to the console.
|
987 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
988 |
+
|
989 |
+
Returns
|
990 |
+
-------
|
991 |
+
stable_whisper.result.WhisperResult
|
992 |
+
All timestamps, words, probabilities, and other data from the refinement of ``text`` with ``audio``.
|
993 |
+
|
994 |
+
Notes
|
995 |
+
-----
|
996 |
+
The lower the ``precision``, the longer the processing time.
|
997 |
+
|
998 |
+
Examples
|
999 |
+
--------
|
1000 |
+
>>> import stable_whisper
|
1001 |
+
>>> model = stable_whisper.load_model('base')
|
1002 |
+
>>> result = model.transcribe('audio.mp3')
|
1003 |
+
>>> model.refine('audio.mp3', result)
|
1004 |
+
>>> result.to_srt_vtt('audio.srt')
|
1005 |
+
Saved 'audio.srt'
|
1006 |
+
|
1007 |
+
</details>
|
1008 |
+
|
1009 |
+
|
1010 |
+
### Regrouping Words
|
1011 |
+
Stable-ts has a preset for regrouping words into different segments with more natural boundaries.
|
1012 |
+
This preset is enabled by `regroup=True` (default).
|
1013 |
+
But there are other built-in [regrouping methods](#regrouping-methods) that allow you to customize the regrouping algorithm.
|
1014 |
+
This preset is just a predefined combination of those methods.
|
1015 |
+
|
1016 |
+
https://github.com/jianfch/stable-ts/assets/28970749/7b6164a3-50e2-4368-8b75-853cb14045ec
|
1017 |
+
|
1018 |
+
```python
|
1019 |
+
# The following results are all functionally equivalent:
|
1020 |
+
result0 = model.transcribe('audio.mp3', regroup=True) # regroup is True by default
|
1021 |
+
result1 = model.transcribe('audio.mp3', regroup=False)
|
1022 |
+
(
|
1023 |
+
result1
|
1024 |
+
.clamp_max()
|
1025 |
+
.split_by_punctuation([('.', ' '), '。', '?', '?', (',', ' '), ','])
|
1026 |
+
.split_by_gap(.5)
|
1027 |
+
.merge_by_gap(.3, max_words=3)
|
1028 |
+
.split_by_punctuation([('.', ' '), '。', '?', '?'])
|
1029 |
+
)
|
1030 |
+
result2 = model.transcribe('audio.mp3', regroup='cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?')
|
1031 |
+
|
1032 |
+
# To undo all regrouping operations:
|
1033 |
+
result0.reset()
|
1034 |
+
```
|
1035 |
+
Any regrouping algorithm can be expressed as a string. Please feel free share your strings [here](https://github.com/jianfch/stable-ts/discussions/162)
|
1036 |
+
#### Regrouping Methods
|
1037 |
+
<details>
|
1038 |
+
<summary>regroup()</summary>
|
1039 |
+
|
1040 |
+
Regroup (in-place) words into segments.
|
1041 |
+
|
1042 |
+
Parameters
|
1043 |
+
----------
|
1044 |
+
regroup_algo: str or bool, default 'da'
|
1045 |
+
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'.
|
1046 |
+
verbose : bool, default False
|
1047 |
+
Whether to show all the methods and arguments parsed from ``regroup_algo``.
|
1048 |
+
only_show : bool, default False
|
1049 |
+
Whether to show the all methods and arguments parsed from ``regroup_algo`` without running the methods
|
1050 |
+
|
1051 |
+
Returns
|
1052 |
+
-------
|
1053 |
+
stable_whisper.result.WhisperResult
|
1054 |
+
The current instance after the changes.
|
1055 |
+
|
1056 |
+
Notes
|
1057 |
+
-----
|
1058 |
+
Syntax for string representation of custom regrouping algorithm.
|
1059 |
+
Method keys:
|
1060 |
+
sg: split_by_gap
|
1061 |
+
sp: split_by_punctuation
|
1062 |
+
sl: split_by_length
|
1063 |
+
sd: split_by_duration
|
1064 |
+
mg: merge_by_gap
|
1065 |
+
mp: merge_by_punctuation
|
1066 |
+
ms: merge_all_segment
|
1067 |
+
cm: clamp_max
|
1068 |
+
l: lock
|
1069 |
+
us: unlock_all_segments
|
1070 |
+
da: default algorithm (cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?)
|
1071 |
+
rw: remove_word
|
1072 |
+
rs: remove_segment
|
1073 |
+
rp: remove_repetition
|
1074 |
+
rws: remove_words_by_str
|
1075 |
+
fg: fill_in_gaps
|
1076 |
+
Metacharacters:
|
1077 |
+
= separates a method key and its arguments (not used if no argument)
|
1078 |
+
_ separates method keys (after arguments if there are any)
|
1079 |
+
+ separates arguments for a method key
|
1080 |
+
/ separates an argument into list of strings
|
1081 |
+
* separates an item in list of strings into a nested list of strings
|
1082 |
+
Notes:
|
1083 |
+
-arguments are parsed positionally
|
1084 |
+
-if no argument is provided, the default ones will be used
|
1085 |
+
-use 1 or 0 to represent True or False
|
1086 |
+
Example 1:
|
1087 |
+
merge_by_gap(.2, 10, lock=True)
|
1088 |
+
mg=.2+10+++1
|
1089 |
+
Note: [lock] is the 5th argument hence the 2 missing arguments inbetween the three + before 1
|
1090 |
+
Example 2:
|
1091 |
+
split_by_punctuation([('.', ' '), '。', '?', '?'], True)
|
1092 |
+
sp=.* /。/?/?+1
|
1093 |
+
Example 3:
|
1094 |
+
merge_all_segments().split_by_gap(.5).merge_by_gap(.15, 3)
|
1095 |
+
ms_sg=.5_mg=.15+3
|
1096 |
+
|
1097 |
+
</details>
|
1098 |
+
|
1099 |
+
<details>
|
1100 |
+
<summary>split_by_gap()</summary>
|
1101 |
+
|
1102 |
+
Split (in-place) any segment where the gap between two of its words is greater than ``max_gap``.
|
1103 |
+
|
1104 |
+
Parameters
|
1105 |
+
----------
|
1106 |
+
max_gap : float, default 0.1
|
1107 |
+
Maximum second(s) allowed between two words if the same segment.
|
1108 |
+
lock : bool, default False
|
1109 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1110 |
+
newline: bool, default False
|
1111 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1112 |
+
|
1113 |
+
Returns
|
1114 |
+
-------
|
1115 |
+
stable_whisper.result.WhisperResult
|
1116 |
+
The current instance after the changes.
|
1117 |
+
|
1118 |
+
</details>
|
1119 |
+
|
1120 |
+
<details>
|
1121 |
+
<summary>split_by_punctuation()</summary>
|
1122 |
+
|
1123 |
+
Split (in-place) segments at words that start/end with ``punctuation``.
|
1124 |
+
|
1125 |
+
Parameters
|
1126 |
+
----------
|
1127 |
+
punctuation : list of str of list of tuple of (str, str) or str
|
1128 |
+
Punctuation(s) to split segments by.
|
1129 |
+
lock : bool, default False
|
1130 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1131 |
+
newline : bool, default False
|
1132 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1133 |
+
min_words : int, optional
|
1134 |
+
Split segments with words >= ``min_words``.
|
1135 |
+
min_chars : int, optional
|
1136 |
+
Split segments with characters >= ``min_chars``.
|
1137 |
+
min_dur : int, optional
|
1138 |
+
split segments with duration (in seconds) >= ``min_dur``.
|
1139 |
+
|
1140 |
+
Returns
|
1141 |
+
-------
|
1142 |
+
stable_whisper.result.WhisperResult
|
1143 |
+
The current instance after the changes.
|
1144 |
+
|
1145 |
+
</details>
|
1146 |
+
|
1147 |
+
<details>
|
1148 |
+
<summary>split_by_length()</summary>
|
1149 |
+
|
1150 |
+
Split (in-place) any segment that exceeds ``max_chars`` or ``max_words`` into smaller segments.
|
1151 |
+
|
1152 |
+
Parameters
|
1153 |
+
----------
|
1154 |
+
max_chars : int, optional
|
1155 |
+
Maximum number of characters allowed in each segment.
|
1156 |
+
max_words : int, optional
|
1157 |
+
Maximum number of words allowed in each segment.
|
1158 |
+
even_split : bool, default True
|
1159 |
+
Whether to evenly split a segment in length if it exceeds ``max_chars`` or ``max_words``.
|
1160 |
+
force_len : bool, default False
|
1161 |
+
Whether to force a constant length for each segment except the last segment.
|
1162 |
+
This will ignore all previous non-locked segment boundaries.
|
1163 |
+
lock : bool, default False
|
1164 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1165 |
+
include_lock: bool, default False
|
1166 |
+
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
|
1167 |
+
Splitting will be done after the first non-locked word > ``max_chars`` / ``max_words``.
|
1168 |
+
newline: bool, default False
|
1169 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1170 |
+
|
1171 |
+
Returns
|
1172 |
+
-------
|
1173 |
+
stable_whisper.result.WhisperResult
|
1174 |
+
The current instance after the changes.
|
1175 |
+
|
1176 |
+
Notes
|
1177 |
+
-----
|
1178 |
+
If ``even_split = True``, segments can still exceed ``max_chars`` and locked words will be ignored to avoid
|
1179 |
+
uneven splitting.
|
1180 |
+
|
1181 |
+
</details>
|
1182 |
+
|
1183 |
+
<details>
|
1184 |
+
<summary>split_by_duration()</summary>
|
1185 |
+
|
1186 |
+
Split (in-place) any segment that exceeds ``max_dur`` into smaller segments.
|
1187 |
+
|
1188 |
+
Parameters
|
1189 |
+
----------
|
1190 |
+
max_dur : float
|
1191 |
+
Maximum duration (in seconds) per segment.
|
1192 |
+
even_split : bool, default True
|
1193 |
+
Whether to evenly split a segment in length if it exceeds ``max_dur``.
|
1194 |
+
force_len : bool, default False
|
1195 |
+
Whether to force a constant length for each segment except the last segment.
|
1196 |
+
This will ignore all previous non-locked segment boundaries.
|
1197 |
+
lock : bool, default False
|
1198 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1199 |
+
include_lock: bool, default False
|
1200 |
+
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
|
1201 |
+
Splitting will be done after the first non-locked word > ``max_dur``.
|
1202 |
+
newline: bool, default False
|
1203 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1204 |
+
|
1205 |
+
Returns
|
1206 |
+
-------
|
1207 |
+
stable_whisper.result.WhisperResult
|
1208 |
+
The current instance after the changes.
|
1209 |
+
|
1210 |
+
Notes
|
1211 |
+
-----
|
1212 |
+
If ``even_split = True``, segments can still exceed ``max_dur`` and locked words will be ignored to avoid
|
1213 |
+
uneven splitting.
|
1214 |
+
|
1215 |
+
</details>
|
1216 |
+
|
1217 |
+
<details>
|
1218 |
+
<summary>merge_by_gap()</summary>
|
1219 |
+
|
1220 |
+
Merge (in-place) any pair of adjacent segments if the gap between them <= ``min_gap``.
|
1221 |
+
|
1222 |
+
Parameters
|
1223 |
+
----------
|
1224 |
+
min_gap : float, default 0.1
|
1225 |
+
Minimum second(s) allow between two segment.
|
1226 |
+
max_words : int, optional
|
1227 |
+
Maximum number of words allowed in each segment.
|
1228 |
+
max_chars : int, optional
|
1229 |
+
Maximum number of characters allowed in each segment.
|
1230 |
+
is_sum_max : bool, default False
|
1231 |
+
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
|
1232 |
+
to be merged.
|
1233 |
+
lock : bool, default False
|
1234 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1235 |
+
|
1236 |
+
Returns
|
1237 |
+
-------
|
1238 |
+
stable_whisper.result.WhisperResult
|
1239 |
+
The current instance after the changes.
|
1240 |
+
|
1241 |
+
</details>
|
1242 |
+
|
1243 |
+
<details>
|
1244 |
+
<summary>merge_by_punctuation()</summary>
|
1245 |
+
|
1246 |
+
Merge (in-place) any two segments that has specific punctuations inbetween.
|
1247 |
+
|
1248 |
+
Parameters
|
1249 |
+
----------
|
1250 |
+
punctuation : list of str of list of tuple of (str, str) or str
|
1251 |
+
Punctuation(s) to merge segments by.
|
1252 |
+
max_words : int, optional
|
1253 |
+
Maximum number of words allowed in each segment.
|
1254 |
+
max_chars : int, optional
|
1255 |
+
Maximum number of characters allowed in each segment.
|
1256 |
+
is_sum_max : bool, default False
|
1257 |
+
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
|
1258 |
+
to be merged.
|
1259 |
+
lock : bool, default False
|
1260 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1261 |
+
|
1262 |
+
Returns
|
1263 |
+
-------
|
1264 |
+
stable_whisper.result.WhisperResult
|
1265 |
+
The current instance after the changes.
|
1266 |
+
|
1267 |
+
</details>
|
1268 |
+
|
1269 |
+
<details>
|
1270 |
+
<summary>merge_all_segments()</summary>
|
1271 |
+
|
1272 |
+
Merge all segments into one segment.
|
1273 |
+
|
1274 |
+
Returns
|
1275 |
+
-------
|
1276 |
+
stable_whisper.result.WhisperResult
|
1277 |
+
The current instance after the changes.
|
1278 |
+
|
1279 |
+
</details>
|
1280 |
+
|
1281 |
+
<details>
|
1282 |
+
<summary>clamp_max()</summary>
|
1283 |
+
|
1284 |
+
Clamp all word durations above certain value.
|
1285 |
+
|
1286 |
+
This is most effective when applied before and after other regroup operations.
|
1287 |
+
|
1288 |
+
Parameters
|
1289 |
+
----------
|
1290 |
+
medium_factor : float, default 2.5
|
1291 |
+
Clamp durations above (``medium_factor`` * medium duration) per segment.
|
1292 |
+
If ``medium_factor = None/0`` or segment has less than 3 words, it will be ignored and use only ``max_dur``.
|
1293 |
+
max_dur : float, optional
|
1294 |
+
Clamp durations above ``max_dur``.
|
1295 |
+
clip_start : bool or None, default None
|
1296 |
+
Whether to clamp the start of a word. If ``None``, clamp the start of first word and end of last word per
|
1297 |
+
segment.
|
1298 |
+
verbose : bool, default False
|
1299 |
+
Whether to print out the timestamp changes.
|
1300 |
+
|
1301 |
+
Returns
|
1302 |
+
-------
|
1303 |
+
stable_whisper.result.WhisperResult
|
1304 |
+
The current instance after the changes.
|
1305 |
+
|
1306 |
+
</details>
|
1307 |
+
|
1308 |
+
<details>
|
1309 |
+
<summary>lock()</summary>
|
1310 |
+
|
1311 |
+
Lock words/segments with matching prefix/suffix to prevent splitting/merging.
|
1312 |
+
|
1313 |
+
Parameters
|
1314 |
+
----------
|
1315 |
+
startswith: str or list of str
|
1316 |
+
Prefixes to lock.
|
1317 |
+
endswith: str or list of str
|
1318 |
+
Suffixes to lock.
|
1319 |
+
right : bool, default True
|
1320 |
+
Whether prevent splits/merges with the next word/segment.
|
1321 |
+
left : bool, default False
|
1322 |
+
Whether prevent splits/merges with the previous word/segment.
|
1323 |
+
case_sensitive : bool, default False
|
1324 |
+
Whether to match the case of the prefixes/suffixes with the words/segments.
|
1325 |
+
strip : bool, default True
|
1326 |
+
Whether to ignore spaces before and after both words/segments and prefixes/suffixes.
|
1327 |
+
|
1328 |
+
Returns
|
1329 |
+
-------
|
1330 |
+
stable_whisper.result.WhisperResult
|
1331 |
+
The current instance after the changes.
|
1332 |
+
|
1333 |
+
</details>
|
1334 |
+
|
1335 |
+
### Editing
|
1336 |
+
The editing methods in stable-ts can be chained with [Regrouping Methods](#regrouping-methods) and used in `regroup()`.
|
1337 |
+
|
1338 |
+
Remove specific instances words or segments:
|
1339 |
+
```python
|
1340 |
+
# Remove first word of the first segment:
|
1341 |
+
first_word = result[0][0]
|
1342 |
+
result.remove_word(first_word)
|
1343 |
+
# This following is also does the same:
|
1344 |
+
del result[0][0]
|
1345 |
+
|
1346 |
+
# Remove the last segment:
|
1347 |
+
last_segment = result[-1]
|
1348 |
+
result.remove_segment(last_segment)
|
1349 |
+
# This following is also does the same:
|
1350 |
+
del result[-1]
|
1351 |
+
```
|
1352 |
+
Docstrings:
|
1353 |
+
<details>
|
1354 |
+
<summary>remove_word()</summary>
|
1355 |
+
|
1356 |
+
Remove a word.
|
1357 |
+
|
1358 |
+
Parameters
|
1359 |
+
----------
|
1360 |
+
word : WordTiming or tuple of (int, int)
|
1361 |
+
Instance of :class:`stable_whisper.result.WordTiming` or tuple of (segment index, word index).
|
1362 |
+
reassign_ids : bool, default True
|
1363 |
+
Whether to reassign segment and word ids (indices) after removing ``word``.
|
1364 |
+
verbose : bool, default True
|
1365 |
+
Whether to print detail of the removed word.
|
1366 |
+
|
1367 |
+
Returns
|
1368 |
+
-------
|
1369 |
+
stable_whisper.result.WhisperResult
|
1370 |
+
The current instance after the changes.
|
1371 |
+
|
1372 |
+
</details>
|
1373 |
+
|
1374 |
+
<details>
|
1375 |
+
<summary>remove_segment()</summary>
|
1376 |
+
|
1377 |
+
Remove a segment.
|
1378 |
+
|
1379 |
+
Parameters
|
1380 |
+
----------
|
1381 |
+
segment : Segment or int
|
1382 |
+
Instance :class:`stable_whisper.result.Segment` or segment index.
|
1383 |
+
reassign_ids : bool, default True
|
1384 |
+
Whether to reassign segment IDs (indices) after removing ``segment``.
|
1385 |
+
verbose : bool, default True
|
1386 |
+
Whether to print detail of the removed word.
|
1387 |
+
|
1388 |
+
Returns
|
1389 |
+
-------
|
1390 |
+
stable_whisper.result.WhisperResult
|
1391 |
+
The current instance after the changes.
|
1392 |
+
|
1393 |
+
</details>
|
1394 |
+
|
1395 |
+
|
1396 |
+
Removing repetitions:
|
1397 |
+
```python
|
1398 |
+
# Example 1: "This is is is a test." -> "This is a test."
|
1399 |
+
# The following removes the last two " is":
|
1400 |
+
result.remove_repetition(1)
|
1401 |
+
|
1402 |
+
# Example 2: "This is is is a test this is a test." -> "This is a test."
|
1403 |
+
# The following removes the second " is" and third " is", then remove the last "this is a test"
|
1404 |
+
# The first parameter `max_words` is `4` because "this is a test" consists 4 words
|
1405 |
+
result.remove_repetition(4)
|
1406 |
+
```
|
1407 |
+
Docstring:
|
1408 |
+
<details>
|
1409 |
+
<summary>remove_repetition()</summary>
|
1410 |
+
|
1411 |
+
Remove words that repeat consecutively.
|
1412 |
+
|
1413 |
+
Parameters
|
1414 |
+
----------
|
1415 |
+
max_words : int
|
1416 |
+
Maximum number of words to look for consecutively.
|
1417 |
+
case_sensitive : bool, default False
|
1418 |
+
Whether the case of words need to match to be considered as repetition.
|
1419 |
+
strip : bool, default True
|
1420 |
+
Whether to ignore spaces before and after each word.
|
1421 |
+
ignore_punctuations : bool, default '"',.?!'
|
1422 |
+
Ending punctuations to ignore.
|
1423 |
+
extend_duration: bool, default True
|
1424 |
+
Whether to extend the duration of the previous word to cover the duration of the repetition.
|
1425 |
+
verbose: bool, default True
|
1426 |
+
Whether to print detail of the removed repetitions.
|
1427 |
+
|
1428 |
+
Returns
|
1429 |
+
-------
|
1430 |
+
stable_whisper.result.WhisperResult
|
1431 |
+
The current instance after the changes.
|
1432 |
+
|
1433 |
+
</details>
|
1434 |
+
|
1435 |
+
Removing specific word(s) by string content:
|
1436 |
+
```python
|
1437 |
+
# Remove all " ok" from " ok ok this is a test."
|
1438 |
+
result.remove_words_by_str('ok')
|
1439 |
+
|
1440 |
+
# Remove all " ok" and " Um..." from " ok this is a test. Um..."
|
1441 |
+
result.remove_words_by_str(['ok', 'um'])
|
1442 |
+
```
|
1443 |
+
Docstring:
|
1444 |
+
<details>
|
1445 |
+
<summary>remove_words_by_str()</summary>
|
1446 |
+
|
1447 |
+
Remove words that match ``words``.
|
1448 |
+
|
1449 |
+
Parameters
|
1450 |
+
----------
|
1451 |
+
words : str or list of str or None
|
1452 |
+
A word or list of words to remove.``None`` for all words to be passed into ``filters``.
|
1453 |
+
case_sensitive : bool, default False
|
1454 |
+
Whether the case of words need to match to be considered as repetition.
|
1455 |
+
strip : bool, default True
|
1456 |
+
Whether to ignore spaces before and after each word.
|
1457 |
+
ignore_punctuations : bool, default '"',.?!'
|
1458 |
+
Ending punctuations to ignore.
|
1459 |
+
min_prob : float, optional
|
1460 |
+
Acts as the first filter the for the words that match ``words``. Words with probability < ``min_prob`` will
|
1461 |
+
be removed if ``filters`` is ``None``, else pass the words into ``filters``. Words without probability will
|
1462 |
+
be treated as having probability < ``min_prob``.
|
1463 |
+
filters : Callable, optional
|
1464 |
+
A function that takes an instance of :class:`stable_whisper.result.WordTiming` as its only argument.
|
1465 |
+
This function is custom filter for the words that match ``words`` and were not caught by ``min_prob``.
|
1466 |
+
verbose:
|
1467 |
+
Whether to print detail of the removed words.
|
1468 |
+
|
1469 |
+
Returns
|
1470 |
+
-------
|
1471 |
+
stable_whisper.result.WhisperResult
|
1472 |
+
The current instance after the changes.
|
1473 |
+
|
1474 |
+
</details>
|
1475 |
+
|
1476 |
+
Filling in segment gaps:
|
1477 |
+
```python
|
1478 |
+
# result0: [" How are you?"] [" I'm good."] [" Good!"]
|
1479 |
+
# result1: [" Hello!"] [" How are you?"] [" How about you?"] [" Good!"]
|
1480 |
+
result0.fill_in_gaps(result1)
|
1481 |
+
# After filling in the gaps in `result0` with contents in `result1`:
|
1482 |
+
# result0: [" Hello!"] [" How are you?"] [" I'm good."] [" How about you?"] [" Good!"]
|
1483 |
+
```
|
1484 |
+
Docstring:
|
1485 |
+
<details>
|
1486 |
+
<summary>fill_in_gaps()</summary>
|
1487 |
+
|
1488 |
+
Fill in segment gaps larger than ``min_gap`` with content from ``other_result`` at the times of gaps.
|
1489 |
+
|
1490 |
+
Parameters
|
1491 |
+
----------
|
1492 |
+
other_result : WhisperResult or str
|
1493 |
+
Another transcription result as an instance of :class:`stable_whisper.result.WhisperResult` or path to the
|
1494 |
+
JSON of the result.
|
1495 |
+
min_gap : float, default 0.1
|
1496 |
+
The minimum seconds of a gap between segments that must be exceeded to be filled in.
|
1497 |
+
case_sensitive : bool, default False
|
1498 |
+
Whether to consider the case of the first and last word of the gap to determine overlapping words to remove
|
1499 |
+
before filling in.
|
1500 |
+
strip : bool, default True
|
1501 |
+
Whether to ignore spaces before and after the first and last word of the gap to determine overlapping words
|
1502 |
+
to remove before filling in.
|
1503 |
+
ignore_punctuations : bool, default '"',.?!'
|
1504 |
+
Ending punctuations to ignore in the first and last word of the gap to determine overlapping words to
|
1505 |
+
remove before filling in.
|
1506 |
+
verbose:
|
1507 |
+
Whether to print detail of the filled content.
|
1508 |
+
|
1509 |
+
Returns
|
1510 |
+
-------
|
1511 |
+
stable_whisper.result.WhisperResult
|
1512 |
+
The current instance after the changes.
|
1513 |
+
|
1514 |
+
</details>
|
1515 |
+
|
1516 |
+
### Locating Words
|
1517 |
+
There are two ways to locate words.
|
1518 |
+
The first way is by approximating time at which the words are spoken
|
1519 |
+
then transcribing a few seconds around the approximated time.
|
1520 |
+
This also the faster way for locating words.
|
1521 |
+
```python
|
1522 |
+
matches = model.locate('audio.mp3', 'are', language='en', count=0)
|
1523 |
+
for match in matches:
|
1524 |
+
print(match.to_display_str())
|
1525 |
+
# verbose=True does the same thing as this for-loop.
|
1526 |
+
```
|
1527 |
+
Docstring:
|
1528 |
+
<details>
|
1529 |
+
<summary>locate()</summary>
|
1530 |
+
|
1531 |
+
Locate when specific words are spoken in ``audio`` without fully transcribing.
|
1532 |
+
|
1533 |
+
This is usefully for quickly finding at what time the specify words or phrases are spoken in an audio. Since it
|
1534 |
+
does not need to transcribe the audio to approximate the time, it is significantly faster transcribing then
|
1535 |
+
locating the word in the transcript.
|
1536 |
+
|
1537 |
+
It can also transcribe few seconds around the approximated time to find out what was said around those words or
|
1538 |
+
confirm if the word was even spoken near that time.
|
1539 |
+
|
1540 |
+
Parameters
|
1541 |
+
----------
|
1542 |
+
model : whisper.model.Whisper
|
1543 |
+
An instance of Whisper ASR model.
|
1544 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
1545 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
1546 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
1547 |
+
text: str or list of int
|
1548 |
+
Words/phrase or list of tokens to search for in ``audio``.
|
1549 |
+
language : str
|
1550 |
+
Language of the ``text``.
|
1551 |
+
count : int, default 1, meaning stop search after 1 match
|
1552 |
+
Number of matches to find. Use 0 to look for all.
|
1553 |
+
duration_window : float or tuple of (float, float), default 3.0, same as (3.0, 3.0)
|
1554 |
+
Seconds before and after the end timestamp approximations to transcribe after mode 1.
|
1555 |
+
If tuple pair of values, then the 1st value will be seconds before the end and 2nd value will be seconds after.
|
1556 |
+
mode : int, default 0
|
1557 |
+
Mode of search.
|
1558 |
+
2, Approximates the end timestamp of ``text`` in the audio. This mode does not confirm whether ``text`` is
|
1559 |
+
spoken at the timestamp
|
1560 |
+
1, Completes mode 2 then transcribes audio within ``duration_window`` to confirm whether `text` is a match at
|
1561 |
+
the approximated timestamp by checking if ``text`` at that ``duration_window`` is within
|
1562 |
+
``probability_threshold`` or matching the string content if ``text`` with the transcribed text at the
|
1563 |
+
``duration_window``.
|
1564 |
+
0, Completes mode 1 then add word timestamps to the transcriptions of each match.
|
1565 |
+
Modes from fastest to slowest: 2, 1, 0
|
1566 |
+
start : float, optional, meaning it starts from 0s
|
1567 |
+
Seconds into the audio to start searching for ``text``.
|
1568 |
+
end : float, optional
|
1569 |
+
Seconds into the audio to stop searching for ``text``.
|
1570 |
+
probability_threshold : float, default 0.5
|
1571 |
+
Minimum probability of each token in ``text`` for it to be considered a match.
|
1572 |
+
eots : int, default 1
|
1573 |
+
Number of EOTs to reach before stopping transcription at mode 1. When transcription reach a EOT, it usually
|
1574 |
+
means the end of the segment or audio. Once ``text`` is found in the ``duration_window``, the transcription
|
1575 |
+
will stop immediately upon reaching a EOT.
|
1576 |
+
max_token_per_seg : int, default 20
|
1577 |
+
Maximum number of tokens to transcribe in the ``duration_window`` before stopping.
|
1578 |
+
exact_token : bool, default False
|
1579 |
+
Whether to find a match base on the exact tokens that make up ``text``.
|
1580 |
+
case_sensitive : bool, default False
|
1581 |
+
Whether to consider the case of ``text`` when matching in string content.
|
1582 |
+
verbose : bool or None, default False
|
1583 |
+
Whether to display the text being decoded to the console.
|
1584 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
1585 |
+
initial_prompt : str, optional
|
1586 |
+
Text to provide as a prompt for the first window. This can be used to provide, or
|
1587 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
1588 |
+
to make it more likely to predict those word correctly.
|
1589 |
+
suppress_tokens : str or list of int, default '-1', meaning suppress special characters except common punctuations
|
1590 |
+
List of tokens to suppress.
|
1591 |
+
demucs : bool or torch.nn.Module, default False
|
1592 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
1593 |
+
a Demucs model to avoid reloading the model for each run.
|
1594 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
1595 |
+
demucs_options : dict, optional
|
1596 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
1597 |
+
only_voice_freq : bool, default False
|
1598 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
1599 |
+
|
1600 |
+
Returns
|
1601 |
+
-------
|
1602 |
+
stable_whisper.result.Segment or list of dict or list of float
|
1603 |
+
Mode 0, list of instances of :class:`stable_whisper.result.Segment`.
|
1604 |
+
Mode 1, list of dictionaries with end timestamp approximation of matches and transcribed neighboring words.
|
1605 |
+
Mode 2, list of timestamps in seconds for each end timestamp approximation.
|
1606 |
+
|
1607 |
+
Notes
|
1608 |
+
-----
|
1609 |
+
For ``text``, the case and spacing matters as 'on', ' on', ' On' are different tokens, therefore chose the one that
|
1610 |
+
best suits the context (e.g. ' On' to look for it at the beginning of a sentence).
|
1611 |
+
|
1612 |
+
Use a sufficiently large first value of ``duration_window`` i.e. the value > time it is expected to speak ``text``.
|
1613 |
+
|
1614 |
+
If ``exact_token = False`` and the string content matches, then ``probability_threshold`` is not used.
|
1615 |
+
|
1616 |
+
Examples
|
1617 |
+
--------
|
1618 |
+
>>> import stable_whisper
|
1619 |
+
>>> model = stable_whisper.load_model('base')
|
1620 |
+
>>> matches = model.locate('audio.mp3', 'are', 'English', verbose=True)
|
1621 |
+
|
1622 |
+
Some words can sound the same but have different spellings to increase of the chance of finding such words use
|
1623 |
+
``initial_prompt``.
|
1624 |
+
|
1625 |
+
>>> matches = model.locate('audio.mp3', ' Nickie', 'English', verbose=True, initial_prompt='Nickie')
|
1626 |
+
|
1627 |
+
</details>
|
1628 |
+
|
1629 |
+
<details>
|
1630 |
+
<summary>CLI</summary>
|
1631 |
+
|
1632 |
+
```
|
1633 |
+
stable-ts audio.mp3 --locate "are" --language en -to "count=0"
|
1634 |
+
```
|
1635 |
+
|
1636 |
+
</details>
|
1637 |
+
|
1638 |
+
The second way allows you to locate words with regular expression,
|
1639 |
+
but it requires the audio to be fully transcribed first.
|
1640 |
+
```python
|
1641 |
+
result = model.transcribe('audio.mp3')
|
1642 |
+
# Find every sentence that contains "and"
|
1643 |
+
matches = result.find(r'[^.]+and[^.]+\.')
|
1644 |
+
# print the all matches if there are any
|
1645 |
+
for match in matches:
|
1646 |
+
print(f'match: {match.text_match}\n'
|
1647 |
+
f'text: {match.text}\n'
|
1648 |
+
f'start: {match.start}\n'
|
1649 |
+
f'end: {match.end}\n')
|
1650 |
+
|
1651 |
+
# Find the word before and after "and" in the matches
|
1652 |
+
matches = matches.find(r'\s\S+\sand\s\S+')
|
1653 |
+
for match in matches:
|
1654 |
+
print(f'match: {match.text_match}\n'
|
1655 |
+
f'text: {match.text}\n'
|
1656 |
+
f'start: {match.start}\n'
|
1657 |
+
f'end: {match.end}\n')
|
1658 |
+
```
|
1659 |
+
Docstring:
|
1660 |
+
<details>
|
1661 |
+
<summary>find()</summary>
|
1662 |
+
|
1663 |
+
Find segments/words and timestamps with regular expression.
|
1664 |
+
|
1665 |
+
Parameters
|
1666 |
+
----------
|
1667 |
+
pattern : str
|
1668 |
+
RegEx pattern to search for.
|
1669 |
+
word_level : bool, default True
|
1670 |
+
Whether to search at word-level.
|
1671 |
+
flags : optional
|
1672 |
+
RegEx flags.
|
1673 |
+
|
1674 |
+
Returns
|
1675 |
+
-------
|
1676 |
+
stable_whisper.result.WhisperResultMatches
|
1677 |
+
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``.
|
1678 |
+
|
1679 |
+
</details>
|
1680 |
+
|
1681 |
+
### Silence Suppression
|
1682 |
+
While the timestamps predicted by Whisper are generally accurate,
|
1683 |
+
it sometimes predicts the start of a word way before the word is spoken
|
1684 |
+
or the end of a word long after the word has been spoken.
|
1685 |
+
This is where "silence suppression" helps. It is enabled by default (`suppress_silence=True`).
|
1686 |
+
The idea is to adjust the timestamps based on the timestamps of non-speech portions of the audio.
|
1687 |
+

|
1688 |
+
*Note: In 1.X, "silence suppression" refers to the process of suppressing timestamp tokens of the silent portions during inference,
|
1689 |
+
but changed to post-inference timestamp adjustments in 2.X, which allows stable-ts to be used with other ASR models.
|
1690 |
+
The timestamp token suppression feature is disabled by default, but can still be enabled with `suppress_ts_tokens=True`.*
|
1691 |
+
|
1692 |
+
By default, stable-ts determines the non-speech timestamps based on
|
1693 |
+
how loud a section of the audio is relative to the neighboring sections.
|
1694 |
+
This method is most effective for cases, where the speech is significantly louder than the background noise.
|
1695 |
+
The other method is to use [Silero VAD](https://github.com/snakers4/silero-vad) (enabled with `vad=True`).
|
1696 |
+
To visualize the differences between non-VAD and VAD, see [Visualizing Suppression](#visualizing-suppression).
|
1697 |
+
|
1698 |
+
Besides the parameters for non-speech detection sensitivity (see [Visualizing Suppression](#visualizing-suppression)),
|
1699 |
+
the following parameters are used to combat inaccurate non-speech detection.<br>
|
1700 |
+
`min_word_dur` is the shortest duration each word is allowed from adjustments.<br>
|
1701 |
+
`nonspeech_error` is the relative error of the non-speech that appears in between a word.<br>
|
1702 |
+
`use_word_position` is whether to use word position in segment to determine whether to keep end or start timestamps
|
1703 |
+
*Note: `nonspeech_error` was not available before 2.14.0; `use_word_position` was not available before 2.14.2;
|
1704 |
+
`min_word_dur` prevented any adjustments that resulted in word duration shorter than `min_word_dur`.*
|
1705 |
+
|
1706 |
+
For the following example, `min_word_dur=0.5` (default: 0.1) and `nonspeech_error=0.3` (default: 0.3).
|
1707 |
+

|
1708 |
+
`nonspeech_error=0.3` allows each non-speech section to be treated 1.3 times their actual duration.
|
1709 |
+
Either from the start of the corresponding word to the end of the non-speech
|
1710 |
+
or from the start of the non-speech to the end of the corresponding word.
|
1711 |
+
In the case that both conditions are met, the shorter one is used.
|
1712 |
+
Or if both are equal, then the start of the non-speech to the end of the word is used.<br>
|
1713 |
+
The second non-speech from 1.375s to 1.75s is ignored for 'world.' because it failed both conditions.<br>
|
1714 |
+
The first word, 'Hello', satisfies only the former condition from 0s to 0.625, thus the new start for 'Hello'
|
1715 |
+
would be 0.625s. However, `min_word_dur=0.5` requires the resultant duration to be at least 0.5s.
|
1716 |
+
As a result, the start of 'Hello' is changed to 0.375s instead of 0.625s.
|
1717 |
+
Furthermore, the default setting, `use_word_position=True`, also ensures the start is adjusted for the first word
|
1718 |
+
and the end is adjusted for the last word of the segment as long as one of the conditions is true.
|
1719 |
+
|
1720 |
+
### Tips
|
1721 |
+
- do not disable word timestamps with `word_timestamps=False` for reliable segment timestamps
|
1722 |
+
- use `vad=True` for more accurate non-speech detection
|
1723 |
+
- use `demucs=True` to isolate vocals with [Demucs](https://github.com/facebookresearch/demucs); it is also effective at isolating vocals even if there is no music
|
1724 |
+
- use `demucs=True` and `vad=True` for music
|
1725 |
+
- set same seed for each transcription (e.g. `random.seed(0)`) for `demucs=True` to produce deterministic outputs
|
1726 |
+
- to enable dynamic quantization for inference on CPU use `--dq true` for CLI or `dq=True` for `stable_whisper.load_model`
|
1727 |
+
- use `encode_video_comparison()` to encode multiple transcripts into one video for synced comparison; see [Encode Comparison](#encode-comparison)
|
1728 |
+
- use `visualize_suppression()` to visualize the differences between non-VAD and VAD options; see [Visualizing Suppression](#visualizing-suppression)
|
1729 |
+
- [refinement](#refinement) can an effective (but slow) alternative for polishing timestamps if silence suppression isn't effective
|
1730 |
+
|
1731 |
+
### Visualizing Suppression
|
1732 |
+
You can visualize which parts of the audio will likely be suppressed (i.e. marked as silent).
|
1733 |
+
Requires: [Pillow](https://github.com/python-pillow/Pillow) or [opencv-python](https://github.com/opencv/opencv-python).
|
1734 |
+
|
1735 |
+
#### Without VAD
|
1736 |
+
```python
|
1737 |
+
import stable_whisper
|
1738 |
+
# regions on the waveform colored red are where it will likely be suppressed and marked as silent
|
1739 |
+
# [q_levels]=20 and [k_size]=5 (default)
|
1740 |
+
stable_whisper.visualize_suppression('audio.mp3', 'image.png', q_levels=20, k_size = 5)
|
1741 |
+
```
|
1742 |
+

|
1743 |
+
|
1744 |
+
#### With [Silero VAD](https://github.com/snakers4/silero-vad)
|
1745 |
+
```python
|
1746 |
+
# [vad_threshold]=0.35 (default)
|
1747 |
+
stable_whisper.visualize_suppression('audio.mp3', 'image.png', vad=True, vad_threshold=0.35)
|
1748 |
+
```
|
1749 |
+

|
1750 |
+
Docstring:
|
1751 |
+
<details>
|
1752 |
+
<summary>visualize_suppression()</summary>
|
1753 |
+
|
1754 |
+
Visualize regions on the waveform of ``audio`` detected as silent.
|
1755 |
+
|
1756 |
+
Regions on the waveform colored red are detected as silent.
|
1757 |
+
|
1758 |
+
Parameters
|
1759 |
+
----------
|
1760 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
1761 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
1762 |
+
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
|
1763 |
+
output : str, default None, meaning image will be shown directly via Pillow or opencv-python
|
1764 |
+
Path to save visualization.
|
1765 |
+
q_levels : int, default 20
|
1766 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
1767 |
+
Acts as a threshold to marking sound as silent.
|
1768 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
1769 |
+
k_size : int, default 5
|
1770 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
1771 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
1772 |
+
vad : bool, default False
|
1773 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
1774 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
1775 |
+
vad_threshold : float, default 0.35
|
1776 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
1777 |
+
max_width : int, default 1500
|
1778 |
+
Maximum width of visualization to avoid overly large image from long audio.
|
1779 |
+
Each unit of pixel is equivalent to 1 token. Use -1 to visualize the entire audio track.
|
1780 |
+
height : int, default 200
|
1781 |
+
Height of visualization.
|
1782 |
+
|
1783 |
+
</details>
|
1784 |
+
|
1785 |
+
### Encode Comparison
|
1786 |
+
You can encode videos similar to the ones in the doc for comparing transcriptions of the same audio.
|
1787 |
+
```python
|
1788 |
+
stable_whisper.encode_video_comparison(
|
1789 |
+
'audio.mp3',
|
1790 |
+
['audio_sub1.srt', 'audio_sub2.srt'],
|
1791 |
+
output_videopath='audio.mp4',
|
1792 |
+
labels=['Example 1', 'Example 2']
|
1793 |
+
)
|
1794 |
+
```
|
1795 |
+
Docstring:
|
1796 |
+
<details>
|
1797 |
+
<summary>encode_video_comparison()</summary>
|
1798 |
+
|
1799 |
+
Encode multiple subtitle files into one video with the subtitles vertically stacked.
|
1800 |
+
|
1801 |
+
Parameters
|
1802 |
+
----------
|
1803 |
+
audiofile : str
|
1804 |
+
Path of audio file.
|
1805 |
+
subtitle_files : list of str
|
1806 |
+
List of paths for subtitle file.
|
1807 |
+
output_videopath : str, optional
|
1808 |
+
Output video path.
|
1809 |
+
labels : list of str, default, None, meaning use ``subtitle_files`` as labels
|
1810 |
+
List of labels for ``subtitle_files``.
|
1811 |
+
height : int, default 90
|
1812 |
+
Height for each subtitle section.
|
1813 |
+
width : int, default 720
|
1814 |
+
Width for each subtitle section.
|
1815 |
+
color : str, default 'black'
|
1816 |
+
Background color of the video.
|
1817 |
+
fontsize: int, default 70
|
1818 |
+
Font size for subtitles.
|
1819 |
+
border_color : str, default 'white'
|
1820 |
+
Border color for separating the sections of subtitle.
|
1821 |
+
label_color : str, default 'white'
|
1822 |
+
Color of labels.
|
1823 |
+
label_size : int, default 14
|
1824 |
+
Font size of labels.
|
1825 |
+
fps : int, default 25
|
1826 |
+
Frame-rate of the video.
|
1827 |
+
video_codec : str, optional
|
1828 |
+
Video codec opf the video.
|
1829 |
+
audio_codec : str, optional
|
1830 |
+
Audio codec opf the video.
|
1831 |
+
overwrite : bool, default False
|
1832 |
+
Whether to overwrite existing video files with the same path as the output video.
|
1833 |
+
only_cmd : bool, default False
|
1834 |
+
Whether to skip encoding and only return the full command generate from the specified options.
|
1835 |
+
verbose : bool, default True
|
1836 |
+
Whether to display ffmpeg processing info.
|
1837 |
+
|
1838 |
+
Returns
|
1839 |
+
-------
|
1840 |
+
str or None
|
1841 |
+
Encoding command as a string if ``only_cmd = True``.
|
1842 |
+
|
1843 |
+
</details>
|
1844 |
+
|
1845 |
+
#### Multiple Files with CLI
|
1846 |
+
Transcribe multiple audio files then process the results directly into SRT files.
|
1847 |
+
```commandline
|
1848 |
+
stable-ts audio1.mp3 audio2.mp3 audio3.mp3 -o audio1.srt audio2.srt audio3.srt
|
1849 |
+
```
|
1850 |
+
|
1851 |
+
### Any ASR
|
1852 |
+
You can use most of the features of Stable-ts improve the results of any ASR model/APIs.
|
1853 |
+
[Just follow this notebook](https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb).
|
1854 |
+
|
1855 |
+
## Quick 1.X → 2.X Guide
|
1856 |
+
### What's new in 2.0.0?
|
1857 |
+
- updated to use Whisper's more reliable word-level timestamps method.
|
1858 |
+
- the more reliable word timestamps allow regrouping all words into segments with more natural boundaries.
|
1859 |
+
- can now suppress silence with [Silero VAD](https://github.com/snakers4/silero-vad) (requires PyTorch 1.12.0+)
|
1860 |
+
- non-VAD silence suppression is also more robust
|
1861 |
+
### Usage changes
|
1862 |
+
- `results_to_sentence_srt(result, 'audio.srt')` → `result.to_srt_vtt('audio.srt', word_level=False)`
|
1863 |
+
- `results_to_word_srt(result, 'audio.srt')` → `result.to_srt_vtt('output.srt', segment_level=False)`
|
1864 |
+
- `results_to_sentence_word_ass(result, 'audio.srt')` → `result.to_ass('output.ass')`
|
1865 |
+
- there's no need to stabilize segments after inference because they're already stabilized during inference
|
1866 |
+
- `transcribe()` returns a `WhisperResult` object which can be converted to `dict` with `.to_dict()`. e.g `result.to_dict()`
|
1867 |
+
|
1868 |
+
## License
|
1869 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details
|
1870 |
+
|
1871 |
+
## Acknowledgments
|
1872 |
+
Includes slight modification of the original work: [Whisper](https://github.com/openai/whisper)
|
examples/non-whisper.ipynb
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "13dc05a3-de12-4d7a-a926-e99d6d97826e",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Using Stable-ts with any ASR"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "5cfee322-ebca-4c23-87a4-a109a2f85203",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import stable_whisper\n",
|
19 |
+
"assert int(stable_whisper.__version__.replace('.', '')) >= 270, f\"Requires Stable-ts 2.7.0+. Current version is {stable_whisper.__version__}.\""
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"id": "e6c2dab2-f4df-46f9-b2e8-94dd88522c7d",
|
25 |
+
"metadata": {},
|
26 |
+
"source": [
|
27 |
+
"<br />\n",
|
28 |
+
"\n",
|
29 |
+
"Stable-ts can be used for other ASR models or web APIs by wrapping them as a function then passing it as the first argument to `non_whisper.transcribe_any()`."
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 2,
|
35 |
+
"id": "7d32fa9f-a54c-4996-97c3-3b360230d029",
|
36 |
+
"metadata": {
|
37 |
+
"tags": []
|
38 |
+
},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"def inference(audio, **kwargs) -> dict:\n",
|
42 |
+
" # run model/API \n",
|
43 |
+
" # return data as a dictionary\n",
|
44 |
+
" data = {}\n",
|
45 |
+
" return data"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "856ef1fd-f489-42af-a90c-97323fd05a6b",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"The data returned by the function must be one of the following:\n",
|
54 |
+
"- an instance of `WhisperResult` containing the data\n",
|
55 |
+
"- a dictionary in an appropriate mapping\n",
|
56 |
+
"- a path of JSON file containing data in an appropriate mapping"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"id": "bbdebdad-af1d-4077-8e99-20e767a0fd91",
|
62 |
+
"metadata": {},
|
63 |
+
"source": [
|
64 |
+
"Here are the 3 types of mappings:"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": 3,
|
70 |
+
"id": "06bc4ce7-5117-4674-8eb9-c343c13c18bc",
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [],
|
73 |
+
"source": [
|
74 |
+
"#1:\n",
|
75 |
+
"essential_mapping = [\n",
|
76 |
+
" [ # 1st Segment\n",
|
77 |
+
" {'word': ' And', 'start': 0.0, 'end': 1.28}, \n",
|
78 |
+
" {'word': ' when', 'start': 1.28, 'end': 1.52}, \n",
|
79 |
+
" {'word': ' no', 'start': 1.52, 'end': 2.26}, \n",
|
80 |
+
" {'word': ' ocean,', 'start': 2.26, 'end': 2.68},\n",
|
81 |
+
" {'word': ' mountain,', 'start': 3.28, 'end': 3.58}\n",
|
82 |
+
" ], \n",
|
83 |
+
" [ # 2nd Segment\n",
|
84 |
+
" {'word': ' or', 'start': 4.0, 'end': 4.08}, \n",
|
85 |
+
" {'word': ' sky', 'start': 4.08, 'end': 4.56}, \n",
|
86 |
+
" {'word': ' could', 'start': 4.56, 'end': 4.84}, \n",
|
87 |
+
" {'word': ' contain', 'start': 4.84, 'end': 5.26}, \n",
|
88 |
+
" {'word': ' us,', 'start': 5.26, 'end': 6.27},\n",
|
89 |
+
" {'word': ' our', 'start': 6.27, 'end': 6.58}, \n",
|
90 |
+
" {'word': ' gaze', 'start': 6.58, 'end': 6.98}, \n",
|
91 |
+
" {'word': ' hungered', 'start': 6.98, 'end': 7.88}, \n",
|
92 |
+
" {'word': ' starward.', 'start': 7.88, 'end': 8.64}\n",
|
93 |
+
" ]\n",
|
94 |
+
"]"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "markdown",
|
99 |
+
"id": "b53bd812-2838-4f47-ab5f-5e729801aaee",
|
100 |
+
"metadata": {},
|
101 |
+
"source": [
|
102 |
+
"<br />\n",
|
103 |
+
"\n",
|
104 |
+
"If word timings are not available they can be omitted, but operations that can be performed on this data will be limited."
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 4,
|
110 |
+
"id": "8c6bf720-5bfd-4e79-90e7-7049a2ca1d3a",
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [],
|
113 |
+
"source": [
|
114 |
+
"#2:\n",
|
115 |
+
"no_word_mapping = [\n",
|
116 |
+
" {\n",
|
117 |
+
" 'start': 0.0, \n",
|
118 |
+
" 'end': 3.58, \n",
|
119 |
+
" 'text': ' And when no ocean, mountain,',\n",
|
120 |
+
" }, \n",
|
121 |
+
" {\n",
|
122 |
+
" 'start': 4.0, \n",
|
123 |
+
" 'end': 8.64, \n",
|
124 |
+
" 'text': ' or sky could contain us, our gaze hungered starward.', \n",
|
125 |
+
" }\n",
|
126 |
+
"]"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "markdown",
|
131 |
+
"id": "108e960f-8bd1-4d2a-92bf-cc8cb56f4615",
|
132 |
+
"metadata": {},
|
133 |
+
"source": [
|
134 |
+
"<br />\n",
|
135 |
+
"\n",
|
136 |
+
"Below is the full mapping for normal Stable-ts results. `None` takes the place of any omitted values except for `start`, `end`, and `text`/`word` which are required."
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 5,
|
142 |
+
"id": "2969aad2-c8bf-4043-8015-669a3102e158",
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [],
|
145 |
+
"source": [
|
146 |
+
"#3:\n",
|
147 |
+
"full_mapping = {\n",
|
148 |
+
" 'language': 'en',\n",
|
149 |
+
" 'text': ' And when no ocean, mountain, or sky could contain us, our gaze hungered starward.', \n",
|
150 |
+
" 'segments': [\n",
|
151 |
+
" {\n",
|
152 |
+
" 'seek': 0.0, \n",
|
153 |
+
" 'start': 0.0, \n",
|
154 |
+
" 'end': 3.58, \n",
|
155 |
+
" 'text': ' And when no ocean, mountain,', \n",
|
156 |
+
" 'tokens': [400, 562, 572, 7810, 11, 6937, 11], \n",
|
157 |
+
" 'temperature': 0.0, \n",
|
158 |
+
" 'avg_logprob': -0.48702024376910663, \n",
|
159 |
+
" 'compression_ratio': 1.0657894736842106, \n",
|
160 |
+
" 'no_speech_prob': 0.3386174440383911, \n",
|
161 |
+
" 'id': 0, \n",
|
162 |
+
" 'words': [\n",
|
163 |
+
" {'word': ' And', 'start': 0.04, 'end': 1.28, 'probability': 0.6481522917747498, 'tokens': [400]}, \n",
|
164 |
+
" {'word': ' when', 'start': 1.28, 'end': 1.52, 'probability': 0.9869539141654968, 'tokens': [562]}, \n",
|
165 |
+
" {'word': ' no', 'start': 1.52, 'end': 2.26, 'probability': 0.57384192943573, 'tokens': [572]}, \n",
|
166 |
+
" {'word': ' ocean,', 'start': 2.26, 'end': 2.68, 'probability': 0.9484889507293701, 'tokens': [7810, 11]},\n",
|
167 |
+
" {'word': ' mountain,', 'start': 3.28, 'end': 3.58, 'probability': 0.9581122398376465, 'tokens': [6937, 11]}\n",
|
168 |
+
" ]\n",
|
169 |
+
" }, \n",
|
170 |
+
" {\n",
|
171 |
+
" 'seek': 0.0, \n",
|
172 |
+
" 'start': 4.0, \n",
|
173 |
+
" 'end': 8.64, \n",
|
174 |
+
" 'text': ' or sky could contain us, our gaze hungered starward.', \n",
|
175 |
+
" 'tokens': [420, 5443, 727, 5304, 505, 11, 527, 24294, 5753, 4073, 3543, 1007, 13], \n",
|
176 |
+
" 'temperature': 0.0, \n",
|
177 |
+
" 'avg_logprob': -0.48702024376910663, \n",
|
178 |
+
" 'compression_ratio': 1.0657894736842106, \n",
|
179 |
+
" 'no_speech_prob': 0.3386174440383911, \n",
|
180 |
+
" 'id': 1, \n",
|
181 |
+
" 'words': [\n",
|
182 |
+
" {'word': ' or', 'start': 4.0, 'end': 4.08, 'probability': 0.9937937259674072, 'tokens': [420]}, \n",
|
183 |
+
" {'word': ' sky', 'start': 4.08, 'end': 4.56, 'probability': 0.9950089454650879, 'tokens': [5443]}, \n",
|
184 |
+
" {'word': ' could', 'start': 4.56, 'end': 4.84, 'probability': 0.9915681481361389, 'tokens': [727]}, \n",
|
185 |
+
" {'word': ' contain', 'start': 4.84, 'end': 5.26, 'probability': 0.898974597454071, 'tokens': [5304]}, \n",
|
186 |
+
" {'word': ' us,', 'start': 5.26, 'end': 6.27, 'probability': 0.999351441860199, 'tokens': [505, 11]},\n",
|
187 |
+
" {'word': ' our', 'start': 6.27, 'end': 6.58, 'probability': 0.9634224772453308, 'tokens': [527]}, \n",
|
188 |
+
" {'word': ' gaze', 'start': 6.58, 'end': 6.98, 'probability': 0.8934874534606934, 'tokens': [24294]}, \n",
|
189 |
+
" {'word': ' hungered', 'start': 6.98, 'end': 7.88, 'probability': 0.7424876093864441, 'tokens': [5753, 4073]}, \n",
|
190 |
+
" {'word': ' starward.', 'start': 7.88, 'end': 8.64, 'probability': 0.464096799492836, 'tokens': [3543, 1007, 13]}\n",
|
191 |
+
" ]\n",
|
192 |
+
" }\n",
|
193 |
+
" ]\n",
|
194 |
+
"}"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "markdown",
|
199 |
+
"id": "49d136e4-0f7d-4dcf-84f9-efb6f0eda491",
|
200 |
+
"metadata": {},
|
201 |
+
"source": [
|
202 |
+
"<br />\n",
|
203 |
+
"\n",
|
204 |
+
"The function must also have `audio` as a parameter."
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 6,
|
210 |
+
"id": "33f03286-69f9-4ae1-aec0-250fd92a8cb6",
|
211 |
+
"metadata": {
|
212 |
+
"tags": []
|
213 |
+
},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"def inference(audio, **kwargs) -> dict:\n",
|
217 |
+
" # run model/API on the audio\n",
|
218 |
+
" # return data in a proper format\n",
|
219 |
+
" return essential_mapping"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "code",
|
224 |
+
"execution_count": 7,
|
225 |
+
"id": "d6710eb5-5386-42cf-b6e7-02a84b5fad40",
|
226 |
+
"metadata": {
|
227 |
+
"tags": []
|
228 |
+
},
|
229 |
+
"outputs": [],
|
230 |
+
"source": [
|
231 |
+
"result = stable_whisper.transcribe_any(inference, './demo.wav', vad=True)"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": 8,
|
237 |
+
"id": "6d7f9de6-5c9b-4c73-808d-640b13efb051",
|
238 |
+
"metadata": {},
|
239 |
+
"outputs": [
|
240 |
+
{
|
241 |
+
"name": "stdout",
|
242 |
+
"output_type": "stream",
|
243 |
+
"text": [
|
244 |
+
"0\n",
|
245 |
+
"00:00:01,122 --> 00:00:02,680\n",
|
246 |
+
"And when no ocean,\n",
|
247 |
+
"\n",
|
248 |
+
"1\n",
|
249 |
+
"00:00:03,280 --> 00:00:03,580\n",
|
250 |
+
"mountain,\n",
|
251 |
+
"\n",
|
252 |
+
"2\n",
|
253 |
+
"00:00:04,000 --> 00:00:06,046\n",
|
254 |
+
"or sky could contain us,\n",
|
255 |
+
"\n",
|
256 |
+
"3\n",
|
257 |
+
"00:00:06,402 --> 00:00:08,640\n",
|
258 |
+
"our gaze hungered starward.\n"
|
259 |
+
]
|
260 |
+
}
|
261 |
+
],
|
262 |
+
"source": [
|
263 |
+
"print(result.to_srt_vtt(word_level=False))"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"cell_type": "code",
|
268 |
+
"execution_count": 9,
|
269 |
+
"id": "be5a45e8-1b25-4a70-9af6-94bc5379fc7d",
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [
|
272 |
+
{
|
273 |
+
"name": "stdout",
|
274 |
+
"output_type": "stream",
|
275 |
+
"text": [
|
276 |
+
"\n",
|
277 |
+
" Transcribe an audio file using any ASR system.\n",
|
278 |
+
"\n",
|
279 |
+
" Parameters\n",
|
280 |
+
" ----------\n",
|
281 |
+
" inference_func: Callable\n",
|
282 |
+
" Function that runs ASR when provided the [audio] and return data in the appropriate format.\n",
|
283 |
+
" For format examples: https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb\n",
|
284 |
+
"\n",
|
285 |
+
" audio: Union[str, np.ndarray, torch.Tensor, bytes]\n",
|
286 |
+
" The path/URL to the audio file, the audio waveform, or bytes of audio file.\n",
|
287 |
+
"\n",
|
288 |
+
" audio_type: str\n",
|
289 |
+
" The type that [audio] needs to be for [inference_func]. (Default: Same type as [audio])\n",
|
290 |
+
"\n",
|
291 |
+
" Types:\n",
|
292 |
+
" None (default)\n",
|
293 |
+
" same type as [audio]\n",
|
294 |
+
"\n",
|
295 |
+
" 'str'\n",
|
296 |
+
" a path to the file\n",
|
297 |
+
" -if [audio] is a file and not audio preprocessing is done,\n",
|
298 |
+
" [audio] will be directly passed into [inference_func]\n",
|
299 |
+
" -if audio preprocessing is performed (from [demucs] and/or [only_voice_freq]),\n",
|
300 |
+
" the processed audio will be encoded into [temp_file] and then passed into [inference_func]\n",
|
301 |
+
"\n",
|
302 |
+
" 'byte'\n",
|
303 |
+
" bytes (used for APIs or to avoid writing any data to hard drive)\n",
|
304 |
+
" -if [audio] is file, the bytes of file is used\n",
|
305 |
+
" -if [audio] PyTorch tensor or NumPy array, the bytes of the [audio] encoded into WAV format is used\n",
|
306 |
+
"\n",
|
307 |
+
" 'torch'\n",
|
308 |
+
" a PyTorch tensor containing the audio waveform, in float32 dtype, on CPU\n",
|
309 |
+
"\n",
|
310 |
+
" 'numpy'\n",
|
311 |
+
" a NumPy array containing the audio waveform, in float32 dtype\n",
|
312 |
+
"\n",
|
313 |
+
" input_sr: int\n",
|
314 |
+
" The sample rate of [audio]. (Default: Auto-detected if [audio] is str/bytes)\n",
|
315 |
+
"\n",
|
316 |
+
" model_sr: int\n",
|
317 |
+
" The sample rate to resample the audio into for [inference_func]. (Default: Same as [input_sr])\n",
|
318 |
+
" Resampling is only performed when [model_sr] do not match the sample rate of the final audio due to:\n",
|
319 |
+
" -[input_sr] not matching\n",
|
320 |
+
" -sample rate changed due to audio preprocessing from [demucs]=True\n",
|
321 |
+
"\n",
|
322 |
+
" inference_kwargs: dict\n",
|
323 |
+
" Dictionary of arguments provided to [inference_func]. (Default: None)\n",
|
324 |
+
"\n",
|
325 |
+
" temp_file: str\n",
|
326 |
+
" Temporary path for the preprocessed audio when [audio_type]='str'. (Default: './_temp_stable-ts_audio_.wav')\n",
|
327 |
+
"\n",
|
328 |
+
" verbose: bool\n",
|
329 |
+
" Whether to display the text being decoded to the console. If True, displays all the details,\n",
|
330 |
+
" If False, displays progressbar. If None, does not display anything (Default: False)\n",
|
331 |
+
"\n",
|
332 |
+
" regroup: Union[bool, str]\n",
|
333 |
+
" Whether to regroup all words into segments with more natural boundaries. (Default: True)\n",
|
334 |
+
" Specify string for customizing the regrouping algorithm.\n",
|
335 |
+
" Ignored if [word_timestamps]=False.\n",
|
336 |
+
"\n",
|
337 |
+
" suppress_silence: bool\n",
|
338 |
+
" Whether to suppress timestamp where audio is silent at segment-level\n",
|
339 |
+
" and word-level if [suppress_word_ts]=True. (Default: True)\n",
|
340 |
+
"\n",
|
341 |
+
" suppress_word_ts: bool\n",
|
342 |
+
" Whether to suppress timestamps, if [suppress_silence]=True, where audio is silent at word-level. (Default: True)\n",
|
343 |
+
"\n",
|
344 |
+
" q_levels: int\n",
|
345 |
+
" Quantization levels for generating timestamp suppression mask; ignored if [vad]=true. (Default: 20)\n",
|
346 |
+
" Acts as a threshold to marking sound as silent.\n",
|
347 |
+
" Fewer levels will increase the threshold of volume at which to mark a sound as silent.\n",
|
348 |
+
"\n",
|
349 |
+
" k_size: int\n",
|
350 |
+
" Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if [vad]=true. (Default: 5)\n",
|
351 |
+
" Recommend 5 or 3; higher sizes will reduce detection of silence.\n",
|
352 |
+
"\n",
|
353 |
+
" demucs: bool\n",
|
354 |
+
" Whether to preprocess the audio track with Demucs to isolate vocals/remove noise. (Default: False)\n",
|
355 |
+
" Demucs must be installed to use. Official repo: https://github.com/facebookresearch/demucs\n",
|
356 |
+
"\n",
|
357 |
+
" demucs_device: str\n",
|
358 |
+
" Device to use for demucs: 'cuda' or 'cpu'. (Default. 'cuda' if torch.cuda.is_available() else 'cpu')\n",
|
359 |
+
"\n",
|
360 |
+
" demucs_output: str\n",
|
361 |
+
" Path to save the vocals isolated by Demucs as WAV file. Ignored if [demucs]=False.\n",
|
362 |
+
" Demucs must be installed to use. Official repo: https://github.com/facebookresearch/demucs\n",
|
363 |
+
"\n",
|
364 |
+
" vad: bool\n",
|
365 |
+
" Whether to use Silero VAD to generate timestamp suppression mask. (Default: False)\n",
|
366 |
+
" Silero VAD requires PyTorch 1.12.0+. Official repo: https://github.com/snakers4/silero-vad\n",
|
367 |
+
"\n",
|
368 |
+
" vad_threshold: float\n",
|
369 |
+
" Threshold for detecting speech with Silero VAD. (Default: 0.35)\n",
|
370 |
+
" Low threshold reduces false positives for silence detection.\n",
|
371 |
+
"\n",
|
372 |
+
" vad_onnx: bool\n",
|
373 |
+
" Whether to use ONNX for Silero VAD. (Default: False)\n",
|
374 |
+
"\n",
|
375 |
+
" min_word_dur: float\n",
|
376 |
+
" Only allow suppressing timestamps that result in word durations greater than this value. (default: 0.1)\n",
|
377 |
+
"\n",
|
378 |
+
" only_voice_freq: bool\n",
|
379 |
+
" Whether to only use sound between 200 - 5000 Hz, where majority of human speech are. (Default: False)\n",
|
380 |
+
"\n",
|
381 |
+
" only_ffmpeg: bool\n",
|
382 |
+
" Whether to use only FFmpeg (and not yt-dlp) for URls. (Default: False)\n",
|
383 |
+
"\n",
|
384 |
+
" Returns\n",
|
385 |
+
" -------\n",
|
386 |
+
" An instance of WhisperResult.\n",
|
387 |
+
" \n"
|
388 |
+
]
|
389 |
+
}
|
390 |
+
],
|
391 |
+
"source": [
|
392 |
+
"print(stable_whisper.transcribe_any.__doc__)"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"cell_type": "code",
|
397 |
+
"execution_count": null,
|
398 |
+
"id": "a99ee627-6ab4-411d-ba27-d372d3647593",
|
399 |
+
"metadata": {},
|
400 |
+
"outputs": [],
|
401 |
+
"source": []
|
402 |
+
}
|
403 |
+
],
|
404 |
+
"metadata": {
|
405 |
+
"kernelspec": {
|
406 |
+
"display_name": "Python 3 (ipykernel)",
|
407 |
+
"language": "python",
|
408 |
+
"name": "python3"
|
409 |
+
},
|
410 |
+
"language_info": {
|
411 |
+
"codemirror_mode": {
|
412 |
+
"name": "ipython",
|
413 |
+
"version": 3
|
414 |
+
},
|
415 |
+
"file_extension": ".py",
|
416 |
+
"mimetype": "text/x-python",
|
417 |
+
"name": "python",
|
418 |
+
"nbconvert_exporter": "python",
|
419 |
+
"pygments_lexer": "ipython3",
|
420 |
+
"version": "3.8.15"
|
421 |
+
}
|
422 |
+
},
|
423 |
+
"nbformat": 4,
|
424 |
+
"nbformat_minor": 5
|
425 |
+
}
|
setup.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
|
4 |
+
|
5 |
+
def version() -> str:
|
6 |
+
with open(os.path.join(os.path.dirname(__file__), 'stable_whisper/_version.py')) as f:
|
7 |
+
return f.read().split('=')[-1].strip().strip('"').strip("'")
|
8 |
+
|
9 |
+
|
10 |
+
def read_me() -> str:
|
11 |
+
with open('README.md', 'r', encoding='utf-8') as f:
|
12 |
+
return f.read()
|
13 |
+
|
14 |
+
|
15 |
+
setup(
|
16 |
+
name="stable-ts",
|
17 |
+
version=version(),
|
18 |
+
description="Modifies OpenAI's Whisper to produce more reliable timestamps.",
|
19 |
+
long_description=read_me(),
|
20 |
+
long_description_content_type='text/markdown',
|
21 |
+
python_requires=">=3.8",
|
22 |
+
author="Jian",
|
23 |
+
url="https://github.com/jianfch/stable-ts",
|
24 |
+
license="MIT",
|
25 |
+
packages=['stable_whisper'],
|
26 |
+
install_requires=[
|
27 |
+
"numpy",
|
28 |
+
"torch",
|
29 |
+
"torchaudio",
|
30 |
+
"tqdm",
|
31 |
+
"more-itertools",
|
32 |
+
"transformers>=4.19.0",
|
33 |
+
"ffmpeg-python==0.2.0",
|
34 |
+
"openai-whisper==20231117"
|
35 |
+
],
|
36 |
+
entry_points={
|
37 |
+
"console_scripts": ["stable-ts=stable_whisper.whisper_word_level:cli"],
|
38 |
+
},
|
39 |
+
include_package_data=False
|
40 |
+
)
|
silence_suppresion0.png
ADDED
![]() |
silence_suppresion1.png
ADDED
![]() |
stable_whisper/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .whisper_word_level import *
|
2 |
+
from .result import *
|
3 |
+
from .text_output import *
|
4 |
+
from .video_output import *
|
5 |
+
from .stabilization import visualize_suppression
|
6 |
+
from .non_whisper import transcribe_any
|
7 |
+
from ._version import __version__
|
8 |
+
from .whisper_compatibility import _required_whisper_ver, _COMPATIBLE_WHISPER_VERSIONS
|
stable_whisper/__main__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .whisper_word_level import cli
|
2 |
+
|
3 |
+
cli()
|
stable_whisper/_version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "2.14.3"
|
stable_whisper/alignment.py
ADDED
@@ -0,0 +1,1265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import re
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from typing import TYPE_CHECKING, Union, List, Callable, Optional, Tuple
|
9 |
+
|
10 |
+
import whisper
|
11 |
+
from whisper.audio import (
|
12 |
+
SAMPLE_RATE, N_FRAMES, N_SAMPLES, N_FFT, pad_or_trim, log_mel_spectrogram, FRAMES_PER_SECOND, CHUNK_LENGTH
|
13 |
+
)
|
14 |
+
|
15 |
+
from .result import WhisperResult, Segment
|
16 |
+
from .timing import add_word_timestamps_stable, split_word_tokens
|
17 |
+
from .audio import prep_audio
|
18 |
+
from .utils import safe_print, format_timestamp
|
19 |
+
from .whisper_compatibility import warn_compatibility_issues, get_tokenizer
|
20 |
+
from .stabilization import get_vad_silence_func, wav2mask, mask2timing
|
21 |
+
|
22 |
+
if TYPE_CHECKING:
|
23 |
+
from whisper.model import Whisper
|
24 |
+
|
25 |
+
__all__ = ['align', 'refine', 'locate']
|
26 |
+
|
27 |
+
|
28 |
+
def align(
|
29 |
+
model: "Whisper",
|
30 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
31 |
+
text: Union[str, List[int], WhisperResult],
|
32 |
+
language: str = None,
|
33 |
+
*,
|
34 |
+
verbose: Optional[bool] = False,
|
35 |
+
regroup: bool = True,
|
36 |
+
suppress_silence: bool = True,
|
37 |
+
suppress_word_ts: bool = True,
|
38 |
+
use_word_position: bool = True,
|
39 |
+
min_word_dur: bool = 0.1,
|
40 |
+
nonspeech_error: float = 0.3,
|
41 |
+
q_levels: int = 20,
|
42 |
+
k_size: int = 5,
|
43 |
+
vad: bool = False,
|
44 |
+
vad_threshold: float = 0.35,
|
45 |
+
vad_onnx: bool = False,
|
46 |
+
demucs: Union[bool, torch.nn.Module] = False,
|
47 |
+
demucs_output: str = None,
|
48 |
+
demucs_options: dict = None,
|
49 |
+
only_voice_freq: bool = False,
|
50 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
51 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
52 |
+
progress_callback: Callable = None,
|
53 |
+
ignore_compatibility: bool = False,
|
54 |
+
remove_instant_words: bool = False,
|
55 |
+
token_step: int = 100,
|
56 |
+
original_split: bool = False,
|
57 |
+
word_dur_factor: Optional[float] = 2.0,
|
58 |
+
max_word_dur: Optional[float] = 3.0,
|
59 |
+
nonspeech_skip: Optional[float] = 3.0,
|
60 |
+
fast_mode: bool = False,
|
61 |
+
tokenizer: "Tokenizer" = None
|
62 |
+
) -> Union[WhisperResult, None]:
|
63 |
+
"""
|
64 |
+
Align plain text or tokens with audio at word-level.
|
65 |
+
|
66 |
+
Since this is significantly faster than transcribing, it is a more efficient method for testing various settings
|
67 |
+
without re-transcribing. This is also useful for timing a more correct transcript than one that Whisper can produce.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
model : "Whisper"
|
72 |
+
The Whisper ASR model modified instance
|
73 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
74 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
75 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
76 |
+
text : str or list of int or stable_whisper.result.WhisperResult
|
77 |
+
String of plain-text, list of tokens, or instance of :class:`stable_whisper.result.WhisperResult`.
|
78 |
+
language : str, default None, uses ``language`` in ``text`` if it is a :class:`stable_whisper.result.WhisperResult`
|
79 |
+
Language of ``text``. Required if ``text`` does not contain ``language``.
|
80 |
+
remove_instant_words : bool, default False
|
81 |
+
Whether to truncate any words with zero duration.
|
82 |
+
token_step : int, default 100
|
83 |
+
Max number of tokens to align each pass. Use higher values to reduce chance of misalignment.
|
84 |
+
original_split : bool, default False
|
85 |
+
Whether to preserve the original segment groupings. Segments are spit by line break if ``text`` is plain-text.
|
86 |
+
max_word_dur : float or None, default 3.0
|
87 |
+
Global maximum word duration in seconds. Re-align words that exceed the global maximum word duration.
|
88 |
+
word_dur_factor : float or None, default 2.0
|
89 |
+
Factor to compute the Local maximum word duration, which is ``word_dur_factor`` * local medium word duration.
|
90 |
+
Words that need re-alignment, are re-algined with duration <= local/global maximum word duration.
|
91 |
+
nonspeech_skip : float or None, default 3.0
|
92 |
+
Skip non-speech sections that are equal or longer than this duration in seconds. Disable skipping if ``None``.
|
93 |
+
fast_mode : bool, default False
|
94 |
+
Whether to speed up alignment by re-alignment with local/global maximum word duration.
|
95 |
+
``True`` tends produce better timestamps when ``text`` is accurate and there are no large speechless gaps.
|
96 |
+
tokenizer : "Tokenizer", default None, meaning a new tokenizer is created according ``language`` and ``model``
|
97 |
+
A tokenizer to used tokenizer text and detokenize tokens.
|
98 |
+
verbose : bool or None, default False
|
99 |
+
Whether to display the text being decoded to the console.
|
100 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
101 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
102 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
103 |
+
Ignored if ``word_timestamps = False``.
|
104 |
+
suppress_silence : bool, default True
|
105 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
106 |
+
suppress_word_ts : bool, default True
|
107 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
108 |
+
use_word_position : bool, default True
|
109 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
110 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
111 |
+
q_levels : int, default 20
|
112 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
113 |
+
Acts as a threshold to marking sound as silent.
|
114 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
115 |
+
k_size : int, default 5
|
116 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
117 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
118 |
+
demucs : bool or torch.nn.Module, default False
|
119 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
120 |
+
a Demucs model to avoid reloading the model for each run.
|
121 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
122 |
+
demucs_output : str, optional
|
123 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
124 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
125 |
+
demucs_options : dict, optional
|
126 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
127 |
+
vad : bool, default False
|
128 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
129 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
130 |
+
vad_threshold : float, default 0.35
|
131 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
132 |
+
vad_onnx : bool, default False
|
133 |
+
Whether to use ONNX for Silero VAD.
|
134 |
+
min_word_dur : float, default 0.1
|
135 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
136 |
+
nonspeech_error : float, default 0.3
|
137 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
138 |
+
only_voice_freq : bool, default False
|
139 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
140 |
+
prepend_punctuations : str, default '"'“¿([{-)'
|
141 |
+
Punctuations to prepend to next word.
|
142 |
+
append_punctuations : str, default '.。,,!!??::”)]}、)'
|
143 |
+
Punctuations to append to previous word.
|
144 |
+
progress_callback : Callable, optional
|
145 |
+
A function that will be called when transcription progress is updated.
|
146 |
+
The callback need two parameters.
|
147 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
148 |
+
The second parameter is a float for total duration of audio in seconds.
|
149 |
+
ignore_compatibility : bool, default False
|
150 |
+
Whether to ignore warnings for compatibility issues with the detected Whisper version.
|
151 |
+
|
152 |
+
Returns
|
153 |
+
-------
|
154 |
+
stable_whisper.result.WhisperResult or None
|
155 |
+
All timestamps, words, probabilities, and other data from the alignment of ``audio``. Return None if alignment
|
156 |
+
fails and ``remove_instant_words = True``.
|
157 |
+
|
158 |
+
Notes
|
159 |
+
-----
|
160 |
+
If ``token_step`` is less than 1, ``token_step`` will be set to its maximum value, 442. This value is computed with
|
161 |
+
``whisper.model.Whisper.dims.n_text_ctx`` - 6.
|
162 |
+
|
163 |
+
IF ``original_split = True`` and a line break is found in middle of a word in ``text``, the split will occur after
|
164 |
+
that word.
|
165 |
+
|
166 |
+
``regroup`` is ignored if ``original_split = True``.
|
167 |
+
|
168 |
+
Examples
|
169 |
+
--------
|
170 |
+
>>> import stable_whisper
|
171 |
+
>>> model = stable_whisper.load_model('base')
|
172 |
+
>>> result = model.align('helloworld.mp3', 'Hello, World!', 'English')
|
173 |
+
>>> result.to_srt_vtt('helloword.srt')
|
174 |
+
Saved 'helloworld.srt'
|
175 |
+
"""
|
176 |
+
is_faster_model = model.__module__.startswith('faster_whisper.')
|
177 |
+
if demucs_options is None:
|
178 |
+
demucs_options = {}
|
179 |
+
if demucs_output:
|
180 |
+
if 'save_path' not in demucs_options:
|
181 |
+
demucs_options['save_path'] = demucs_output
|
182 |
+
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
|
183 |
+
'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
|
184 |
+
DeprecationWarning, stacklevel=2)
|
185 |
+
max_token_step = (model.max_length if is_faster_model else model.dims.n_text_ctx) - 6
|
186 |
+
if token_step < 1:
|
187 |
+
token_step = max_token_step
|
188 |
+
elif token_step > max_token_step:
|
189 |
+
raise ValueError(f'The max value for [token_step] is {max_token_step} but got {token_step}.')
|
190 |
+
|
191 |
+
warn_compatibility_issues(whisper, ignore_compatibility)
|
192 |
+
split_indices_by_char = []
|
193 |
+
if isinstance(text, WhisperResult):
|
194 |
+
if language is None:
|
195 |
+
language = text.language
|
196 |
+
if original_split and len(text.segments) > 1 and text.has_words:
|
197 |
+
split_indices_by_char = np.cumsum([sum(len(w.word) for w in seg.words) for seg in text.segments])
|
198 |
+
text = text.all_tokens() if text.has_words and all(w.tokens for w in text.all_words()) else text.text
|
199 |
+
elif isinstance(text, str):
|
200 |
+
if original_split and '\n' in text:
|
201 |
+
text_split = [line if line.startswith(' ') else ' '+line for line in text.splitlines()]
|
202 |
+
split_indices_by_char = np.cumsum([len(seg) for seg in text_split])
|
203 |
+
text = ''.join(re.sub(r'\s', ' ', seg) for seg in text_split)
|
204 |
+
else:
|
205 |
+
text = re.sub(r'\s', ' ', text)
|
206 |
+
if not text.startswith(' '):
|
207 |
+
text = ' ' + text
|
208 |
+
if language is None:
|
209 |
+
raise TypeError('expected argument for language')
|
210 |
+
if tokenizer is None:
|
211 |
+
tokenizer = get_tokenizer(model, is_faster_model=is_faster_model, language=language, task='transcribe')
|
212 |
+
tokens = tokenizer.encode(text) if isinstance(text, str) else text
|
213 |
+
tokens = [t for t in tokens if t < tokenizer.eot]
|
214 |
+
_, (words, word_tokens), _ = split_word_tokens([dict(tokens=tokens)], tokenizer)
|
215 |
+
|
216 |
+
audio = prep_audio(
|
217 |
+
audio,
|
218 |
+
demucs=demucs,
|
219 |
+
demucs_options=demucs_options,
|
220 |
+
only_voice_freq=only_voice_freq,
|
221 |
+
verbose=verbose
|
222 |
+
)
|
223 |
+
|
224 |
+
sample_padding = int(N_FFT // 2) + 1
|
225 |
+
seek_sample = 0
|
226 |
+
total_samples = audio.shape[-1]
|
227 |
+
total_duration = round(total_samples / SAMPLE_RATE, 2)
|
228 |
+
total_words = len(words)
|
229 |
+
|
230 |
+
if is_faster_model:
|
231 |
+
def timestamp_words():
|
232 |
+
temp_segment = dict(
|
233 |
+
seek=0,
|
234 |
+
start=0.0,
|
235 |
+
end=round(segment_samples / model.feature_extractor.sampling_rate, 3),
|
236 |
+
tokens=[t for wt in curr_word_tokens for t in wt],
|
237 |
+
)
|
238 |
+
features = model.feature_extractor(audio_segment.numpy())
|
239 |
+
encoder_output = model.encode(features[:, : model.feature_extractor.nb_max_frames])
|
240 |
+
|
241 |
+
model.add_word_timestamps(
|
242 |
+
segments=[temp_segment],
|
243 |
+
tokenizer=tokenizer,
|
244 |
+
encoder_output=encoder_output,
|
245 |
+
num_frames=round(segment_samples / model.feature_extractor.hop_length),
|
246 |
+
prepend_punctuations=prepend_punctuations,
|
247 |
+
append_punctuations=append_punctuations,
|
248 |
+
last_speech_timestamp=temp_segment['start'],
|
249 |
+
)
|
250 |
+
|
251 |
+
cumsum_lens = np.cumsum([len(w) for w in curr_words]).tolist()
|
252 |
+
final_cumsum_lens = np.cumsum([len(w['word']) for w in temp_segment['words']]).tolist()
|
253 |
+
|
254 |
+
assert not (set(final_cumsum_lens) - set(cumsum_lens)), 'word mismatch'
|
255 |
+
prev_l_idx = 0
|
256 |
+
for w_idx, cs_len in enumerate(final_cumsum_lens):
|
257 |
+
temp_segment['words'][w_idx]['start'] = round(temp_segment['words'][w_idx]['start'] + time_offset, 3)
|
258 |
+
temp_segment['words'][w_idx]['end'] = round(temp_segment['words'][w_idx]['end'] + time_offset, 3)
|
259 |
+
l_idx = cumsum_lens.index(cs_len)+1
|
260 |
+
temp_segment['words'][w_idx]['tokens'] = [t for wt in curr_word_tokens[prev_l_idx:l_idx] for t in wt]
|
261 |
+
prev_l_idx = l_idx
|
262 |
+
|
263 |
+
return temp_segment
|
264 |
+
|
265 |
+
else:
|
266 |
+
def timestamp_words():
|
267 |
+
temp_segment = dict(
|
268 |
+
seek=time_offset,
|
269 |
+
tokens=(curr_words, curr_word_tokens)
|
270 |
+
)
|
271 |
+
|
272 |
+
mel_segment = log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
|
273 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(device=model.device)
|
274 |
+
|
275 |
+
add_word_timestamps_stable(
|
276 |
+
segments=[temp_segment],
|
277 |
+
model=model,
|
278 |
+
tokenizer=tokenizer,
|
279 |
+
mel=mel_segment,
|
280 |
+
num_samples=segment_samples,
|
281 |
+
split_callback=(lambda x, _: x),
|
282 |
+
prepend_punctuations=prepend_punctuations,
|
283 |
+
append_punctuations=append_punctuations,
|
284 |
+
gap_padding=None
|
285 |
+
)
|
286 |
+
|
287 |
+
return temp_segment
|
288 |
+
|
289 |
+
def get_curr_words():
|
290 |
+
nonlocal words, word_tokens
|
291 |
+
curr_tk_count = 0
|
292 |
+
w, wt = [], []
|
293 |
+
for _ in range(len(words)):
|
294 |
+
tk_count = len(word_tokens[0])
|
295 |
+
if curr_tk_count + tk_count > token_step and w:
|
296 |
+
break
|
297 |
+
w.append(words.pop(0))
|
298 |
+
wt.append(word_tokens.pop(0))
|
299 |
+
curr_tk_count += tk_count
|
300 |
+
return w, wt
|
301 |
+
result = []
|
302 |
+
|
303 |
+
nonspeech_timings = [[], []]
|
304 |
+
nonspeech_vad_timings = None
|
305 |
+
if (suppress_silence or nonspeech_skip is not None) and vad:
|
306 |
+
nonspeech_vad_timings = (
|
307 |
+
get_vad_silence_func(onnx=vad_onnx, verbose=verbose)(audio, speech_threshold=vad_threshold)
|
308 |
+
)
|
309 |
+
if nonspeech_vad_timings is not None:
|
310 |
+
nonspeech_timings = nonspeech_vad_timings[0].copy(), nonspeech_vad_timings[1].copy()
|
311 |
+
|
312 |
+
with tqdm(total=total_duration, unit='sec', disable=verbose is not False, desc='Align') as tqdm_pbar:
|
313 |
+
|
314 |
+
def update_pbar(finish: bool = False):
|
315 |
+
tqdm_pbar.update((total_duration if finish else min(round(last_ts, 2), total_duration)) - tqdm_pbar.n)
|
316 |
+
if progress_callback is not None:
|
317 |
+
progress_callback(seek=tqdm_pbar.n, total=tqdm_pbar.total)
|
318 |
+
|
319 |
+
def redo_words(_idx: int = None):
|
320 |
+
nonlocal seg_words, seg_tokens, seg_words, words, word_tokens, curr_words, temp_word
|
321 |
+
if curr_words and temp_word is not None:
|
322 |
+
assert curr_words[0]['word'] == temp_word['word']
|
323 |
+
if curr_words[0]['probability'] >= temp_word['probability']:
|
324 |
+
temp_word = curr_words[0]
|
325 |
+
if _idx is None: # redo all
|
326 |
+
words = seg_words + words
|
327 |
+
word_tokens = seg_tokens + word_tokens
|
328 |
+
curr_words = []
|
329 |
+
elif _idx != len(seg_words): # redo from _idx
|
330 |
+
words = seg_words[_idx:] + words
|
331 |
+
word_tokens = seg_tokens[_idx:] + word_tokens
|
332 |
+
curr_words = curr_words[:_idx]
|
333 |
+
if curr_words:
|
334 |
+
if temp_word is not None:
|
335 |
+
curr_words[0] = temp_word
|
336 |
+
temp_word = None
|
337 |
+
words = seg_words[_idx-1:_idx] + words
|
338 |
+
word_tokens = seg_tokens[_idx-1:_idx] + word_tokens
|
339 |
+
temp_word = curr_words.pop(-1)
|
340 |
+
else:
|
341 |
+
if temp_word is not None:
|
342 |
+
curr_words[0] = temp_word
|
343 |
+
temp_word = None
|
344 |
+
|
345 |
+
n_samples = model.feature_extractor.n_samples if is_faster_model else N_SAMPLES
|
346 |
+
|
347 |
+
temp_word = None
|
348 |
+
|
349 |
+
while words and seek_sample < total_samples:
|
350 |
+
|
351 |
+
time_offset = seek_sample / SAMPLE_RATE
|
352 |
+
seek_sample_end = seek_sample + n_samples
|
353 |
+
audio_segment = audio[seek_sample:seek_sample_end]
|
354 |
+
segment_samples = audio_segment.shape[-1]
|
355 |
+
|
356 |
+
if nonspeech_skip is not None:
|
357 |
+
segment_nonspeech_timings = None
|
358 |
+
if not vad:
|
359 |
+
ts_token_mask = wav2mask(audio_segment, q_levels=q_levels, k_size=k_size)
|
360 |
+
segment_nonspeech_timings = mask2timing(ts_token_mask, time_offset=time_offset)
|
361 |
+
if segment_nonspeech_timings is not None:
|
362 |
+
nonspeech_timings[0].extend(segment_nonspeech_timings[0])
|
363 |
+
nonspeech_timings[1].extend(segment_nonspeech_timings[1])
|
364 |
+
elif nonspeech_vad_timings:
|
365 |
+
timing_indices = np.logical_and(
|
366 |
+
nonspeech_vad_timings[1] > time_offset,
|
367 |
+
nonspeech_vad_timings[0] < time_offset + 30.0
|
368 |
+
)
|
369 |
+
|
370 |
+
if timing_indices.any():
|
371 |
+
segment_nonspeech_timings = (
|
372 |
+
nonspeech_vad_timings[0][timing_indices], nonspeech_vad_timings[1][timing_indices]
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
segment_nonspeech_timings = None
|
376 |
+
|
377 |
+
if mn := timing_indices.argmax():
|
378 |
+
nonspeech_vad_timings = (nonspeech_vad_timings[0][mn:], nonspeech_vad_timings[1][mn:])
|
379 |
+
|
380 |
+
if segment_nonspeech_timings is not None:
|
381 |
+
# segment has no detectable speech
|
382 |
+
if (
|
383 |
+
(segment_nonspeech_timings[0][0] <= time_offset + min_word_dur) and
|
384 |
+
(segment_nonspeech_timings[1][0] >= time_offset + segment_samples - min_word_dur)
|
385 |
+
):
|
386 |
+
seek_sample += segment_samples
|
387 |
+
continue
|
388 |
+
|
389 |
+
timing_indices = (segment_nonspeech_timings[1] - segment_nonspeech_timings[0]) >= nonspeech_skip
|
390 |
+
if any(timing_indices):
|
391 |
+
nonspeech_starts = segment_nonspeech_timings[0][timing_indices]
|
392 |
+
nonspeech_ends = segment_nonspeech_timings[1][timing_indices]
|
393 |
+
|
394 |
+
if round(time_offset, 3) >= nonspeech_starts[0]:
|
395 |
+
seek_sample = round(nonspeech_ends[0] * SAMPLE_RATE)
|
396 |
+
if seek_sample + (min_word_dur * SAMPLE_RATE) >= total_samples:
|
397 |
+
seek_sample = total_samples
|
398 |
+
continue
|
399 |
+
time_offset = seek_sample / SAMPLE_RATE
|
400 |
+
|
401 |
+
if len(nonspeech_starts) > 1:
|
402 |
+
seek_sample_end = (
|
403 |
+
seek_sample + round((nonspeech_starts[1] - nonspeech_ends[0]) * SAMPLE_RATE)
|
404 |
+
)
|
405 |
+
audio_segment = audio[seek_sample:seek_sample_end]
|
406 |
+
segment_samples = audio_segment.shape[-1]
|
407 |
+
|
408 |
+
curr_words, curr_word_tokens = get_curr_words()
|
409 |
+
|
410 |
+
segment = timestamp_words()
|
411 |
+
curr_words = segment['words']
|
412 |
+
seg_words = [w['word'] for w in curr_words]
|
413 |
+
seg_tokens = [w['tokens'] for w in curr_words]
|
414 |
+
durations = np.array([w['end'] - w['start'] for w in curr_words]).round(3)
|
415 |
+
nonzero_mask = durations > 0
|
416 |
+
nonzero_indices = np.flatnonzero(nonzero_mask)
|
417 |
+
if len(nonzero_indices):
|
418 |
+
redo_index = nonzero_indices[-1] + 1
|
419 |
+
if (
|
420 |
+
words and
|
421 |
+
redo_index > 1 and
|
422 |
+
curr_words[nonzero_indices[-1]]['end'] >= np.floor(time_offset + segment_samples / SAMPLE_RATE)
|
423 |
+
):
|
424 |
+
nonzero_mask[nonzero_indices[-1]] = False
|
425 |
+
nonzero_indices = nonzero_indices[:-1]
|
426 |
+
redo_index = nonzero_indices[-1] + 1
|
427 |
+
med_dur = np.median(durations[:redo_index])
|
428 |
+
|
429 |
+
if fast_mode:
|
430 |
+
new_start = None
|
431 |
+
global_max_dur = None
|
432 |
+
else:
|
433 |
+
local_max_dur = round(med_dur * word_dur_factor, 3) if word_dur_factor else None
|
434 |
+
if max_word_dur:
|
435 |
+
local_max_dur = min(local_max_dur, max_word_dur) if local_max_dur else max_word_dur
|
436 |
+
global_max_dur = max_word_dur
|
437 |
+
else:
|
438 |
+
global_max_dur = local_max_dur or None
|
439 |
+
if global_max_dur and med_dur > global_max_dur:
|
440 |
+
med_dur = global_max_dur
|
441 |
+
if (
|
442 |
+
local_max_dur and durations[nonzero_indices[0]] > global_max_dur
|
443 |
+
):
|
444 |
+
new_start = round(max(
|
445 |
+
curr_words[nonzero_indices[0]]['end'] - (med_dur * nonzero_indices[0] + local_max_dur),
|
446 |
+
curr_words[nonzero_indices[0]]['start']
|
447 |
+
), 3)
|
448 |
+
if new_start <= time_offset:
|
449 |
+
new_start = None
|
450 |
+
else:
|
451 |
+
new_start = None
|
452 |
+
if new_start is None:
|
453 |
+
if global_max_dur:
|
454 |
+
index_offset = nonzero_indices[0] + 1
|
455 |
+
redo_indices = \
|
456 |
+
np.flatnonzero(durations[index_offset:redo_index] > global_max_dur) + index_offset
|
457 |
+
if len(redo_indices):
|
458 |
+
redo_index = redo_indices[0]
|
459 |
+
last_ts = curr_words[redo_index - 1]['end']
|
460 |
+
redo_words(redo_index)
|
461 |
+
else:
|
462 |
+
last_ts = new_start
|
463 |
+
redo_words()
|
464 |
+
seek_sample = round(last_ts * SAMPLE_RATE)
|
465 |
+
else:
|
466 |
+
seek_sample += audio_segment.shape[-1]
|
467 |
+
last_ts = round(seek_sample / SAMPLE_RATE, 2)
|
468 |
+
redo_words()
|
469 |
+
|
470 |
+
update_pbar()
|
471 |
+
|
472 |
+
result.extend(curr_words)
|
473 |
+
|
474 |
+
if verbose:
|
475 |
+
line = '\n'.join(
|
476 |
+
f"[{format_timestamp(word['start'])}] -> "
|
477 |
+
f"[{format_timestamp(word['end'])}] \"{word['word']}\""
|
478 |
+
for word in curr_words
|
479 |
+
)
|
480 |
+
safe_print(line)
|
481 |
+
update_pbar(True)
|
482 |
+
|
483 |
+
if temp_word is not None:
|
484 |
+
result.append(temp_word)
|
485 |
+
if not result:
|
486 |
+
warnings.warn('Failed to align text.', stacklevel=2)
|
487 |
+
elif words:
|
488 |
+
warnings.warn(f'Failed to align the last {len(words)}/{total_words} words after '
|
489 |
+
f'{format_timestamp(result[-1]["end"])}.', stacklevel=2)
|
490 |
+
|
491 |
+
if words and not remove_instant_words:
|
492 |
+
result.extend(
|
493 |
+
[
|
494 |
+
dict(word=w, start=total_duration, end=total_duration, probability=0.0, tokens=wt)
|
495 |
+
for w, wt in zip(words, word_tokens)
|
496 |
+
]
|
497 |
+
)
|
498 |
+
|
499 |
+
if not result:
|
500 |
+
return
|
501 |
+
|
502 |
+
if len(split_indices_by_char):
|
503 |
+
word_lens = np.cumsum([[len(w['word']) for w in result]])
|
504 |
+
split_indices = [(word_lens >= i).nonzero()[0][0]+1 for i in split_indices_by_char]
|
505 |
+
result = WhisperResult([result[i:j] for i, j in zip([0]+split_indices[:-1], split_indices)])
|
506 |
+
else:
|
507 |
+
result = WhisperResult([result])
|
508 |
+
|
509 |
+
if suppress_silence:
|
510 |
+
result.suppress_silence(
|
511 |
+
*nonspeech_timings,
|
512 |
+
min_word_dur=min_word_dur,
|
513 |
+
word_level=suppress_word_ts,
|
514 |
+
nonspeech_error=nonspeech_error,
|
515 |
+
use_word_position=use_word_position
|
516 |
+
)
|
517 |
+
result.update_nonspeech_sections(*nonspeech_timings)
|
518 |
+
if not original_split:
|
519 |
+
result.regroup(regroup)
|
520 |
+
|
521 |
+
if fail_segs := len([None for s in result.segments if s.end-s.start <= 0]):
|
522 |
+
warnings.warn(f'{fail_segs}/{len(result.segments)} segments failed to align.', stacklevel=2)
|
523 |
+
|
524 |
+
return result
|
525 |
+
|
526 |
+
|
527 |
+
def refine(
|
528 |
+
model: "Whisper",
|
529 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
530 |
+
result: WhisperResult,
|
531 |
+
*,
|
532 |
+
steps: str = None,
|
533 |
+
rel_prob_decrease: float = .03,
|
534 |
+
abs_prob_decrease: float = .05,
|
535 |
+
rel_rel_prob_decrease: Optional[float] = None,
|
536 |
+
prob_threshold: float = .5,
|
537 |
+
rel_dur_change: Optional[float] = .5,
|
538 |
+
abs_dur_change: Optional[float] = None,
|
539 |
+
word_level: bool = True,
|
540 |
+
precision: float = None,
|
541 |
+
single_batch: bool = False,
|
542 |
+
inplace: bool = True,
|
543 |
+
demucs: Union[bool, torch.nn.Module] = False,
|
544 |
+
demucs_options: dict = None,
|
545 |
+
only_voice_freq: bool = False,
|
546 |
+
verbose: Optional[bool] = False
|
547 |
+
) -> WhisperResult:
|
548 |
+
"""
|
549 |
+
Improve existing timestamps.
|
550 |
+
|
551 |
+
This function iteratively muting portions of the audio and monitoring token probabilities to find the most precise
|
552 |
+
timestamps. This "most precise" in this case means the latest start and earliest end of a word that maintains an
|
553 |
+
acceptable probability determined by the specified arguments.
|
554 |
+
|
555 |
+
This is useful readjusting timestamps when they start too early or end too late.
|
556 |
+
|
557 |
+
Parameters
|
558 |
+
----------
|
559 |
+
model : "Whisper"
|
560 |
+
The Whisper ASR model modified instance
|
561 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
562 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
563 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
564 |
+
result : stable_whisper.result.WhisperResult
|
565 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
566 |
+
steps : str, default 'se'
|
567 |
+
Instructions for refinement. A 's' means refine start-timestamps. An 'e' means refine end-timestamps.
|
568 |
+
rel_prob_decrease : float, default 0.3
|
569 |
+
Maximum percent decrease in probability relative to original probability which is the probability from muting
|
570 |
+
according initial timestamps.
|
571 |
+
abs_prob_decrease : float, default 0.05
|
572 |
+
Maximum decrease in probability from original probability.
|
573 |
+
rel_rel_prob_decrease : float, optional
|
574 |
+
Maximum percent decrease in probability relative to previous probability which is the probability from previous
|
575 |
+
iteration of muting.
|
576 |
+
prob_threshold : float, default 0.5
|
577 |
+
Stop refining the timestamp if the probability of its token goes below this value.
|
578 |
+
rel_dur_change : float, default 0.5
|
579 |
+
Maximum percent change in duration of a word relative to its original duration.
|
580 |
+
abs_dur_change : float, optional
|
581 |
+
Maximum seconds a word is allowed deviate from its original duration.
|
582 |
+
word_level : bool, default True
|
583 |
+
Whether to refine timestamps on word-level. If ``False``, only refine start/end timestamps of each segment.
|
584 |
+
precision : float, default 0.1
|
585 |
+
Precision of refined timestamps in seconds. The lowest precision is 0.02 second.
|
586 |
+
single_batch : bool, default False
|
587 |
+
Whether to process in only batch size of one to reduce memory usage.
|
588 |
+
inplace : bool, default True, meaning return a deepcopy of ``result``
|
589 |
+
Whether to alter timestamps in-place.
|
590 |
+
demucs : bool or torch.nn.Module, default False
|
591 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
592 |
+
a Demucs model to avoid reloading the model for each run.
|
593 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
594 |
+
demucs_options : dict, optional
|
595 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
596 |
+
only_voice_freq : bool, default False
|
597 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
598 |
+
verbose : bool or None, default False
|
599 |
+
Whether to display the text being decoded to the console.
|
600 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
601 |
+
|
602 |
+
Returns
|
603 |
+
-------
|
604 |
+
stable_whisper.result.WhisperResult
|
605 |
+
All timestamps, words, probabilities, and other data from the refinement of ``text`` with ``audio``.
|
606 |
+
|
607 |
+
Notes
|
608 |
+
-----
|
609 |
+
The lower the ``precision``, the longer the processing time.
|
610 |
+
|
611 |
+
Examples
|
612 |
+
--------
|
613 |
+
>>> import stable_whisper
|
614 |
+
>>> model = stable_whisper.load_model('base')
|
615 |
+
>>> result = model.transcribe('audio.mp3')
|
616 |
+
>>> model.refine('audio.mp3', result)
|
617 |
+
>>> result.to_srt_vtt('audio.srt')
|
618 |
+
Saved 'audio.srt'
|
619 |
+
"""
|
620 |
+
if not steps:
|
621 |
+
steps = 'se'
|
622 |
+
if precision is None:
|
623 |
+
precision = 0.1
|
624 |
+
if invalid_steps := steps.replace('s', '').replace('e', ''):
|
625 |
+
raise ValueError(f'Invalid step(s): {", ".join(invalid_steps)}')
|
626 |
+
if not result.has_words:
|
627 |
+
raise NotImplementedError(f'Result must have word timestamps.')
|
628 |
+
|
629 |
+
if not inplace:
|
630 |
+
result = copy.deepcopy(result)
|
631 |
+
|
632 |
+
audio = prep_audio(
|
633 |
+
audio,
|
634 |
+
demucs=demucs,
|
635 |
+
demucs_options=demucs_options,
|
636 |
+
only_voice_freq=only_voice_freq,
|
637 |
+
verbose=verbose
|
638 |
+
)
|
639 |
+
max_inference_tokens = model.dims.n_text_ctx - 6
|
640 |
+
sample_padding = int(N_FFT // 2) + 1
|
641 |
+
frame_precision = max(round(precision * FRAMES_PER_SECOND), 2)
|
642 |
+
total_duration = round(audio.shape[-1] / SAMPLE_RATE, 3)
|
643 |
+
tokenizer = get_tokenizer(model, language=result.language, task='transcribe')
|
644 |
+
|
645 |
+
def ts_to_frames(timestamps: Union[np.ndarray, list]) -> np.ndarray:
|
646 |
+
if isinstance(timestamps, list):
|
647 |
+
timestamps = np.array(timestamps)
|
648 |
+
return (timestamps * FRAMES_PER_SECOND).round().astype(int)
|
649 |
+
|
650 |
+
def curr_segments():
|
651 |
+
all_words = result.all_words()
|
652 |
+
seg_edge_mask = np.array([
|
653 |
+
1 if _i == 0 else (2 if _i == len(seg.words)-1 else 0)
|
654 |
+
for seg in result.segments
|
655 |
+
for _i, w in enumerate(seg.words)
|
656 |
+
])
|
657 |
+
start_times = [
|
658 |
+
max(
|
659 |
+
0 if abs_dur_change is None else (w.start - abs_dur_change),
|
660 |
+
0 if rel_dur_change is None else (w.start - w.duration * rel_dur_change),
|
661 |
+
0 if i == 0 else max(all_words[i - 1].end, w.end - 14.5, 0)
|
662 |
+
)
|
663 |
+
for i, w in enumerate(all_words)
|
664 |
+
]
|
665 |
+
end_times = [
|
666 |
+
min(
|
667 |
+
total_duration if abs_dur_change is None else (w.end + abs_dur_change),
|
668 |
+
total_duration if rel_dur_change is None else (w.end + w.duration * rel_dur_change),
|
669 |
+
total_duration if i == len(all_words) else min(all_words[i].start, w.start + 14.5, total_duration)
|
670 |
+
)
|
671 |
+
for i, w in enumerate(all_words, 1)
|
672 |
+
]
|
673 |
+
start = start_times[0]
|
674 |
+
|
675 |
+
prev_i = 0
|
676 |
+
curr_words, curr_starts, curr_ends = [], [], []
|
677 |
+
|
678 |
+
for i, w in enumerate(all_words, 1):
|
679 |
+
if (
|
680 |
+
(end_times[0] - start > 30) or
|
681 |
+
(len(curr_words) + 1 > max_inference_tokens)
|
682 |
+
):
|
683 |
+
if curr_words:
|
684 |
+
yield curr_words, curr_starts, curr_ends, seg_edge_mask[prev_i:prev_i+len(curr_words)]
|
685 |
+
curr_words, curr_starts, curr_ends = [], [], []
|
686 |
+
start = start_times[0]
|
687 |
+
prev_i = i - 1
|
688 |
+
|
689 |
+
curr_words.append(w)
|
690 |
+
curr_starts.append(start_times.pop(0))
|
691 |
+
curr_ends.append(end_times.pop(0))
|
692 |
+
|
693 |
+
if i == len(all_words):
|
694 |
+
yield curr_words, curr_starts, curr_ends, seg_edge_mask[prev_i:prev_i+len(curr_words)]
|
695 |
+
|
696 |
+
def _refine(_step: str):
|
697 |
+
|
698 |
+
for words, min_starts, max_ends, edge_mask in curr_segments():
|
699 |
+
|
700 |
+
time_offset = min_starts[0]
|
701 |
+
start_sample = round(time_offset * SAMPLE_RATE)
|
702 |
+
end_sample = round(max_ends[-1] * SAMPLE_RATE)
|
703 |
+
audio_segment = audio[start_sample:end_sample + 1].unsqueeze(0)
|
704 |
+
|
705 |
+
max_starts = ts_to_frames(np.array([w.end for w in words]) - time_offset)
|
706 |
+
min_ends = ts_to_frames(np.array([w.start for w in words]) - time_offset)
|
707 |
+
min_starts = ts_to_frames(np.array(min_starts) - time_offset)
|
708 |
+
max_ends = ts_to_frames(np.array(max_ends) - time_offset)
|
709 |
+
|
710 |
+
mid_starts = min_starts + ((max_starts - min_starts) / 2).round().astype(int)
|
711 |
+
mid_ends = min_ends + ((max_ends - min_ends) / 2).round().astype(int)
|
712 |
+
|
713 |
+
text_tokens = [t for w in words for t in w.tokens if t < tokenizer.eot]
|
714 |
+
word_tokens = [[t for t in w.tokens if t < tokenizer.eot] for w in words]
|
715 |
+
orig_mel_segment = log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
|
716 |
+
orig_mel_segment = pad_or_trim(orig_mel_segment, N_FRAMES).to(device=model.device)
|
717 |
+
|
718 |
+
def get_prob():
|
719 |
+
|
720 |
+
tokens = torch.tensor(
|
721 |
+
[
|
722 |
+
*tokenizer.sot_sequence,
|
723 |
+
tokenizer.no_timestamps,
|
724 |
+
*text_tokens,
|
725 |
+
tokenizer.eot,
|
726 |
+
]
|
727 |
+
).to(model.device)
|
728 |
+
|
729 |
+
with torch.no_grad():
|
730 |
+
curr_mel_segment = mel_segment if prob_indices else orig_mel_segment
|
731 |
+
if single_batch:
|
732 |
+
logits = torch.cat(
|
733 |
+
[model(_mel.unsqueeze(0), tokens.unsqueeze(0)) for _mel in curr_mel_segment]
|
734 |
+
)
|
735 |
+
else:
|
736 |
+
logits = model(curr_mel_segment, tokens.unsqueeze(0))
|
737 |
+
|
738 |
+
sampled_logits = logits[:, len(tokenizer.sot_sequence):, : tokenizer.eot]
|
739 |
+
token_probs = sampled_logits.softmax(dim=-1)
|
740 |
+
|
741 |
+
text_token_probs = token_probs[:, np.arange(len(text_tokens)), text_tokens]
|
742 |
+
token_positions = token_probs[:, np.arange(len(text_tokens))]
|
743 |
+
if logits.shape[0] != 1 and prob_indices is not None:
|
744 |
+
indices1 = np.arange(len(prob_indices))
|
745 |
+
text_token_probs = text_token_probs[prob_indices, indices1]
|
746 |
+
token_positions = token_positions[prob_indices, indices1]
|
747 |
+
else:
|
748 |
+
text_token_probs.squeeze_(0)
|
749 |
+
|
750 |
+
text_token_probs = text_token_probs.tolist()
|
751 |
+
token_positions = \
|
752 |
+
(
|
753 |
+
token_positions.sort().indices == tokens[len(tokenizer.sot_sequence) + 1:-1][:, None]
|
754 |
+
).nonzero()[:, -1].tolist()
|
755 |
+
|
756 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens]), (1, 0))
|
757 |
+
word_probabilities = np.array([
|
758 |
+
text_token_probs[j-1] if is_end_ts else text_token_probs[i]
|
759 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
760 |
+
])
|
761 |
+
token_positions = [
|
762 |
+
token_positions[j-1] if is_end_ts else token_positions[i]
|
763 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
764 |
+
]
|
765 |
+
|
766 |
+
return word_probabilities, token_positions
|
767 |
+
|
768 |
+
def update_ts():
|
769 |
+
if not is_finish[idx] or changes[idx, -1] == -1:
|
770 |
+
return
|
771 |
+
new_ts = round(time_offset + (changes[idx, -1] / FRAMES_PER_SECOND), 3)
|
772 |
+
if changes[idx, 0] and not changes[idx, 1]:
|
773 |
+
if is_end_ts:
|
774 |
+
if new_ts <= words[idx].end:
|
775 |
+
return
|
776 |
+
elif new_ts >= words[idx].start:
|
777 |
+
return
|
778 |
+
if not verbose:
|
779 |
+
return
|
780 |
+
curr_word = words[idx]
|
781 |
+
word_info = (f'[Word="{curr_word.word}"] '
|
782 |
+
f'[Segment ID: {curr_word.segment_id}] '
|
783 |
+
f'[Word ID: {curr_word.id}]')
|
784 |
+
if is_end_ts:
|
785 |
+
print(f'End: {words[idx].end} -> {new_ts} {word_info}')
|
786 |
+
words[idx].end = new_ts
|
787 |
+
else:
|
788 |
+
print(f'Start: {words[idx].start} -> {new_ts} {word_info}')
|
789 |
+
words[idx].start = new_ts
|
790 |
+
|
791 |
+
mel_segment = orig_mel_segment.clone().repeat_interleave(2, 0)
|
792 |
+
is_end_ts = _step == 'e'
|
793 |
+
|
794 |
+
prob_indices = []
|
795 |
+
is_finish = np.less([w.probability for w in words], prob_threshold)
|
796 |
+
is_finish = np.logical_or(is_finish, [w.duration == 0 for w in words])
|
797 |
+
if not word_level:
|
798 |
+
is_finish[edge_mask != (2 if is_end_ts else 1)] = True
|
799 |
+
for idx, _i in enumerate(max_starts if is_end_ts else min_ends):
|
800 |
+
row = idx % 2
|
801 |
+
prob_indices.extend([row] * len(words[idx].tokens))
|
802 |
+
if is_finish[idx]:
|
803 |
+
continue
|
804 |
+
if is_end_ts:
|
805 |
+
_p = mel_segment.shape[-1] if idx == len(words)-1 else mid_ends[idx+1]
|
806 |
+
mel_segment[row, :, _i:_p] = 0
|
807 |
+
else:
|
808 |
+
_p = 0 if idx == 0 else mid_starts[idx-1]
|
809 |
+
mel_segment[row, :, _p:_i] = 0
|
810 |
+
orig_probs, orig_tk_poss = get_prob()
|
811 |
+
changes = np.zeros((orig_probs.shape[-1], 3), dtype=int)
|
812 |
+
changes[:, -1] = -1
|
813 |
+
frame_indices = (mid_ends, max_starts) if is_end_ts else (min_ends, mid_starts)
|
814 |
+
for idx, (_s, _e) in enumerate(zip(*frame_indices)):
|
815 |
+
row = idx % 2
|
816 |
+
if is_finish[idx]:
|
817 |
+
continue
|
818 |
+
mel_segment[row, :, _s:_e] = 0
|
819 |
+
|
820 |
+
new_probs = prev_probs = orig_probs
|
821 |
+
while not np.all(is_finish):
|
822 |
+
probs, tk_poss = get_prob()
|
823 |
+
abs_diffs = orig_probs - probs
|
824 |
+
rel_diffs = abs_diffs / orig_probs
|
825 |
+
rel_change_diffs = (prev_probs - probs) / prev_probs
|
826 |
+
prev_probs = probs
|
827 |
+
for idx, (abs_diff, rel_diff, rel_change_diff, prob) \
|
828 |
+
in enumerate(zip(abs_diffs, rel_diffs, rel_change_diffs, probs)):
|
829 |
+
if is_finish[idx]:
|
830 |
+
continue
|
831 |
+
if is_end_ts:
|
832 |
+
curr_min, curr_max, curr_mid = min_ends[idx], max_ends[idx], mid_ends[idx]
|
833 |
+
else:
|
834 |
+
curr_min, curr_max, curr_mid = min_starts[idx], max_starts[idx], mid_starts[idx]
|
835 |
+
|
836 |
+
row = prob_indices[idx]
|
837 |
+
best_tks_changed = orig_tk_poss[idx] > tk_poss[idx]
|
838 |
+
failed_requirements = (
|
839 |
+
abs_diff > abs_prob_decrease or
|
840 |
+
rel_diff > rel_prob_decrease or
|
841 |
+
(rel_rel_prob_decrease is not None and rel_change_diff > rel_rel_prob_decrease) or
|
842 |
+
prob < prob_threshold or
|
843 |
+
best_tks_changed
|
844 |
+
)
|
845 |
+
|
846 |
+
if failed_requirements:
|
847 |
+
changes[idx][0] = 1
|
848 |
+
if is_end_ts:
|
849 |
+
curr_min = curr_mid
|
850 |
+
else:
|
851 |
+
curr_max = curr_mid
|
852 |
+
else:
|
853 |
+
changes[idx][1] = 1
|
854 |
+
if is_end_ts:
|
855 |
+
curr_max = curr_mid
|
856 |
+
else:
|
857 |
+
curr_min = curr_mid
|
858 |
+
|
859 |
+
if (new_mid_change := round((curr_max - curr_min) / 2)) < frame_precision:
|
860 |
+
is_finish[idx] = True
|
861 |
+
update_ts()
|
862 |
+
continue
|
863 |
+
|
864 |
+
new_mid = curr_min + new_mid_change
|
865 |
+
if failed_requirements:
|
866 |
+
if is_end_ts:
|
867 |
+
mel_segment[row, :, curr_min:new_mid] = orig_mel_segment[0, :, curr_min:new_mid]
|
868 |
+
else:
|
869 |
+
mel_segment[row, :, new_mid:curr_max] = orig_mel_segment[0, :, new_mid:curr_max]
|
870 |
+
|
871 |
+
else:
|
872 |
+
if is_end_ts:
|
873 |
+
mel_segment[row, :, new_mid:curr_max] = 0
|
874 |
+
else:
|
875 |
+
mel_segment[row, :, curr_min:new_mid] = 0
|
876 |
+
|
877 |
+
if is_end_ts:
|
878 |
+
min_ends[idx], max_ends[idx], mid_ends[idx] = curr_min, curr_max, new_mid
|
879 |
+
else:
|
880 |
+
min_starts[idx], max_starts[idx], mid_starts[idx] = curr_min, curr_max, new_mid
|
881 |
+
if not best_tks_changed:
|
882 |
+
changes[idx][-1] = new_mid
|
883 |
+
new_probs[idx] = prob
|
884 |
+
|
885 |
+
update_pbar(words[-1].end)
|
886 |
+
|
887 |
+
with tqdm(total=round(total_duration, 2), unit='sec', disable=verbose is not False, desc='Refine') as tqdm_pbar:
|
888 |
+
|
889 |
+
def update_pbar(last_ts: float):
|
890 |
+
nonlocal prev_ts
|
891 |
+
tqdm_pbar.update(round(((last_ts - prev_ts) / len(steps)), 2))
|
892 |
+
prev_ts = last_ts
|
893 |
+
|
894 |
+
for step_count, step in enumerate(steps, 1):
|
895 |
+
prev_ts = 0
|
896 |
+
_refine(step)
|
897 |
+
update_pbar(round(tqdm_pbar.total / len(step), 2))
|
898 |
+
tqdm_pbar.update(tqdm_pbar.total - tqdm_pbar.n)
|
899 |
+
|
900 |
+
result.update_all_segs_with_words()
|
901 |
+
|
902 |
+
return result
|
903 |
+
|
904 |
+
|
905 |
+
def locate(
|
906 |
+
model: "Whisper",
|
907 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
908 |
+
text: Union[str, List[int]],
|
909 |
+
language: str,
|
910 |
+
count: int = 1,
|
911 |
+
duration_window: Union[float, Tuple[float, float]] = 3.0,
|
912 |
+
*,
|
913 |
+
mode: int = 0,
|
914 |
+
start: float = None,
|
915 |
+
end: float = None,
|
916 |
+
probability_threshold: float = 0.5,
|
917 |
+
eots: int = 1,
|
918 |
+
max_token_per_seg: int = 20,
|
919 |
+
exact_token: bool = False,
|
920 |
+
case_sensitive: bool = False,
|
921 |
+
verbose: bool = False,
|
922 |
+
initial_prompt: str = None,
|
923 |
+
suppress_tokens: Union[str, List[int]] = '-1',
|
924 |
+
demucs: Union[bool, torch.nn.Module] = False,
|
925 |
+
demucs_options: dict = None,
|
926 |
+
only_voice_freq: bool = False,
|
927 |
+
) -> Union[List[Segment], List[dict]]:
|
928 |
+
"""
|
929 |
+
Locate when specific words are spoken in ``audio`` without fully transcribing.
|
930 |
+
|
931 |
+
This is usefully for quickly finding at what time the specify words or phrases are spoken in an audio. Since it
|
932 |
+
does not need to transcribe the audio to approximate the time, it is significantly faster transcribing then
|
933 |
+
locating the word in the transcript.
|
934 |
+
|
935 |
+
It can also transcribe few seconds around the approximated time to find out what was said around those words or
|
936 |
+
confirm if the word was even spoken near that time.
|
937 |
+
|
938 |
+
Parameters
|
939 |
+
----------
|
940 |
+
model : whisper.model.Whisper
|
941 |
+
An instance of Whisper ASR model.
|
942 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
943 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
944 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
945 |
+
text: str or list of int
|
946 |
+
Words/phrase or list of tokens to search for in ``audio``.
|
947 |
+
language : str
|
948 |
+
Language of the ``text``.
|
949 |
+
count : int, default 1, meaning stop search after 1 match
|
950 |
+
Number of matches to find. Use 0 to look for all.
|
951 |
+
duration_window : float or tuple of (float, float), default 3.0, same as (3.0, 3.0)
|
952 |
+
Seconds before and after the end timestamp approximations to transcribe after mode 1.
|
953 |
+
If tuple pair of values, then the 1st value will be seconds before the end and 2nd value will be seconds after.
|
954 |
+
mode : int, default 0
|
955 |
+
Mode of search.
|
956 |
+
2, Approximates the end timestamp of ``text`` in the audio. This mode does not confirm whether ``text`` is
|
957 |
+
spoken at the timestamp
|
958 |
+
1, Completes mode 2 then transcribes audio within ``duration_window`` to confirm whether `text` is a match at
|
959 |
+
the approximated timestamp by checking if ``text`` at that ``duration_window`` is within
|
960 |
+
``probability_threshold`` or matching the string content if ``text`` with the transcribed text at the
|
961 |
+
``duration_window``.
|
962 |
+
0, Completes mode 1 then add word timestamps to the transcriptions of each match.
|
963 |
+
Modes from fastest to slowest: 2, 1, 0
|
964 |
+
start : float, optional, meaning it starts from 0s
|
965 |
+
Seconds into the audio to start searching for ``text``.
|
966 |
+
end : float, optional
|
967 |
+
Seconds into the audio to stop searching for ``text``.
|
968 |
+
probability_threshold : float, default 0.5
|
969 |
+
Minimum probability of each token in ``text`` for it to be considered a match.
|
970 |
+
eots : int, default 1
|
971 |
+
Number of EOTs to reach before stopping transcription at mode 1. When transcription reach a EOT, it usually
|
972 |
+
means the end of the segment or audio. Once ``text`` is found in the ``duration_window``, the transcription
|
973 |
+
will stop immediately upon reaching a EOT.
|
974 |
+
max_token_per_seg : int, default 20
|
975 |
+
Maximum number of tokens to transcribe in the ``duration_window`` before stopping.
|
976 |
+
exact_token : bool, default False
|
977 |
+
Whether to find a match base on the exact tokens that make up ``text``.
|
978 |
+
case_sensitive : bool, default False
|
979 |
+
Whether to consider the case of ``text`` when matching in string content.
|
980 |
+
verbose : bool or None, default False
|
981 |
+
Whether to display the text being decoded to the console.
|
982 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
983 |
+
initial_prompt : str, optional
|
984 |
+
Text to provide as a prompt for the first window. This can be used to provide, or
|
985 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
986 |
+
to make it more likely to predict those word correctly.
|
987 |
+
suppress_tokens : str or list of int, default '-1', meaning suppress special characters except common punctuations
|
988 |
+
List of tokens to suppress.
|
989 |
+
demucs : bool or torch.nn.Module, default False
|
990 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
991 |
+
a Demucs model to avoid reloading the model for each run.
|
992 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
993 |
+
demucs_options : dict, optional
|
994 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
995 |
+
only_voice_freq : bool, default False
|
996 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
997 |
+
|
998 |
+
Returns
|
999 |
+
-------
|
1000 |
+
stable_whisper.result.Segment or list of dict or list of float
|
1001 |
+
Mode 0, list of instances of :class:`stable_whisper.result.Segment`.
|
1002 |
+
Mode 1, list of dictionaries with end timestamp approximation of matches and transcribed neighboring words.
|
1003 |
+
Mode 2, list of timestamps in seconds for each end timestamp approximation.
|
1004 |
+
|
1005 |
+
Notes
|
1006 |
+
-----
|
1007 |
+
For ``text``, the case and spacing matters as 'on', ' on', ' On' are different tokens, therefore chose the one that
|
1008 |
+
best suits the context (e.g. ' On' to look for it at the beginning of a sentence).
|
1009 |
+
|
1010 |
+
Use a sufficiently large first value of ``duration_window`` i.e. the value > time it is expected to speak ``text``.
|
1011 |
+
|
1012 |
+
If ``exact_token = False`` and the string content matches, then ``probability_threshold`` is not used.
|
1013 |
+
|
1014 |
+
Examples
|
1015 |
+
--------
|
1016 |
+
>>> import stable_whisper
|
1017 |
+
>>> model = stable_whisper.load_model('base')
|
1018 |
+
>>> matches = model.locate('audio.mp3', 'are', 'English', verbose=True)
|
1019 |
+
|
1020 |
+
Some words can sound the same but have different spellings to increase of the chance of finding such words use
|
1021 |
+
``initial_prompt``.
|
1022 |
+
|
1023 |
+
>>> matches = model.locate('audio.mp3', ' Nickie', 'English', verbose=True, initial_prompt='Nickie')
|
1024 |
+
"""
|
1025 |
+
from whisper.timing import median_filter
|
1026 |
+
from whisper.decoding import DecodingTask, DecodingOptions, SuppressTokens
|
1027 |
+
from .timing import split_word_tokens
|
1028 |
+
|
1029 |
+
sample_padding = int(N_FFT // 2) + 1
|
1030 |
+
sec_per_emb = model.dims.n_audio_ctx / CHUNK_LENGTH
|
1031 |
+
CHUNK_SAMPLES = round(CHUNK_LENGTH * SAMPLE_RATE)
|
1032 |
+
if isinstance(duration_window, (float, int)):
|
1033 |
+
duration_window = [duration_window] * 2
|
1034 |
+
window_sum = sum(duration_window)
|
1035 |
+
assert CHUNK_SAMPLES > window_sum, \
|
1036 |
+
f'Sum of [duration_window] must be less than {CHUNK_SAMPLES}, got {window_sum}'
|
1037 |
+
adjusted_chunk_size = CHUNK_SAMPLES - round(duration_window[0]*SAMPLE_RATE)
|
1038 |
+
if initial_prompt:
|
1039 |
+
initial_prompt = ' ' + initial_prompt.strip()
|
1040 |
+
task = DecodingTask(model, DecodingOptions(
|
1041 |
+
language=language, prompt=initial_prompt, suppress_tokens=suppress_tokens, without_timestamps=True,
|
1042 |
+
))
|
1043 |
+
tokenizer = task.tokenizer
|
1044 |
+
initial_tokens = list(task.initial_tokens)
|
1045 |
+
text_tokens, text = (tokenizer.encode(text), text) if isinstance(text, str) else (text, tokenizer.decode(text))
|
1046 |
+
if not exact_token and not case_sensitive:
|
1047 |
+
text = text.lower()
|
1048 |
+
|
1049 |
+
tk_suppress_masks = [
|
1050 |
+
[i for i in fil.suppress_tokens if i < tokenizer.eot]
|
1051 |
+
for fil in task.logit_filters if isinstance(fil, SuppressTokens)
|
1052 |
+
]
|
1053 |
+
|
1054 |
+
audio = prep_audio(
|
1055 |
+
audio,
|
1056 |
+
demucs=demucs,
|
1057 |
+
demucs_options=demucs_options,
|
1058 |
+
only_voice_freq=only_voice_freq,
|
1059 |
+
verbose=verbose
|
1060 |
+
)
|
1061 |
+
prev_target_end = None
|
1062 |
+
found = 0
|
1063 |
+
if end:
|
1064 |
+
audio = audio[:round(end * SAMPLE_RATE)]
|
1065 |
+
seek_sample = round(start * SAMPLE_RATE) if start else 0
|
1066 |
+
total_samples = audio.shape[-1]
|
1067 |
+
|
1068 |
+
def _locate():
|
1069 |
+
nonlocal seek_sample, found
|
1070 |
+
seek = round(seek_sample / SAMPLE_RATE, 3)
|
1071 |
+
audio_segment = audio[seek_sample: seek_sample + CHUNK_SAMPLES]
|
1072 |
+
mel_segment = log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
|
1073 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(device=model.device)
|
1074 |
+
|
1075 |
+
QKs = [None] * model.dims.n_text_layer
|
1076 |
+
hooks = [
|
1077 |
+
block.cross_attn.register_forward_hook(
|
1078 |
+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
|
1079 |
+
)
|
1080 |
+
for i, block in enumerate(model.decoder.blocks)
|
1081 |
+
]
|
1082 |
+
tokens = torch.tensor([initial_tokens + text_tokens]).to(model.device)
|
1083 |
+
with torch.no_grad():
|
1084 |
+
audio_features = model.encoder(mel_segment.unsqueeze(0))
|
1085 |
+
model.decoder(tokens, audio_features)
|
1086 |
+
|
1087 |
+
for hook in hooks:
|
1088 |
+
hook.remove()
|
1089 |
+
|
1090 |
+
weights = torch.cat([QKs[_l][:, _h] for _l, _h in model.alignment_heads.indices().T], dim=0)
|
1091 |
+
weights = weights.softmax(dim=-1)
|
1092 |
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
1093 |
+
weights = (weights - mean) / std
|
1094 |
+
weights = median_filter(weights, 7)
|
1095 |
+
|
1096 |
+
matrix = weights.mean(axis=0)
|
1097 |
+
target_end = round((matrix[-1].argmax()/sec_per_emb).item(), 3)
|
1098 |
+
found_msg = f'"{text}" ending at ~{format_timestamp(target_end+seek)}' if verbose else ''
|
1099 |
+
|
1100 |
+
if mode == 2:
|
1101 |
+
if found_msg:
|
1102 |
+
safe_print('Unconfirmed:' + found_msg)
|
1103 |
+
nonlocal prev_target_end
|
1104 |
+
found += 1
|
1105 |
+
if (
|
1106 |
+
(seek_sample + CHUNK_SAMPLES >= total_samples) or
|
1107 |
+
(count and found >= count) or
|
1108 |
+
(prev_target_end == target_end)
|
1109 |
+
):
|
1110 |
+
seek_sample = total_samples
|
1111 |
+
else:
|
1112 |
+
seek_sample += round(target_end * SAMPLE_RATE)
|
1113 |
+
prev_target_end = target_end
|
1114 |
+
return dict(tokens=[], target_end=target_end+seek)
|
1115 |
+
|
1116 |
+
curr_start = round(max(target_end - duration_window[0], 0.), 3)
|
1117 |
+
curr_end = round(target_end + duration_window[1], 3)
|
1118 |
+
start_frame = round(curr_start * FRAMES_PER_SECOND)
|
1119 |
+
end_frame = round(curr_end * FRAMES_PER_SECOND)
|
1120 |
+
mel_segment_section = pad_or_trim(mel_segment[..., start_frame:end_frame], N_FRAMES)
|
1121 |
+
temp_tokens = torch.tensor([initial_tokens]).to(model.device)
|
1122 |
+
|
1123 |
+
predictions = []
|
1124 |
+
|
1125 |
+
target_token_idx = 0
|
1126 |
+
not_end = True
|
1127 |
+
found_target = False
|
1128 |
+
curr_eots = 0
|
1129 |
+
temp_audio_features = model.encoder(mel_segment_section.unsqueeze(0))
|
1130 |
+
tokens_to_decode = []
|
1131 |
+
replace_found_tokens = []
|
1132 |
+
infer_tokens = [temp_tokens[0]]
|
1133 |
+
kv_cache, hooks = model.install_kv_cache_hooks()
|
1134 |
+
while not_end:
|
1135 |
+
with torch.no_grad():
|
1136 |
+
logits = model.decoder(temp_tokens, temp_audio_features, kv_cache=kv_cache)[0, -1, :tokenizer.eot+1]
|
1137 |
+
for tks in tk_suppress_masks:
|
1138 |
+
logits[tks] = -np.inf
|
1139 |
+
sorted_logits_idxs = logits.sort(dim=-1).indices[-2:]
|
1140 |
+
best_token = sorted_logits_idxs[-1]
|
1141 |
+
best_non_eot_token = sorted_logits_idxs[-2] if best_token == tokenizer.eot else best_token
|
1142 |
+
|
1143 |
+
logits = logits[:tokenizer.eot].softmax(dim=-1)
|
1144 |
+
if found_target:
|
1145 |
+
target_word_prob = is_match = None
|
1146 |
+
else:
|
1147 |
+
if exact_token:
|
1148 |
+
is_match = False
|
1149 |
+
else:
|
1150 |
+
tokens_to_decode.append(best_non_eot_token)
|
1151 |
+
temp_text = tokenizer.decode(tokens_to_decode)
|
1152 |
+
if not case_sensitive:
|
1153 |
+
temp_text = temp_text.lower()
|
1154 |
+
if is_match := temp_text.endswith(text):
|
1155 |
+
tokens_to_decode = []
|
1156 |
+
target_word_prob = logits[text_tokens[target_token_idx]].item()
|
1157 |
+
if (
|
1158 |
+
target_word_prob is not None and
|
1159 |
+
(
|
1160 |
+
target_word_prob >= probability_threshold or
|
1161 |
+
best_non_eot_token == text_tokens[target_token_idx] or
|
1162 |
+
is_match
|
1163 |
+
)
|
1164 |
+
):
|
1165 |
+
if is_match:
|
1166 |
+
best_token = best_non_eot_token
|
1167 |
+
token_prob = logits[best_token].item()
|
1168 |
+
found_target = True
|
1169 |
+
else:
|
1170 |
+
best_token[None] = text_tokens[target_token_idx]
|
1171 |
+
if len(replace_found_tokens) or best_non_eot_token != text_tokens[target_token_idx]:
|
1172 |
+
replace_found_tokens.append(best_non_eot_token)
|
1173 |
+
target_token_idx += 1
|
1174 |
+
if target_token_idx == len(text_tokens):
|
1175 |
+
found_target = True
|
1176 |
+
token_prob = target_word_prob
|
1177 |
+
if found_target:
|
1178 |
+
found += 1
|
1179 |
+
curr_eots = 0
|
1180 |
+
else:
|
1181 |
+
if not found_target:
|
1182 |
+
if len(replace_found_tokens):
|
1183 |
+
temp_tokens = torch.cat(infer_tokens)[None]
|
1184 |
+
temp_tokens = torch.cat(
|
1185 |
+
[temp_tokens[..., :-len(replace_found_tokens)],
|
1186 |
+
torch.stack(replace_found_tokens)[None]]
|
1187 |
+
)
|
1188 |
+
replace_found_tokens = []
|
1189 |
+
kv_cache.clear()
|
1190 |
+
target_token_idx = 0
|
1191 |
+
if best_token == tokenizer.eot:
|
1192 |
+
if curr_eots >= eots or found_target:
|
1193 |
+
not_end = False
|
1194 |
+
else:
|
1195 |
+
curr_eots += 1
|
1196 |
+
best_token = best_non_eot_token
|
1197 |
+
else:
|
1198 |
+
curr_eots = 0
|
1199 |
+
token_prob = None if best_token == tokenizer.eot else logits[best_token].item()
|
1200 |
+
|
1201 |
+
predictions.append(dict(token=best_token.item(), prob=token_prob))
|
1202 |
+
if len(predictions) > max_token_per_seg:
|
1203 |
+
not_end = False
|
1204 |
+
if not_end:
|
1205 |
+
infer_tokens.append(best_token[None])
|
1206 |
+
temp_tokens = best_token[None, None]
|
1207 |
+
kv_cache.clear()
|
1208 |
+
for hook in hooks:
|
1209 |
+
hook.remove()
|
1210 |
+
segment = None
|
1211 |
+
|
1212 |
+
if found_target:
|
1213 |
+
if found_msg:
|
1214 |
+
safe_print('Confirmed: ' + found_msg, tqdm_pbar.write)
|
1215 |
+
final_tokens = [p['token'] for p in predictions]
|
1216 |
+
if mode == 1:
|
1217 |
+
_, (ws, wts), _ = split_word_tokens([dict(tokens=final_tokens)], tokenizer)
|
1218 |
+
final_token_probs = [p['prob'] for p in predictions]
|
1219 |
+
wps = [float(np.mean([final_token_probs.pop(0) for _ in wt])) for wt in wts]
|
1220 |
+
words = [dict(word=w, tokens=wt, probability=wp) for w, wt, wp in zip(ws, wts, wps)]
|
1221 |
+
final_end = target_end+seek
|
1222 |
+
near_text = "".join(ws)
|
1223 |
+
segment = dict(end=final_end, text=text, duration_window_text=near_text, duration_window_word=words)
|
1224 |
+
if verbose:
|
1225 |
+
safe_print(f'Duration Window: "{near_text}"\n', tqdm_pbar.write)
|
1226 |
+
seek_sample += round(curr_end * SAMPLE_RATE)
|
1227 |
+
else:
|
1228 |
+
|
1229 |
+
segment = dict(
|
1230 |
+
seek=0,
|
1231 |
+
tokens=final_tokens
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
add_word_timestamps_stable(
|
1235 |
+
segments=[segment],
|
1236 |
+
model=model,
|
1237 |
+
tokenizer=tokenizer,
|
1238 |
+
mel=mel_segment,
|
1239 |
+
num_samples=round(curr_end*SAMPLE_RATE),
|
1240 |
+
gap_padding=None
|
1241 |
+
)
|
1242 |
+
segment = Segment(0, 0, '', words=segment['words'])
|
1243 |
+
segment.update_seg_with_words()
|
1244 |
+
seek_sample += round(segment.words[-1].end * SAMPLE_RATE)
|
1245 |
+
segment.offset_time(seek)
|
1246 |
+
segment.seek = curr_start
|
1247 |
+
if verbose:
|
1248 |
+
safe_print(segment.to_display_str(), tqdm_pbar.write)
|
1249 |
+
|
1250 |
+
else:
|
1251 |
+
seek_sample += adjusted_chunk_size if audio_segment.shape[-1] == CHUNK_SAMPLES else audio_segment.shape[-1]
|
1252 |
+
|
1253 |
+
return segment
|
1254 |
+
|
1255 |
+
total_duration = round(total_samples / SAMPLE_RATE, 2)
|
1256 |
+
matches = []
|
1257 |
+
with tqdm(total=total_duration, unit='sec', disable=verbose is None, desc='Locate') as tqdm_pbar:
|
1258 |
+
while seek_sample < total_samples and (not count or found < count):
|
1259 |
+
if match := _locate():
|
1260 |
+
matches.append(match)
|
1261 |
+
tqdm_pbar.update(round(seek_sample/SAMPLE_RATE, 2) - tqdm_pbar.n)
|
1262 |
+
tqdm_pbar.update(tqdm_pbar.total - tqdm_pbar.n)
|
1263 |
+
if verbose and not matches:
|
1264 |
+
safe_print(f'Failed to locate "{text}".')
|
1265 |
+
return matches
|
stable_whisper/audio.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import warnings
|
3 |
+
import ffmpeg
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import numpy as np
|
7 |
+
from typing import Union, Optional
|
8 |
+
|
9 |
+
from whisper.audio import SAMPLE_RATE
|
10 |
+
|
11 |
+
|
12 |
+
def is_ytdlp_available():
|
13 |
+
return subprocess.run('yt-dlp -h', shell=True, capture_output=True).returncode == 0
|
14 |
+
|
15 |
+
|
16 |
+
def _load_file(file: Union[str, bytes], verbose: bool = False, only_ffmpeg: bool = False):
|
17 |
+
if isinstance(file, str) and '://' in file:
|
18 |
+
if is_ytdlp_available():
|
19 |
+
verbosity = ' -q' if verbose is None else (' --progress' if verbose else ' --progress -q')
|
20 |
+
p = subprocess.run(
|
21 |
+
f'yt-dlp "{file}" -f ba/w -I 1{verbosity} -o -',
|
22 |
+
shell=True,
|
23 |
+
stdout=subprocess.PIPE
|
24 |
+
)
|
25 |
+
if len(p.stdout) == 0:
|
26 |
+
raise RuntimeError(f'Failed to download media from "{file}" with yt-dlp')
|
27 |
+
return p.stdout
|
28 |
+
else:
|
29 |
+
warnings.warn('URL detected but yt-dlp not available. '
|
30 |
+
'To handle a greater variety of URLs (i.e. non-direct links), '
|
31 |
+
'install yt-dlp, \'pip install yt-dlp\' (repo: https://github.com/yt-dlp/yt-dlp).')
|
32 |
+
if not only_ffmpeg:
|
33 |
+
if is_ytdlp_available():
|
34 |
+
verbosity = ' -q' if verbose is None else (' --progress' if verbose else ' --progress -q')
|
35 |
+
p = subprocess.run(
|
36 |
+
f'yt-dlp "{file}" -f ba/w -I 1{verbosity} -o -',
|
37 |
+
shell=True,
|
38 |
+
stdout=subprocess.PIPE
|
39 |
+
)
|
40 |
+
if p.returncode != 0 or len(p.stdout) == 0:
|
41 |
+
raise RuntimeError(f'Failed to download media from "{file}" with yt-dlp')
|
42 |
+
return p.stdout
|
43 |
+
else:
|
44 |
+
warnings.warn('URL detected but yt-dlp not available. '
|
45 |
+
'To handle a greater variety of URLs (i.e. non-direct links), '
|
46 |
+
'install yt-dlp, \'pip install yt-dlp\' (repo: https://github.com/yt-dlp/yt-dlp).')
|
47 |
+
return file
|
48 |
+
|
49 |
+
|
50 |
+
# modified version of whisper.audio.load_audio
|
51 |
+
def load_audio(file: Union[str, bytes], sr: int = SAMPLE_RATE, verbose: bool = True, only_ffmpeg: bool = False):
|
52 |
+
"""
|
53 |
+
Open an audio file and read as mono waveform then resamples as necessary.
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
file : str or bytes
|
58 |
+
The audio file to open, bytes of file, or URL to audio/video.
|
59 |
+
sr : int, default ``whisper.model.SAMPLE_RATE``
|
60 |
+
The sample rate to resample the audio if necessary.
|
61 |
+
verbose : bool, default True
|
62 |
+
Whether to print yt-dlp log.
|
63 |
+
only_ffmpeg : bool, default False
|
64 |
+
Whether to use only FFmpeg (instead of yt-dlp) for URls.
|
65 |
+
|
66 |
+
Returns
|
67 |
+
-------
|
68 |
+
numpy.ndarray
|
69 |
+
A array containing the audio waveform in float32.
|
70 |
+
"""
|
71 |
+
file = _load_file(file, verbose=verbose, only_ffmpeg=only_ffmpeg)
|
72 |
+
if isinstance(file, bytes):
|
73 |
+
inp, file = file, 'pipe:'
|
74 |
+
else:
|
75 |
+
inp = None
|
76 |
+
try:
|
77 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
78 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
79 |
+
out, _ = (
|
80 |
+
ffmpeg.input(file, threads=0)
|
81 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
82 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True, input=inp)
|
83 |
+
)
|
84 |
+
except ffmpeg.Error as e:
|
85 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
86 |
+
|
87 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
88 |
+
|
89 |
+
|
90 |
+
def voice_freq_filter(wf: (torch.Tensor, np.ndarray), sr: int,
|
91 |
+
upper_freq: int = None,
|
92 |
+
lower_freq: int = None) -> torch.Tensor:
|
93 |
+
if isinstance(wf, np.ndarray):
|
94 |
+
wf = torch.from_numpy(wf)
|
95 |
+
if upper_freq is None:
|
96 |
+
upper_freq = 5000
|
97 |
+
if lower_freq is None:
|
98 |
+
lower_freq = 200
|
99 |
+
assert upper_freq > lower_freq, f'upper_freq {upper_freq} must but greater than lower_freq {lower_freq}'
|
100 |
+
return torchaudio.functional.highpass_biquad(torchaudio.functional.lowpass_biquad(wf, sr, upper_freq),
|
101 |
+
sr,
|
102 |
+
lower_freq)
|
103 |
+
|
104 |
+
|
105 |
+
def is_demucs_available():
|
106 |
+
from importlib.util import find_spec
|
107 |
+
if find_spec('demucs') is None:
|
108 |
+
raise ModuleNotFoundError("Please install Demucs; "
|
109 |
+
"'pip install -U demucs' or "
|
110 |
+
"'pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs'; "
|
111 |
+
"Official Demucs repo: https://github.com/facebookresearch/demucs")
|
112 |
+
|
113 |
+
|
114 |
+
def load_demucs_model():
|
115 |
+
is_demucs_available()
|
116 |
+
from demucs.pretrained import get_model_from_args
|
117 |
+
return get_model_from_args(type('args', (object,), dict(name='htdemucs', repo=None))).cpu().eval()
|
118 |
+
|
119 |
+
|
120 |
+
def demucs_audio(audio: (torch.Tensor, str),
|
121 |
+
input_sr: int = None,
|
122 |
+
output_sr: int = None,
|
123 |
+
model=None,
|
124 |
+
device=None,
|
125 |
+
verbose: bool = True,
|
126 |
+
track_name: str = None,
|
127 |
+
save_path: str = None,
|
128 |
+
**demucs_options) -> torch.Tensor:
|
129 |
+
"""
|
130 |
+
Isolates vocals / remove noise from ``audio`` with Demucs.
|
131 |
+
|
132 |
+
Official repo, https://github.com/facebookresearch/demucs.
|
133 |
+
"""
|
134 |
+
if model is None:
|
135 |
+
model = load_demucs_model()
|
136 |
+
else:
|
137 |
+
is_demucs_available()
|
138 |
+
from demucs.apply import apply_model
|
139 |
+
|
140 |
+
if track_name:
|
141 |
+
track_name = f'"{track_name}"'
|
142 |
+
|
143 |
+
if isinstance(audio, (str, bytes)):
|
144 |
+
if isinstance(audio, str) and not track_name:
|
145 |
+
track_name = f'"{audio}"'
|
146 |
+
audio = torch.from_numpy(load_audio(audio, model.samplerate))
|
147 |
+
elif input_sr != model.samplerate:
|
148 |
+
if input_sr is None:
|
149 |
+
raise ValueError('No [input_sr] specified for audio tensor.')
|
150 |
+
audio = torchaudio.functional.resample(audio,
|
151 |
+
orig_freq=input_sr,
|
152 |
+
new_freq=model.samplerate)
|
153 |
+
if not track_name:
|
154 |
+
track_name = 'audio track'
|
155 |
+
audio_dims = audio.dim()
|
156 |
+
if audio_dims == 1:
|
157 |
+
audio = audio[None, None].repeat_interleave(2, -2)
|
158 |
+
else:
|
159 |
+
if audio.shape[-2] == 1:
|
160 |
+
audio = audio.repeat_interleave(2, -2)
|
161 |
+
if audio_dims < 3:
|
162 |
+
audio = audio[None]
|
163 |
+
|
164 |
+
if 'mix' in demucs_options:
|
165 |
+
audio = demucs_options.pop('mix')
|
166 |
+
|
167 |
+
if device is None:
|
168 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
169 |
+
|
170 |
+
vocals_idx = model.sources.index('vocals')
|
171 |
+
if verbose:
|
172 |
+
print(f'Isolating vocals from {track_name}')
|
173 |
+
apply_kwarg = dict(
|
174 |
+
model=model,
|
175 |
+
mix=audio,
|
176 |
+
device=device,
|
177 |
+
split=True,
|
178 |
+
overlap=.25,
|
179 |
+
progress=verbose is not None,
|
180 |
+
)
|
181 |
+
apply_kwarg.update(demucs_options)
|
182 |
+
vocals = apply_model(
|
183 |
+
**apply_kwarg
|
184 |
+
)[0, vocals_idx].mean(0)
|
185 |
+
|
186 |
+
if device != 'cpu':
|
187 |
+
torch.cuda.empty_cache()
|
188 |
+
|
189 |
+
if output_sr is not None and model.samplerate != output_sr:
|
190 |
+
vocals = torchaudio.functional.resample(vocals,
|
191 |
+
orig_freq=model.samplerate,
|
192 |
+
new_freq=output_sr)
|
193 |
+
|
194 |
+
if save_path is not None:
|
195 |
+
if isinstance(save_path, str) and not save_path.lower().endswith('.wav'):
|
196 |
+
save_path += '.wav'
|
197 |
+
torchaudio.save(save_path, vocals[None], output_sr or model.samplerate)
|
198 |
+
print(f'Saved: {save_path}')
|
199 |
+
|
200 |
+
return vocals
|
201 |
+
|
202 |
+
|
203 |
+
def get_samplerate(audiofile: (str, bytes)) -> (int, None):
|
204 |
+
import re
|
205 |
+
if isinstance(audiofile, str):
|
206 |
+
metadata = subprocess.run(f'ffmpeg -i {audiofile}', capture_output=True, shell=True).stderr.decode()
|
207 |
+
else:
|
208 |
+
p = subprocess.Popen(f'ffmpeg -i -', stderr=subprocess.PIPE, stdin=subprocess.PIPE, shell=True)
|
209 |
+
try:
|
210 |
+
p.stdin.write(audiofile)
|
211 |
+
except BrokenPipeError:
|
212 |
+
pass
|
213 |
+
finally:
|
214 |
+
metadata = p.communicate()[-1]
|
215 |
+
if metadata is not None:
|
216 |
+
metadata = metadata.decode()
|
217 |
+
sr = re.findall(r'\n.+Stream.+Audio.+\D+(\d+) Hz', metadata)
|
218 |
+
if sr:
|
219 |
+
return int(sr[0])
|
220 |
+
|
221 |
+
|
222 |
+
def prep_audio(
|
223 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
224 |
+
demucs: Union[bool, torch.nn.Module] = False,
|
225 |
+
demucs_options: dict = None,
|
226 |
+
only_voice_freq: bool = False,
|
227 |
+
only_ffmpeg: bool = False,
|
228 |
+
verbose: Optional[bool] = False,
|
229 |
+
sr: int = None
|
230 |
+
) -> torch.Tensor:
|
231 |
+
"""
|
232 |
+
Converts input audio of many types into a mono waveform as a torch.Tensor.
|
233 |
+
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
237 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
238 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
239 |
+
demucs : bool or torch.nn.Module, default False
|
240 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
241 |
+
a Demucs model to avoid reloading the model for each run.
|
242 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
243 |
+
demucs_options : dict, optional
|
244 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
245 |
+
only_voice_freq : bool, default False
|
246 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
247 |
+
sr : int, default None, meaning ``whisper.audio.SAMPLE_RATE``, 16kHZ
|
248 |
+
The sample rate of ``audio``.
|
249 |
+
verbose : bool, default False
|
250 |
+
Whether to print yt-dlp log.
|
251 |
+
only_ffmpeg: bool, default False
|
252 |
+
Whether to use only FFmpeg (and not yt-dlp) for URls.
|
253 |
+
|
254 |
+
Returns
|
255 |
+
-------
|
256 |
+
torch.Tensor
|
257 |
+
A mono waveform.
|
258 |
+
"""
|
259 |
+
if not sr:
|
260 |
+
sr = SAMPLE_RATE
|
261 |
+
if isinstance(audio, (str, bytes)):
|
262 |
+
if demucs:
|
263 |
+
demucs_kwargs = dict(
|
264 |
+
audio=audio,
|
265 |
+
output_sr=sr,
|
266 |
+
verbose=verbose,
|
267 |
+
)
|
268 |
+
demucs_kwargs.update(demucs_options or {})
|
269 |
+
audio = demucs_audio(**demucs_kwargs)
|
270 |
+
else:
|
271 |
+
audio = torch.from_numpy(load_audio(audio, sr=sr, verbose=verbose, only_ffmpeg=only_ffmpeg))
|
272 |
+
else:
|
273 |
+
if isinstance(audio, np.ndarray):
|
274 |
+
audio = torch.from_numpy(audio)
|
275 |
+
if demucs:
|
276 |
+
demucs_kwargs = dict(
|
277 |
+
audio=audio,
|
278 |
+
input_sr=sr,
|
279 |
+
output_sr=sr,
|
280 |
+
verbose=verbose,
|
281 |
+
)
|
282 |
+
demucs_kwargs.update(demucs_options or {})
|
283 |
+
audio = demucs_audio(**demucs_kwargs)
|
284 |
+
if only_voice_freq:
|
285 |
+
audio = voice_freq_filter(audio, sr)
|
286 |
+
|
287 |
+
return audio
|
288 |
+
|
stable_whisper/decode.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING, List, Union
|
2 |
+
from dataclasses import replace
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult
|
8 |
+
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from whisper.model import Whisper
|
12 |
+
|
13 |
+
|
14 |
+
def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None):
|
15 |
+
if ts_token_mask is not None:
|
16 |
+
ts_logits[:, ts_token_mask] = -np.inf
|
17 |
+
|
18 |
+
|
19 |
+
# modified version of whisper.decoding.DecodingTask
|
20 |
+
class DecodingTaskStable(DecodingTask):
|
21 |
+
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None)
|
24 |
+
self.audio_features: torch.Tensor = kwargs.pop('audio_features', None)
|
25 |
+
super(DecodingTaskStable, self).__init__(*args, **kwargs)
|
26 |
+
|
27 |
+
def _get_audio_features(self, mel: torch.Tensor):
|
28 |
+
if self.audio_features is None:
|
29 |
+
audio_features = super()._get_audio_features(mel)
|
30 |
+
self.audio_features = audio_features.detach().clone()
|
31 |
+
return audio_features
|
32 |
+
return self.audio_features.clone()
|
33 |
+
|
34 |
+
# modified version of whisper.DecodingTask._main_loop
|
35 |
+
def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor):
|
36 |
+
n_batch = tokens.shape[0]
|
37 |
+
sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device)
|
38 |
+
no_speech_probs = [np.nan] * n_batch
|
39 |
+
|
40 |
+
try:
|
41 |
+
for i in range(self.sample_len):
|
42 |
+
logits = self.inference.logits(tokens, audio_features)
|
43 |
+
|
44 |
+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
45 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
46 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
47 |
+
|
48 |
+
# now we need to consider the logits at the last token only
|
49 |
+
logits = logits[:, -1]
|
50 |
+
|
51 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
52 |
+
for logit_filter in self.logit_filters:
|
53 |
+
logit_filter.apply(logits, tokens)
|
54 |
+
|
55 |
+
# suppress timestamp tokens where the audio is silent so that decoder ignores those timestamps
|
56 |
+
_suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask)
|
57 |
+
|
58 |
+
logits.nan_to_num_(-np.inf)
|
59 |
+
# expand the tokens tensor with the selected next tokens
|
60 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
61 |
+
|
62 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
63 |
+
break
|
64 |
+
finally:
|
65 |
+
self.inference.cleanup_caching()
|
66 |
+
|
67 |
+
return tokens, sum_logprobs, no_speech_probs
|
68 |
+
|
69 |
+
|
70 |
+
# modified version of whisper.decoding.decode
|
71 |
+
@torch.no_grad()
|
72 |
+
def decode_stable(model: "Whisper",
|
73 |
+
mel: torch.Tensor,
|
74 |
+
options: DecodingOptions = DecodingOptions(),
|
75 |
+
ts_token_mask: torch.Tensor = None,
|
76 |
+
audio_features: torch.Tensor = None,
|
77 |
+
**kwargs, ) -> \
|
78 |
+
Union[DecodingResult, List[DecodingResult], tuple]:
|
79 |
+
"""
|
80 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
model : whisper.model.Whisper
|
85 |
+
An instance of Whisper ASR model.
|
86 |
+
mel : torch.Tensor,
|
87 |
+
A tensor containing the Mel spectrogram(s). ``mel.shape`` must be (80, 3000) or (*, 80, 3000).
|
88 |
+
options : whisper.decode.DecodingOptions, default whisper.decode.DecodingOptions()
|
89 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
90 |
+
ts_token_mask : torch.Tensor, optional
|
91 |
+
Mask for suppressing to timestamp token(s) for decoding.
|
92 |
+
audio_features : torch.Tensor, optional
|
93 |
+
Reused ``audio_feature`` from encoder for fallback.
|
94 |
+
|
95 |
+
Returns
|
96 |
+
-------
|
97 |
+
whisper.decode.DecodingResult or list whisper.decode.DecodingResult
|
98 |
+
The result(s) of decoding contained in ``whisper.decode.DecodingResult`` dataclass instance(s).
|
99 |
+
"""
|
100 |
+
if single := mel.ndim == 2:
|
101 |
+
mel = mel.unsqueeze(0)
|
102 |
+
|
103 |
+
if kwargs:
|
104 |
+
options = replace(options, **kwargs)
|
105 |
+
|
106 |
+
task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features)
|
107 |
+
result = task.run(mel)
|
108 |
+
|
109 |
+
return result[0] if single else result, task.audio_features
|
stable_whisper/non_whisper.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import io
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import numpy as np
|
7 |
+
from typing import Union, Callable, Optional
|
8 |
+
|
9 |
+
from .audio import load_audio
|
10 |
+
from .result import WhisperResult
|
11 |
+
|
12 |
+
AUDIO_TYPES = ('str', 'byte', 'torch', 'numpy')
|
13 |
+
|
14 |
+
|
15 |
+
def transcribe_any(
|
16 |
+
inference_func: Callable,
|
17 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
18 |
+
audio_type: str = None,
|
19 |
+
input_sr: int = None,
|
20 |
+
model_sr: int = None,
|
21 |
+
inference_kwargs: dict = None,
|
22 |
+
temp_file: str = None,
|
23 |
+
verbose: Optional[bool] = False,
|
24 |
+
regroup: Union[bool, str] = True,
|
25 |
+
suppress_silence: bool = True,
|
26 |
+
suppress_word_ts: bool = True,
|
27 |
+
q_levels: int = 20,
|
28 |
+
k_size: int = 5,
|
29 |
+
demucs: bool = False,
|
30 |
+
demucs_device: str = None,
|
31 |
+
demucs_output: str = None,
|
32 |
+
demucs_options: dict = None,
|
33 |
+
vad: bool = False,
|
34 |
+
vad_threshold: float = 0.35,
|
35 |
+
vad_onnx: bool = False,
|
36 |
+
min_word_dur: float = 0.1,
|
37 |
+
nonspeech_error: float = 0.3,
|
38 |
+
use_word_position: bool = True,
|
39 |
+
only_voice_freq: bool = False,
|
40 |
+
only_ffmpeg: bool = False,
|
41 |
+
force_order: bool = False,
|
42 |
+
check_sorted: bool = True
|
43 |
+
) -> WhisperResult:
|
44 |
+
"""
|
45 |
+
Transcribe ``audio`` using any ASR system.
|
46 |
+
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
inference_func : Callable
|
50 |
+
Function that runs ASR when provided the [audio] and return data in the appropriate format.
|
51 |
+
For format examples see, https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb.
|
52 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
53 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
54 |
+
audio_type : {'str', 'byte', 'torch', 'numpy', None}, default None, meaning same type as ``audio``
|
55 |
+
The type that ``audio`` needs to be for ``inference_func``.
|
56 |
+
'str' is a path to the file.
|
57 |
+
'byte' is bytes (used for APIs or to avoid writing any data to hard drive).
|
58 |
+
'torch' is an instance of :class:`torch.Tensor` containing the audio waveform, in float32 dtype, on CPU.
|
59 |
+
'numpy' is an instance of :class:`numpy.ndarray` containing the audio waveform, in float32 dtype.
|
60 |
+
input_sr : int, default None, meaning auto-detected if ``audio`` is ``str`` or ``bytes``
|
61 |
+
The sample rate of ``audio``.
|
62 |
+
model_sr : int, default None, meaning same sample rate as ``input_sr``
|
63 |
+
The sample rate to resample the audio into for ``inference_func``.
|
64 |
+
inference_kwargs : dict, optional
|
65 |
+
Dictionary of arguments to pass into ``inference_func``.
|
66 |
+
temp_file : str, default './_temp_stable-ts_audio_.wav'
|
67 |
+
Temporary path for the preprocessed audio when ``audio_type = 'str'``.
|
68 |
+
verbose: bool, False
|
69 |
+
Whether to displays all the details during transcription, If ``False``, displays progressbar. If ``None``, does
|
70 |
+
not display anything.
|
71 |
+
regroup: str or bool, default True
|
72 |
+
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. Only
|
73 |
+
applies if ``word_timestamps = False``.
|
74 |
+
suppress_silence : bool, default True
|
75 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
76 |
+
suppress_word_ts : bool, default True
|
77 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
78 |
+
q_levels : int, default 20
|
79 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
80 |
+
Acts as a threshold to marking sound as silent.
|
81 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
82 |
+
k_size : int, default 5
|
83 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
84 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
85 |
+
demucs : bool or torch.nn.Module, default False
|
86 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
87 |
+
a Demucs model to avoid reloading the model for each run.
|
88 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
89 |
+
demucs_output : str, optional
|
90 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
91 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
92 |
+
demucs_options : dict, optional
|
93 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
94 |
+
demucs_device : str, default None, meaning 'cuda' if cuda is available with ``torch`` else 'cpu'
|
95 |
+
Device to use for demucs.
|
96 |
+
vad : bool, default False
|
97 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
98 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
99 |
+
vad_threshold : float, default 0.35
|
100 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
101 |
+
vad_onnx : bool, default False
|
102 |
+
Whether to use ONNX for Silero VAD.
|
103 |
+
min_word_dur : float, default 0.1
|
104 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
105 |
+
nonspeech_error : float, default 0.3
|
106 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
107 |
+
use_word_position : bool, default True
|
108 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
109 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
110 |
+
only_voice_freq : bool, default False
|
111 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
112 |
+
only_ffmpeg : bool, default False
|
113 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
114 |
+
force_order : bool, default False
|
115 |
+
Whether to use adjacent timestamps to replace timestamps that are out of order. Use this parameter only if
|
116 |
+
the words/segments returned by ``inference_func`` are expected to be in chronological order.
|
117 |
+
check_sorted : bool, default True
|
118 |
+
Whether to raise an error when timestamps returned by ``inference_func`` are not in ascending order.
|
119 |
+
|
120 |
+
Returns
|
121 |
+
-------
|
122 |
+
stable_whisper.result.WhisperResult
|
123 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
124 |
+
|
125 |
+
Notes
|
126 |
+
-----
|
127 |
+
For ``audio_type = 'str'``:
|
128 |
+
If ``audio`` is a file and no audio preprocessing is set, ``audio`` will be directly passed into
|
129 |
+
``inference_func``.
|
130 |
+
If audio preprocessing is ``demucs`` or ``only_voice_freq``, the processed audio will be encoded into
|
131 |
+
``temp_file`` and then passed into ``inference_func``.
|
132 |
+
|
133 |
+
For ``audio_type = 'byte'``:
|
134 |
+
If ``audio`` is file, the bytes of file will be passed into ``inference_func``.
|
135 |
+
If ``audio`` is :class:`torch.Tensor` or :class:`numpy.ndarray`, the bytes of the ``audio`` will be encoded
|
136 |
+
into WAV format then passed into ``inference_func``.
|
137 |
+
|
138 |
+
Resampling is only performed on ``audio`` when ``model_sr`` does not match the sample rate of the ``audio`` before
|
139 |
+
passing into ``inference_func`` due to ``input_sr`` not matching ``model_sr``, or sample rate changes due to
|
140 |
+
audio preprocessing from ``demucs = True``.
|
141 |
+
"""
|
142 |
+
if demucs_options is None:
|
143 |
+
demucs_options = {}
|
144 |
+
if demucs_output:
|
145 |
+
if 'save_path' not in demucs_options:
|
146 |
+
demucs_options['save_path'] = demucs_output
|
147 |
+
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
|
148 |
+
'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
|
149 |
+
DeprecationWarning, stacklevel=2)
|
150 |
+
if demucs_device:
|
151 |
+
if 'device' not in demucs_options:
|
152 |
+
demucs_options['device'] = demucs_device
|
153 |
+
warnings.warn('``demucs_device`` is deprecated. Use ``demucs_options`` with ``device`` instead. '
|
154 |
+
'E.g. demucs_options=dict(device="cpu")',
|
155 |
+
DeprecationWarning, stacklevel=2)
|
156 |
+
|
157 |
+
if audio_type is not None and (audio_type := audio_type.lower()) not in AUDIO_TYPES:
|
158 |
+
raise NotImplementedError(f'[audio_type]={audio_type} is not supported. Types: {AUDIO_TYPES}')
|
159 |
+
|
160 |
+
if audio_type is None:
|
161 |
+
if isinstance(audio, str):
|
162 |
+
audio_type = 'str'
|
163 |
+
elif isinstance(audio, bytes):
|
164 |
+
audio_type = 'byte'
|
165 |
+
elif isinstance(audio, torch.Tensor):
|
166 |
+
audio_type = 'pytorch'
|
167 |
+
elif isinstance(audio, np.ndarray):
|
168 |
+
audio_type = 'numpy'
|
169 |
+
else:
|
170 |
+
raise TypeError(f'{type(audio)} is not supported for [audio].')
|
171 |
+
|
172 |
+
if (
|
173 |
+
input_sr is None and
|
174 |
+
isinstance(audio, (np.ndarray, torch.Tensor)) and
|
175 |
+
(demucs or only_voice_freq or suppress_silence or model_sr)
|
176 |
+
):
|
177 |
+
raise ValueError('[input_sr] is required when [audio] is a PyTorch tensor or NumPy array.')
|
178 |
+
|
179 |
+
if (
|
180 |
+
model_sr is None and
|
181 |
+
isinstance(audio, (str, bytes)) and
|
182 |
+
audio_type in ('torch', 'numpy')
|
183 |
+
):
|
184 |
+
raise ValueError('[model_sr] is required when [audio_type] is a "pytorch" or "numpy".')
|
185 |
+
|
186 |
+
if isinstance(audio, str):
|
187 |
+
from .audio import _load_file
|
188 |
+
audio = _load_file(audio, verbose=verbose, only_ffmpeg=only_ffmpeg)
|
189 |
+
|
190 |
+
if inference_kwargs is None:
|
191 |
+
inference_kwargs = {}
|
192 |
+
|
193 |
+
temp_file = os.path.abspath(temp_file or './_temp_stable-ts_audio_.wav')
|
194 |
+
temp_audio_file = None
|
195 |
+
|
196 |
+
curr_sr = input_sr
|
197 |
+
|
198 |
+
if demucs:
|
199 |
+
if demucs is True:
|
200 |
+
from .audio import load_demucs_model
|
201 |
+
demucs_model = load_demucs_model()
|
202 |
+
else:
|
203 |
+
demucs_model = demucs
|
204 |
+
demucs = True
|
205 |
+
else:
|
206 |
+
demucs_model = None
|
207 |
+
|
208 |
+
def get_input_sr():
|
209 |
+
nonlocal input_sr
|
210 |
+
if not input_sr and isinstance(audio, (str, bytes)):
|
211 |
+
from .audio import get_samplerate
|
212 |
+
input_sr = get_samplerate(audio)
|
213 |
+
return input_sr
|
214 |
+
|
215 |
+
if only_voice_freq:
|
216 |
+
from .audio import voice_freq_filter
|
217 |
+
if demucs_model is None:
|
218 |
+
curr_sr = model_sr or get_input_sr()
|
219 |
+
else:
|
220 |
+
curr_sr = demucs_model.samplerate
|
221 |
+
if model_sr is None:
|
222 |
+
model_sr = get_input_sr()
|
223 |
+
audio = load_audio(audio, sr=curr_sr, verbose=verbose, only_ffmpeg=only_ffmpeg)
|
224 |
+
audio = voice_freq_filter(audio, curr_sr)
|
225 |
+
|
226 |
+
if demucs:
|
227 |
+
from .audio import demucs_audio
|
228 |
+
if demucs_device is None:
|
229 |
+
demucs_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
230 |
+
demucs_kwargs = dict(
|
231 |
+
audio=audio,
|
232 |
+
input_sr=curr_sr,
|
233 |
+
model=demucs_model,
|
234 |
+
save_path=demucs_output,
|
235 |
+
device=demucs_device,
|
236 |
+
verbose=verbose
|
237 |
+
)
|
238 |
+
demucs_kwargs.update(demucs_options or {})
|
239 |
+
audio = demucs_audio(
|
240 |
+
**demucs_kwargs
|
241 |
+
)
|
242 |
+
curr_sr = demucs_model.samplerate
|
243 |
+
if demucs_output and audio_type == 'str':
|
244 |
+
audio = demucs_output
|
245 |
+
|
246 |
+
final_audio = audio
|
247 |
+
|
248 |
+
if model_sr is not None:
|
249 |
+
|
250 |
+
if curr_sr is None:
|
251 |
+
curr_sr = get_input_sr()
|
252 |
+
|
253 |
+
if curr_sr != model_sr:
|
254 |
+
if isinstance(final_audio, (str, bytes)):
|
255 |
+
final_audio = load_audio(
|
256 |
+
final_audio,
|
257 |
+
sr=model_sr,
|
258 |
+
verbose=verbose,
|
259 |
+
only_ffmpeg=only_ffmpeg
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
if isinstance(final_audio, np.ndarray):
|
263 |
+
final_audio = torch.from_numpy(final_audio)
|
264 |
+
if isinstance(final_audio, torch.Tensor):
|
265 |
+
final_audio = torchaudio.functional.resample(
|
266 |
+
final_audio,
|
267 |
+
orig_freq=curr_sr,
|
268 |
+
new_freq=model_sr,
|
269 |
+
resampling_method="kaiser_window"
|
270 |
+
)
|
271 |
+
|
272 |
+
if audio_type in ('torch', 'numpy'):
|
273 |
+
|
274 |
+
if isinstance(final_audio, (str, bytes)):
|
275 |
+
final_audio = load_audio(
|
276 |
+
final_audio,
|
277 |
+
sr=model_sr,
|
278 |
+
verbose=verbose,
|
279 |
+
only_ffmpeg=only_ffmpeg
|
280 |
+
)
|
281 |
+
|
282 |
+
else:
|
283 |
+
if audio_type == 'torch':
|
284 |
+
if isinstance(final_audio, np.ndarray):
|
285 |
+
final_audio = torch.from_numpy(final_audio)
|
286 |
+
elif audio_type == 'numpy' and isinstance(final_audio, torch.Tensor):
|
287 |
+
final_audio = final_audio.cpu().numpy()
|
288 |
+
|
289 |
+
elif audio_type == 'str':
|
290 |
+
|
291 |
+
if isinstance(final_audio, (torch.Tensor, np.ndarray)):
|
292 |
+
if isinstance(final_audio, np.ndarray):
|
293 |
+
final_audio = torch.from_numpy(final_audio)
|
294 |
+
if final_audio.ndim < 2:
|
295 |
+
final_audio = final_audio[None]
|
296 |
+
torchaudio.save(temp_file, final_audio, model_sr)
|
297 |
+
final_audio = temp_audio_file = temp_file
|
298 |
+
|
299 |
+
elif isinstance(final_audio, bytes):
|
300 |
+
with open(temp_file, 'wb') as f:
|
301 |
+
f.write(final_audio)
|
302 |
+
final_audio = temp_audio_file = temp_file
|
303 |
+
|
304 |
+
else: # audio_type == 'byte'
|
305 |
+
|
306 |
+
if isinstance(final_audio, (torch.Tensor, np.ndarray)):
|
307 |
+
if isinstance(final_audio, np.ndarray):
|
308 |
+
final_audio = torch.from_numpy(final_audio)
|
309 |
+
if final_audio.ndim < 2:
|
310 |
+
final_audio = final_audio[None]
|
311 |
+
with io.BytesIO() as f:
|
312 |
+
torchaudio.save(f, final_audio, model_sr, format="wav")
|
313 |
+
f.seek(0)
|
314 |
+
final_audio = f.read()
|
315 |
+
|
316 |
+
elif isinstance(final_audio, str):
|
317 |
+
with open(final_audio, 'rb') as f:
|
318 |
+
final_audio = f.read()
|
319 |
+
|
320 |
+
inference_kwargs['audio'] = final_audio
|
321 |
+
|
322 |
+
result = None
|
323 |
+
try:
|
324 |
+
result = inference_func(**inference_kwargs)
|
325 |
+
if not isinstance(result, WhisperResult):
|
326 |
+
result = WhisperResult(result, force_order=force_order, check_sorted=check_sorted)
|
327 |
+
if suppress_silence:
|
328 |
+
result.adjust_by_silence(
|
329 |
+
audio, vad,
|
330 |
+
vad_onnx=vad_onnx, vad_threshold=vad_threshold,
|
331 |
+
q_levels=q_levels, k_size=k_size,
|
332 |
+
sample_rate=curr_sr, min_word_dur=min_word_dur,
|
333 |
+
word_level=suppress_word_ts, verbose=True,
|
334 |
+
nonspeech_error=nonspeech_error,
|
335 |
+
use_word_position=use_word_position
|
336 |
+
)
|
337 |
+
|
338 |
+
if result.has_words and regroup:
|
339 |
+
result.regroup(regroup)
|
340 |
+
|
341 |
+
finally:
|
342 |
+
if temp_audio_file is not None:
|
343 |
+
try:
|
344 |
+
os.unlink(temp_audio_file)
|
345 |
+
except Exception as e:
|
346 |
+
warnings.warn(f'Failed to remove temporary audio file {temp_audio_file}. {e}')
|
347 |
+
|
348 |
+
return result
|
stable_whisper/quantization.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from whisper.model import Linear, Conv1d, LayerNorm, Whisper
|
4 |
+
|
5 |
+
|
6 |
+
def replace_modules(model: nn.Module, only_linear: bool = False):
|
7 |
+
"""
|
8 |
+
Replace ``Linear``/``Conv1d``/``LayerNorm`` from :class:`whisper.model` with equivalent module in
|
9 |
+
:class:`torch.nn`.
|
10 |
+
"""
|
11 |
+
for m in model.__dict__.get('_modules', []):
|
12 |
+
module = model.__getattr__(m)
|
13 |
+
update = True
|
14 |
+
if isinstance(module, Linear):
|
15 |
+
model.__setattr__(m, nn.Linear(module.in_features, module.out_features,
|
16 |
+
bias=module.bias is not None))
|
17 |
+
elif not only_linear and isinstance(module, Conv1d):
|
18 |
+
model.__setattr__(m, nn.Conv1d(module.in_channels, module.out_channels,
|
19 |
+
kernel_size=module.kernel_size,
|
20 |
+
stride=module.stride,
|
21 |
+
padding=module.padding,
|
22 |
+
bias=module.bias is not None))
|
23 |
+
elif not only_linear and isinstance(module, LayerNorm):
|
24 |
+
model.__setattr__(m, nn.LayerNorm(module.normalized_shape[0]))
|
25 |
+
else:
|
26 |
+
update = False
|
27 |
+
replace_modules(module)
|
28 |
+
|
29 |
+
if update:
|
30 |
+
model.__getattr__(m).load_state_dict(module.state_dict())
|
31 |
+
|
32 |
+
|
33 |
+
def ptdq_linear(model: "Whisper"):
|
34 |
+
"""
|
35 |
+
Apply Dynamic Quantization to instance of :class:`whisper.model.Whisper`.
|
36 |
+
"""
|
37 |
+
model.cpu()
|
38 |
+
replace_modules(model, only_linear=True)
|
39 |
+
torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8, inplace=True)
|
40 |
+
setattr(model, 'dq', True)
|
stable_whisper/result.py
ADDED
@@ -0,0 +1,2281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from typing import Union, List, Tuple, Optional, Callable
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from copy import deepcopy
|
8 |
+
from itertools import chain
|
9 |
+
|
10 |
+
from .stabilization import suppress_silence, get_vad_silence_func, mask2timing, wav2mask
|
11 |
+
from .text_output import *
|
12 |
+
from .utils import str_to_valid_type, format_timestamp, UnsortedException
|
13 |
+
|
14 |
+
|
15 |
+
__all__ = ['WhisperResult', 'Segment']
|
16 |
+
|
17 |
+
|
18 |
+
def _combine_attr(obj: object, other_obj: object, attr: str):
|
19 |
+
if (val := getattr(obj, attr)) is not None:
|
20 |
+
other_val = getattr(other_obj, attr)
|
21 |
+
if isinstance(val, list):
|
22 |
+
if other_val is None:
|
23 |
+
setattr(obj, attr, None)
|
24 |
+
else:
|
25 |
+
val.extend(other_val)
|
26 |
+
else:
|
27 |
+
new_val = None if other_val is None else ((val + other_val) / 2)
|
28 |
+
setattr(obj, attr, new_val)
|
29 |
+
|
30 |
+
|
31 |
+
def _increment_attr(obj: object, attr: str, val: Union[int, float]):
|
32 |
+
if (curr_val := getattr(obj, attr, None)) is not None:
|
33 |
+
setattr(obj, attr, curr_val + val)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class WordTiming:
|
38 |
+
word: str
|
39 |
+
start: float
|
40 |
+
end: float
|
41 |
+
probability: float = None
|
42 |
+
tokens: List[int] = None
|
43 |
+
left_locked: bool = False
|
44 |
+
right_locked: bool = False
|
45 |
+
segment_id: Optional[int] = None
|
46 |
+
id: Optional[int] = None
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.word)
|
50 |
+
|
51 |
+
def __add__(self, other: 'WordTiming'):
|
52 |
+
self_copy = deepcopy(self)
|
53 |
+
|
54 |
+
self_copy.start = min(self_copy.start, other.start)
|
55 |
+
self_copy.end = max(other.end, self_copy.end)
|
56 |
+
self_copy.word += other.word
|
57 |
+
self_copy.left_locked = self_copy.left_locked or other.left_locked
|
58 |
+
self_copy.right_locked = self_copy.right_locked or other.right_locked
|
59 |
+
_combine_attr(self_copy, other, 'probability')
|
60 |
+
_combine_attr(self_copy, other, 'tokens')
|
61 |
+
|
62 |
+
return self_copy
|
63 |
+
|
64 |
+
def __deepcopy__(self, memo=None):
|
65 |
+
return self.copy()
|
66 |
+
|
67 |
+
def copy(self):
|
68 |
+
return WordTiming(
|
69 |
+
word=self.word,
|
70 |
+
start=self.start,
|
71 |
+
end=self.end,
|
72 |
+
probability=self.probability,
|
73 |
+
tokens=None if self.tokens is None else self.tokens.copy(),
|
74 |
+
left_locked=self.left_locked,
|
75 |
+
right_locked=self.right_locked,
|
76 |
+
segment_id=self.segment_id,
|
77 |
+
id=self.id
|
78 |
+
)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def duration(self):
|
82 |
+
return round(self.end - self.start, 3)
|
83 |
+
|
84 |
+
def round_all_timestamps(self):
|
85 |
+
self.start = round(self.start, 3)
|
86 |
+
self.end = round(self.end, 3)
|
87 |
+
|
88 |
+
def offset_time(self, offset_seconds: float):
|
89 |
+
self.start = round(self.start + offset_seconds, 3)
|
90 |
+
self.end = round(self.end + offset_seconds, 3)
|
91 |
+
|
92 |
+
def to_dict(self):
|
93 |
+
dict_ = deepcopy(self).__dict__
|
94 |
+
dict_.pop('left_locked')
|
95 |
+
dict_.pop('right_locked')
|
96 |
+
return dict_
|
97 |
+
|
98 |
+
def lock_left(self):
|
99 |
+
self.left_locked = True
|
100 |
+
|
101 |
+
def lock_right(self):
|
102 |
+
self.right_locked = True
|
103 |
+
|
104 |
+
def lock_both(self):
|
105 |
+
self.lock_left()
|
106 |
+
self.lock_right()
|
107 |
+
|
108 |
+
def unlock_both(self):
|
109 |
+
self.left_locked = False
|
110 |
+
self.right_locked = False
|
111 |
+
|
112 |
+
def suppress_silence(self,
|
113 |
+
silent_starts: np.ndarray,
|
114 |
+
silent_ends: np.ndarray,
|
115 |
+
min_word_dur: float = 0.1,
|
116 |
+
nonspeech_error: float = 0.3,
|
117 |
+
keep_end: Optional[bool] = True):
|
118 |
+
suppress_silence(self, silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end)
|
119 |
+
return self
|
120 |
+
|
121 |
+
def rescale_time(self, scale_factor: float):
|
122 |
+
self.start = round(self.start * scale_factor, 3)
|
123 |
+
self.end = round(self.end * scale_factor, 3)
|
124 |
+
|
125 |
+
def clamp_max(self, max_dur: float, clip_start: bool = False, verbose: bool = False):
|
126 |
+
if self.duration > max_dur:
|
127 |
+
if clip_start:
|
128 |
+
new_start = round(self.end - max_dur, 3)
|
129 |
+
if verbose:
|
130 |
+
print(f'Start: {self.start} -> {new_start}\nEnd: {self.end}\nText:"{self.word}"\n')
|
131 |
+
self.start = new_start
|
132 |
+
|
133 |
+
else:
|
134 |
+
new_end = round(self.start + max_dur, 3)
|
135 |
+
if verbose:
|
136 |
+
print(f'Start: {self.start}\nEnd: {self.end} -> {new_end}\nText:"{self.word}"\n')
|
137 |
+
self.end = new_end
|
138 |
+
|
139 |
+
def set_segment(self, segment: 'Segment'):
|
140 |
+
self._segment = segment
|
141 |
+
|
142 |
+
def get_segment(self) -> Union['Segment', None]:
|
143 |
+
"""
|
144 |
+
Return instance of :class:`stable_whisper.result.Segment` that this instance is a part of.
|
145 |
+
"""
|
146 |
+
return getattr(self, '_segment', None)
|
147 |
+
|
148 |
+
|
149 |
+
def _words_by_lock(words: List[WordTiming], only_text: bool = False, include_single: bool = False):
|
150 |
+
"""
|
151 |
+
Return a nested list of words such that each sublist contains words that are locked together.
|
152 |
+
"""
|
153 |
+
all_words = []
|
154 |
+
for word in words:
|
155 |
+
if len(all_words) == 0 or not (all_words[-1][-1].right_locked or word.left_locked):
|
156 |
+
all_words.append([word])
|
157 |
+
else:
|
158 |
+
all_words[-1].append(word)
|
159 |
+
if only_text:
|
160 |
+
all_words = list(map(lambda ws: list(map(lambda w: w.word, ws)), all_words))
|
161 |
+
if not include_single:
|
162 |
+
all_words = [ws for ws in all_words if len(ws) > 1]
|
163 |
+
return all_words
|
164 |
+
|
165 |
+
|
166 |
+
@dataclass
|
167 |
+
class Segment:
|
168 |
+
start: float
|
169 |
+
end: float
|
170 |
+
text: str
|
171 |
+
seek: float = None
|
172 |
+
tokens: List[int] = None
|
173 |
+
temperature: float = None
|
174 |
+
avg_logprob: float = None
|
175 |
+
compression_ratio: float = None
|
176 |
+
no_speech_prob: float = None
|
177 |
+
words: Union[List[WordTiming], List[dict]] = None
|
178 |
+
ori_has_words: bool = None
|
179 |
+
id: int = None
|
180 |
+
|
181 |
+
def __getitem__(self, index: int) -> WordTiming:
|
182 |
+
if self.words is None:
|
183 |
+
raise ValueError('segment contains no words')
|
184 |
+
return self.words[index]
|
185 |
+
|
186 |
+
def __delitem__(self, index: int):
|
187 |
+
if self.words is None:
|
188 |
+
raise ValueError('segment contains no words')
|
189 |
+
del self.words[index]
|
190 |
+
self.reassign_ids()
|
191 |
+
self.update_seg_with_words()
|
192 |
+
|
193 |
+
def __deepcopy__(self, memo=None):
|
194 |
+
return self.copy()
|
195 |
+
|
196 |
+
def copy(self, new_words: Optional[List[WordTiming]] = None):
|
197 |
+
if new_words is None:
|
198 |
+
words = None if self.words is None else [w.copy() for w in self.words]
|
199 |
+
else:
|
200 |
+
words = [w.copy() for w in new_words]
|
201 |
+
|
202 |
+
new_seg = Segment(
|
203 |
+
start=self.start,
|
204 |
+
end=self.end,
|
205 |
+
text=self.text,
|
206 |
+
seek=self.seek,
|
207 |
+
tokens=self.tokens,
|
208 |
+
temperature=self.temperature,
|
209 |
+
avg_logprob=self.avg_logprob,
|
210 |
+
compression_ratio=self.compression_ratio,
|
211 |
+
no_speech_prob=self.no_speech_prob,
|
212 |
+
words=words,
|
213 |
+
id=self.id
|
214 |
+
)
|
215 |
+
new_seg.update_seg_with_words()
|
216 |
+
return new_seg
|
217 |
+
|
218 |
+
def to_display_str(self, only_segment: bool = False):
|
219 |
+
line = f'[{format_timestamp(self.start)} --> {format_timestamp(self.end)}] "{self.text}"'
|
220 |
+
if self.has_words and not only_segment:
|
221 |
+
line += '\n' + '\n'.join(
|
222 |
+
f"-[{format_timestamp(w.start)}] -> [{format_timestamp(w.end)}] \"{w.word}\"" for w in self.words
|
223 |
+
) + '\n'
|
224 |
+
return line
|
225 |
+
|
226 |
+
@property
|
227 |
+
def has_words(self):
|
228 |
+
return bool(self.words)
|
229 |
+
|
230 |
+
@property
|
231 |
+
def duration(self):
|
232 |
+
return self.end - self.start
|
233 |
+
|
234 |
+
def word_count(self):
|
235 |
+
if self.has_words:
|
236 |
+
return len(self.words)
|
237 |
+
return -1
|
238 |
+
|
239 |
+
def char_count(self):
|
240 |
+
if self.has_words:
|
241 |
+
return sum(len(w) for w in self.words)
|
242 |
+
return len(self.text)
|
243 |
+
|
244 |
+
def __post_init__(self):
|
245 |
+
if self.has_words:
|
246 |
+
self.words: List[WordTiming] = \
|
247 |
+
[WordTiming(**word) if isinstance(word, dict) else word for word in self.words]
|
248 |
+
for w in self.words:
|
249 |
+
w.set_segment(self)
|
250 |
+
if self.ori_has_words is None:
|
251 |
+
self.ori_has_words = self.has_words
|
252 |
+
self.round_all_timestamps()
|
253 |
+
|
254 |
+
def __add__(self, other: 'Segment'):
|
255 |
+
self_copy = deepcopy(self)
|
256 |
+
|
257 |
+
self_copy.start = min(self_copy.start, other.start)
|
258 |
+
self_copy.end = max(other.end, self_copy.end)
|
259 |
+
self_copy.text += other.text
|
260 |
+
|
261 |
+
_combine_attr(self_copy, other, 'tokens')
|
262 |
+
_combine_attr(self_copy, other, 'temperature')
|
263 |
+
_combine_attr(self_copy, other, 'avg_logprob')
|
264 |
+
_combine_attr(self_copy, other, 'compression_ratio')
|
265 |
+
_combine_attr(self_copy, other, 'no_speech_prob')
|
266 |
+
if self_copy.has_words:
|
267 |
+
if other.has_words:
|
268 |
+
self_copy.words.extend(other.words)
|
269 |
+
else:
|
270 |
+
self_copy.words = None
|
271 |
+
|
272 |
+
return self_copy
|
273 |
+
|
274 |
+
def _word_operations(self, operation: str, *args, **kwargs):
|
275 |
+
if self.has_words:
|
276 |
+
for w in self.words:
|
277 |
+
getattr(w, operation)(*args, **kwargs)
|
278 |
+
|
279 |
+
def round_all_timestamps(self):
|
280 |
+
self.start = round(self.start, 3)
|
281 |
+
self.end = round(self.end, 3)
|
282 |
+
if self.has_words:
|
283 |
+
for word in self.words:
|
284 |
+
word.round_all_timestamps()
|
285 |
+
|
286 |
+
def offset_time(self, offset_seconds: float):
|
287 |
+
self.start = round(self.start + offset_seconds, 3)
|
288 |
+
self.end = round(self.end + offset_seconds, 3)
|
289 |
+
_increment_attr(self, 'seek', offset_seconds)
|
290 |
+
self._word_operations('offset_time', offset_seconds)
|
291 |
+
|
292 |
+
def add_words(self, index0: int, index1: int, inplace: bool = False):
|
293 |
+
if self.has_words:
|
294 |
+
new_word = self.words[index0] + self.words[index1]
|
295 |
+
if inplace:
|
296 |
+
i0, i1 = sorted([index0, index1])
|
297 |
+
self.words[i0] = new_word
|
298 |
+
del self.words[i1]
|
299 |
+
return new_word
|
300 |
+
|
301 |
+
def rescale_time(self, scale_factor: float):
|
302 |
+
self.start = round(self.start * scale_factor, 3)
|
303 |
+
self.end = round(self.end * scale_factor, 3)
|
304 |
+
if self.seek is not None:
|
305 |
+
self.seek = round(self.seek * scale_factor, 3)
|
306 |
+
self._word_operations('rescale_time', scale_factor)
|
307 |
+
self.update_seg_with_words()
|
308 |
+
|
309 |
+
def apply_min_dur(self, min_dur: float, inplace: bool = False):
|
310 |
+
"""
|
311 |
+
Merge any word with adjacent word if its duration is less than ``min_dur``.
|
312 |
+
"""
|
313 |
+
segment = self if inplace else deepcopy(self)
|
314 |
+
if not self.has_words:
|
315 |
+
return segment
|
316 |
+
max_i = len(segment.words) - 1
|
317 |
+
if max_i == 0:
|
318 |
+
return segment
|
319 |
+
for i in reversed(range(len(segment.words))):
|
320 |
+
if max_i == 0:
|
321 |
+
break
|
322 |
+
if segment.words[i].duration < min_dur:
|
323 |
+
if i == max_i:
|
324 |
+
segment.add_words(i-1, i, inplace=True)
|
325 |
+
elif i == 0:
|
326 |
+
segment.add_words(i, i+1, inplace=True)
|
327 |
+
else:
|
328 |
+
if segment.words[i+1].duration < segment.words[i-1].duration:
|
329 |
+
segment.add_words(i-1, i, inplace=True)
|
330 |
+
else:
|
331 |
+
segment.add_words(i, i+1, inplace=True)
|
332 |
+
max_i -= 1
|
333 |
+
return segment
|
334 |
+
|
335 |
+
def _to_reverse_text(
|
336 |
+
self,
|
337 |
+
prepend_punctuations: str = None,
|
338 |
+
append_punctuations: str = None
|
339 |
+
):
|
340 |
+
"""
|
341 |
+
Return a copy with words reversed order per segment.
|
342 |
+
"""
|
343 |
+
if prepend_punctuations is None:
|
344 |
+
prepend_punctuations = "\"'“¿([{-"
|
345 |
+
if prepend_punctuations and ' ' not in prepend_punctuations:
|
346 |
+
prepend_punctuations += ' '
|
347 |
+
if append_punctuations is None:
|
348 |
+
append_punctuations = "\"'.。,,!!??::”)]}、"
|
349 |
+
self_copy = deepcopy(self)
|
350 |
+
has_prepend = bool(prepend_punctuations)
|
351 |
+
has_append = bool(append_punctuations)
|
352 |
+
if has_prepend or has_append:
|
353 |
+
word_objs = (
|
354 |
+
self_copy.words
|
355 |
+
if self_copy.has_words else
|
356 |
+
[WordTiming(w, 0, 1, 0) for w in self_copy.text.split(' ')]
|
357 |
+
)
|
358 |
+
for word in word_objs:
|
359 |
+
new_append = ''
|
360 |
+
if has_prepend:
|
361 |
+
for _ in range(len(word)):
|
362 |
+
char = word.word[0]
|
363 |
+
if char in prepend_punctuations:
|
364 |
+
new_append += char
|
365 |
+
word.word = word.word[1:]
|
366 |
+
else:
|
367 |
+
break
|
368 |
+
new_prepend = ''
|
369 |
+
if has_append:
|
370 |
+
for _ in range(len(word)):
|
371 |
+
char = word.word[-1]
|
372 |
+
if char in append_punctuations:
|
373 |
+
new_prepend += char
|
374 |
+
word.word = word.word[:-1]
|
375 |
+
else:
|
376 |
+
break
|
377 |
+
word.word = f'{new_prepend}{word.word}{new_append[::-1]}'
|
378 |
+
self_copy.text = ''.join(w.word for w in reversed(word_objs))
|
379 |
+
|
380 |
+
return self_copy
|
381 |
+
|
382 |
+
def to_dict(self, reverse_text: Union[bool, tuple] = False):
|
383 |
+
if reverse_text:
|
384 |
+
seg_dict = (
|
385 |
+
(self._to_reverse_text(*reverse_text)
|
386 |
+
if isinstance(reverse_text, tuple) else
|
387 |
+
self._to_reverse_text()).__dict__
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
seg_dict = deepcopy(self).__dict__
|
391 |
+
seg_dict.pop('ori_has_words')
|
392 |
+
if self.has_words:
|
393 |
+
seg_dict['words'] = [w.to_dict() for w in seg_dict['words']]
|
394 |
+
elif self.ori_has_words:
|
395 |
+
seg_dict['words'] = []
|
396 |
+
else:
|
397 |
+
seg_dict.pop('words')
|
398 |
+
if self.id is None:
|
399 |
+
seg_dict.pop('id')
|
400 |
+
if reverse_text:
|
401 |
+
seg_dict['reversed_text'] = True
|
402 |
+
return seg_dict
|
403 |
+
|
404 |
+
def words_by_lock(self, only_text: bool = True, include_single: bool = False):
|
405 |
+
return _words_by_lock(self.words, only_text=only_text, include_single=include_single)
|
406 |
+
|
407 |
+
@property
|
408 |
+
def left_locked(self):
|
409 |
+
if self.has_words:
|
410 |
+
return self.words[0].left_locked
|
411 |
+
return False
|
412 |
+
|
413 |
+
@property
|
414 |
+
def right_locked(self):
|
415 |
+
if self.has_words:
|
416 |
+
return self.words[-1].right_locked
|
417 |
+
return False
|
418 |
+
|
419 |
+
def lock_left(self):
|
420 |
+
if self.has_words:
|
421 |
+
self.words[0].lock_left()
|
422 |
+
|
423 |
+
def lock_right(self):
|
424 |
+
if self.has_words:
|
425 |
+
self.words[-1].lock_right()
|
426 |
+
|
427 |
+
def lock_both(self):
|
428 |
+
self.lock_left()
|
429 |
+
self.lock_right()
|
430 |
+
|
431 |
+
def unlock_all_words(self):
|
432 |
+
self._word_operations('unlock_both')
|
433 |
+
|
434 |
+
def reassign_ids(self):
|
435 |
+
if self.has_words:
|
436 |
+
for i, w in enumerate(self.words):
|
437 |
+
w.segment_id = self.id
|
438 |
+
w.id = i
|
439 |
+
|
440 |
+
def update_seg_with_words(self):
|
441 |
+
if self.has_words:
|
442 |
+
self.start = self.words[0].start
|
443 |
+
self.end = self.words[-1].end
|
444 |
+
self.text = ''.join(w.word for w in self.words)
|
445 |
+
self.tokens = (
|
446 |
+
None
|
447 |
+
if any(w.tokens is None for w in self.words) else
|
448 |
+
[t for w in self.words for t in w.tokens]
|
449 |
+
)
|
450 |
+
for w in self.words:
|
451 |
+
w.set_segment(self)
|
452 |
+
|
453 |
+
def suppress_silence(self,
|
454 |
+
silent_starts: np.ndarray,
|
455 |
+
silent_ends: np.ndarray,
|
456 |
+
min_word_dur: float = 0.1,
|
457 |
+
word_level: bool = True,
|
458 |
+
nonspeech_error: float = 0.3,
|
459 |
+
use_word_position: bool = True):
|
460 |
+
if self.has_words:
|
461 |
+
words = self.words if word_level or len(self.words) == 1 else [self.words[0], self.words[-1]]
|
462 |
+
for i, w in enumerate(words, 1):
|
463 |
+
if use_word_position:
|
464 |
+
keep_end = True if i == 1 else (False if i == len(words) else None)
|
465 |
+
else:
|
466 |
+
keep_end = None
|
467 |
+
w.suppress_silence(silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end)
|
468 |
+
self.update_seg_with_words()
|
469 |
+
else:
|
470 |
+
suppress_silence(self,
|
471 |
+
silent_starts,
|
472 |
+
silent_ends,
|
473 |
+
min_word_dur,
|
474 |
+
nonspeech_error)
|
475 |
+
|
476 |
+
return self
|
477 |
+
|
478 |
+
def get_locked_indices(self):
|
479 |
+
locked_indices = [i
|
480 |
+
for i, (left, right) in enumerate(zip(self.words[1:], self.words[:-1]))
|
481 |
+
if left.left_locked or right.right_locked]
|
482 |
+
return locked_indices
|
483 |
+
|
484 |
+
def get_gaps(self, as_ndarray=False):
|
485 |
+
if self.has_words:
|
486 |
+
s_ts = np.array([w.start for w in self.words])
|
487 |
+
e_ts = np.array([w.end for w in self.words])
|
488 |
+
gap = s_ts[1:] - e_ts[:-1]
|
489 |
+
return gap if as_ndarray else gap.tolist()
|
490 |
+
return []
|
491 |
+
|
492 |
+
def get_gap_indices(self, max_gap: float = 0.1): # for splitting
|
493 |
+
if not self.has_words or len(self.words) < 2:
|
494 |
+
return []
|
495 |
+
if max_gap is None:
|
496 |
+
max_gap = 0
|
497 |
+
indices = (self.get_gaps(True) > max_gap).nonzero()[0].tolist()
|
498 |
+
return sorted(set(indices) - set(self.get_locked_indices()))
|
499 |
+
|
500 |
+
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): # for splitting
|
501 |
+
if not self.has_words or len(self.words) < 2:
|
502 |
+
return []
|
503 |
+
if isinstance(punctuation, str):
|
504 |
+
punctuation = [punctuation]
|
505 |
+
indices = []
|
506 |
+
for p in punctuation:
|
507 |
+
if isinstance(p, str):
|
508 |
+
for i, s in enumerate(self.words[:-1]):
|
509 |
+
if s.word.endswith(p):
|
510 |
+
indices.append(i)
|
511 |
+
elif i != 0 and s.word.startswith(p):
|
512 |
+
indices.append(i-1)
|
513 |
+
else:
|
514 |
+
ending, beginning = p
|
515 |
+
indices.extend([i for i, (w0, w1) in enumerate(zip(self.words[:-1], self.words[1:]))
|
516 |
+
if w0.word.endswith(ending) and w1.word.startswith(beginning)])
|
517 |
+
|
518 |
+
return sorted(set(indices) - set(self.get_locked_indices()))
|
519 |
+
|
520 |
+
def get_length_indices(self, max_chars: int = None, max_words: int = None, even_split: bool = True,
|
521 |
+
include_lock: bool = False):
|
522 |
+
# for splitting
|
523 |
+
if not self.has_words or (max_chars is None and max_words is None):
|
524 |
+
return []
|
525 |
+
assert max_chars != 0 and max_words != 0, \
|
526 |
+
f'max_chars and max_words must be greater 0, but got {max_chars} and {max_words}'
|
527 |
+
if len(self.words) < 2:
|
528 |
+
return []
|
529 |
+
indices = []
|
530 |
+
if even_split:
|
531 |
+
char_count = -1 if max_chars is None else sum(map(len, self.words))
|
532 |
+
word_count = -1 if max_words is None else len(self.words)
|
533 |
+
exceed_chars = max_chars is not None and char_count > max_chars
|
534 |
+
exceed_words = max_words is not None and word_count > max_words
|
535 |
+
if exceed_chars:
|
536 |
+
splits = np.ceil(char_count / max_chars)
|
537 |
+
chars_per_split = char_count / splits
|
538 |
+
cum_char_count = np.cumsum([len(w.word) for w in self.words[:-1]])
|
539 |
+
indices = [
|
540 |
+
(np.abs(cum_char_count-(i*chars_per_split))).argmin()
|
541 |
+
for i in range(1, int(splits))
|
542 |
+
]
|
543 |
+
if max_words is not None:
|
544 |
+
exceed_words = any(j-i+1 > max_words for i, j in zip([0]+indices, indices+[len(self.words)]))
|
545 |
+
|
546 |
+
if exceed_words:
|
547 |
+
splits = np.ceil(word_count / max_words)
|
548 |
+
words_per_split = word_count / splits
|
549 |
+
cum_word_count = np.array(range(1, len(self.words)+1))
|
550 |
+
indices = [
|
551 |
+
np.abs(cum_word_count-(i*words_per_split)).argmin()
|
552 |
+
for i in range(1, int(splits))
|
553 |
+
]
|
554 |
+
|
555 |
+
else:
|
556 |
+
curr_words = 0
|
557 |
+
curr_chars = 0
|
558 |
+
locked_indices = []
|
559 |
+
if include_lock:
|
560 |
+
locked_indices = self.get_locked_indices()
|
561 |
+
for i, word in enumerate(self.words):
|
562 |
+
curr_words += 1
|
563 |
+
curr_chars += len(word)
|
564 |
+
if i != 0:
|
565 |
+
if (
|
566 |
+
max_chars is not None and curr_chars > max_chars
|
567 |
+
or
|
568 |
+
max_words is not None and curr_words > max_words
|
569 |
+
) and i-1 not in locked_indices:
|
570 |
+
indices.append(i-1)
|
571 |
+
curr_words = 1
|
572 |
+
curr_chars = len(word)
|
573 |
+
return indices
|
574 |
+
|
575 |
+
def get_duration_indices(self, max_dur: float, even_split: bool = True, include_lock: bool = False):
|
576 |
+
if not self.has_words or (total_duration := np.sum([w.duration for w in self.words])) <= max_dur:
|
577 |
+
return []
|
578 |
+
if even_split:
|
579 |
+
splits = np.ceil(total_duration / max_dur)
|
580 |
+
dur_per_split = total_duration / splits
|
581 |
+
cum_dur = np.cumsum([w.duration for w in self.words[:-1]])
|
582 |
+
indices = [
|
583 |
+
(np.abs(cum_dur - (i * dur_per_split))).argmin()
|
584 |
+
for i in range(1, int(splits))
|
585 |
+
]
|
586 |
+
else:
|
587 |
+
indices = []
|
588 |
+
curr_total_dur = 0.0
|
589 |
+
locked_indices = self.get_locked_indices() if include_lock else []
|
590 |
+
for i, word in enumerate(self.words):
|
591 |
+
curr_total_dur += word.duration
|
592 |
+
if i != 0:
|
593 |
+
if curr_total_dur > max_dur and i - 1 not in locked_indices:
|
594 |
+
indices.append(i - 1)
|
595 |
+
curr_total_dur = word.duration
|
596 |
+
return indices
|
597 |
+
|
598 |
+
def split(self, indices: List[int]):
|
599 |
+
if len(indices) == 0:
|
600 |
+
return []
|
601 |
+
if indices[-1] != len(self.words) - 1:
|
602 |
+
indices.append(len(self.words) - 1)
|
603 |
+
seg_copies = []
|
604 |
+
prev_i = 0
|
605 |
+
for i in indices:
|
606 |
+
i += 1
|
607 |
+
c = deepcopy(self)
|
608 |
+
c.words = c.words[prev_i:i]
|
609 |
+
c.update_seg_with_words()
|
610 |
+
seg_copies.append(c)
|
611 |
+
prev_i = i
|
612 |
+
return seg_copies
|
613 |
+
|
614 |
+
def set_result(self, result: 'WhisperResult'):
|
615 |
+
self._result = result
|
616 |
+
|
617 |
+
def get_result(self) -> Union['WhisperResult', None]:
|
618 |
+
"""
|
619 |
+
Return outer instance of :class:`stable_whisper.result.WhisperResult` that ``self`` is a part of.
|
620 |
+
"""
|
621 |
+
return getattr(self, '_result', None)
|
622 |
+
|
623 |
+
|
624 |
+
class WhisperResult:
|
625 |
+
|
626 |
+
def __init__(
|
627 |
+
self,
|
628 |
+
result: Union[str, dict, list],
|
629 |
+
force_order: bool = False,
|
630 |
+
check_sorted: Union[bool, str] = True,
|
631 |
+
show_unsorted: bool = True
|
632 |
+
):
|
633 |
+
result, self.path = self._standardize_result(result)
|
634 |
+
self.ori_dict = result.get('ori_dict') or result
|
635 |
+
self.language = self.ori_dict.get('language')
|
636 |
+
self._regroup_history = result.get('regroup_history', '')
|
637 |
+
self._nonspeech_sections = result.get('nonspeech_sections', [])
|
638 |
+
segments = deepcopy(result.get('segments', self.ori_dict.get('segments')))
|
639 |
+
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else []
|
640 |
+
self._forced_order = force_order
|
641 |
+
if self._forced_order:
|
642 |
+
self.force_order()
|
643 |
+
self.raise_for_unsorted(check_sorted, show_unsorted)
|
644 |
+
self.remove_no_word_segments(any(seg.has_words for seg in self.segments))
|
645 |
+
self.update_all_segs_with_words()
|
646 |
+
|
647 |
+
def __getitem__(self, index: int) -> Segment:
|
648 |
+
return self.segments[index]
|
649 |
+
|
650 |
+
def __delitem__(self, index: int):
|
651 |
+
del self.segments[index]
|
652 |
+
self.reassign_ids(True)
|
653 |
+
|
654 |
+
@staticmethod
|
655 |
+
def _standardize_result(result: Union[str, dict, list]):
|
656 |
+
path = None
|
657 |
+
if isinstance(result, str):
|
658 |
+
path = result
|
659 |
+
result = load_result(path)
|
660 |
+
if isinstance(result, list):
|
661 |
+
if isinstance(result[0], list):
|
662 |
+
if not isinstance(result[0][0], dict):
|
663 |
+
raise NotImplementedError(f'Got list of list of {type(result[0])} but expects list of list of dict')
|
664 |
+
result = dict(
|
665 |
+
segments=[
|
666 |
+
dict(
|
667 |
+
start=words[0]['start'],
|
668 |
+
end=words[-1]['end'],
|
669 |
+
text=''.join(w['word'] for w in words),
|
670 |
+
words=words
|
671 |
+
)
|
672 |
+
for words in result
|
673 |
+
]
|
674 |
+
)
|
675 |
+
|
676 |
+
elif isinstance(result[0], dict):
|
677 |
+
result = dict(segments=result)
|
678 |
+
else:
|
679 |
+
raise NotImplementedError(f'Got list of {type(result[0])} but expects list of list/dict')
|
680 |
+
return result, path
|
681 |
+
|
682 |
+
def force_order(self):
|
683 |
+
prev_ts_end = 0
|
684 |
+
timestamps = self.all_words_or_segments()
|
685 |
+
for i, ts in enumerate(timestamps, 1):
|
686 |
+
if ts.start < prev_ts_end:
|
687 |
+
ts.start = prev_ts_end
|
688 |
+
if ts.start > ts.end:
|
689 |
+
if prev_ts_end > ts.end:
|
690 |
+
warnings.warn('Multiple consecutive timestamps are out of order. Some parts will have no duration.')
|
691 |
+
ts.start = ts.end
|
692 |
+
for j in range(i-2, -1, -1):
|
693 |
+
if timestamps[j].end > ts.end:
|
694 |
+
timestamps[j].end = ts.end
|
695 |
+
if timestamps[j].start > ts.end:
|
696 |
+
timestamps[j].start = ts.end
|
697 |
+
else:
|
698 |
+
if ts.start != prev_ts_end:
|
699 |
+
ts.start = prev_ts_end
|
700 |
+
else:
|
701 |
+
ts.end = ts.start if i == len(timestamps) else timestamps[i].start
|
702 |
+
prev_ts_end = ts.end
|
703 |
+
if self.has_words:
|
704 |
+
self.update_all_segs_with_words()
|
705 |
+
|
706 |
+
def raise_for_unsorted(self, check_sorted: Union[bool, str] = True, show_unsorted: bool = True):
|
707 |
+
if check_sorted is False:
|
708 |
+
return
|
709 |
+
all_parts = self.all_words_or_segments()
|
710 |
+
has_words = self.has_words
|
711 |
+
timestamps = np.array(list(chain.from_iterable((p.start, p.end) for p in all_parts)))
|
712 |
+
if len(timestamps) > 1 and (unsorted_mask := timestamps[:-1] > timestamps[1:]).any():
|
713 |
+
if show_unsorted:
|
714 |
+
def get_part_info(idx):
|
715 |
+
curr_part = all_parts[idx]
|
716 |
+
seg_id = curr_part.segment_id if has_words else curr_part.id
|
717 |
+
word_id_str = f'Word ID: {curr_part.id}\n' if has_words else ''
|
718 |
+
return (
|
719 |
+
f'Segment ID: {seg_id}\n{word_id_str}'
|
720 |
+
f'Start: {curr_part.start}\nEnd: {curr_part.end}\n'
|
721 |
+
f'Text: "{curr_part.word if has_words else curr_part.text}"'
|
722 |
+
), curr_part.start, curr_part.end
|
723 |
+
|
724 |
+
for i, unsorted in enumerate(unsorted_mask, 2):
|
725 |
+
if unsorted:
|
726 |
+
word_id = i//2-1
|
727 |
+
part_info, start, end = get_part_info(word_id)
|
728 |
+
if i % 2 == 1:
|
729 |
+
next_info, next_start, _ = get_part_info(word_id+1)
|
730 |
+
part_info += f'\nConflict: end ({end}) > next start ({next_start})\n{next_info}'
|
731 |
+
else:
|
732 |
+
part_info += f'\nConflict: start ({start}) > end ({end})'
|
733 |
+
print(part_info, end='\n\n')
|
734 |
+
|
735 |
+
data = self.to_dict()
|
736 |
+
if check_sorted is True:
|
737 |
+
raise UnsortedException(data=data)
|
738 |
+
warnings.warn('Timestamps are not in ascending order. '
|
739 |
+
'If data is produced by Stable-ts, please submit an issue with the saved data.')
|
740 |
+
save_as_json(data, check_sorted)
|
741 |
+
|
742 |
+
def update_all_segs_with_words(self):
|
743 |
+
for seg in self.segments:
|
744 |
+
seg.update_seg_with_words()
|
745 |
+
seg.set_result(self)
|
746 |
+
|
747 |
+
def update_nonspeech_sections(self, silent_starts, silent_ends):
|
748 |
+
self._nonspeech_sections = [dict(start=s, end=e) for s, e in zip(silent_starts, silent_ends)]
|
749 |
+
|
750 |
+
def add_segments(self, index0: int, index1: int, inplace: bool = False, lock: bool = False):
|
751 |
+
new_seg = self.segments[index0] + self.segments[index1]
|
752 |
+
new_seg.update_seg_with_words()
|
753 |
+
if lock and self.segments[index0].has_words:
|
754 |
+
lock_idx = len(self.segments[index0].words)
|
755 |
+
new_seg.words[lock_idx - 1].lock_right()
|
756 |
+
if lock_idx < len(new_seg.words):
|
757 |
+
new_seg.words[lock_idx].lock_left()
|
758 |
+
if inplace:
|
759 |
+
i0, i1 = sorted([index0, index1])
|
760 |
+
self.segments[i0] = new_seg
|
761 |
+
del self.segments[i1]
|
762 |
+
return new_seg
|
763 |
+
|
764 |
+
def rescale_time(self, scale_factor: float):
|
765 |
+
for s in self.segments:
|
766 |
+
s.rescale_time(scale_factor)
|
767 |
+
|
768 |
+
def apply_min_dur(self, min_dur: float, inplace: bool = False):
|
769 |
+
"""
|
770 |
+
Merge any word/segment with adjacent word/segment if its duration is less than ``min_dur``.
|
771 |
+
"""
|
772 |
+
result = self if inplace else deepcopy(self)
|
773 |
+
max_i = len(result.segments) - 1
|
774 |
+
if max_i == 0:
|
775 |
+
return result
|
776 |
+
for i in reversed(range(len(result.segments))):
|
777 |
+
if max_i == 0:
|
778 |
+
break
|
779 |
+
if result.segments[i].duration < min_dur:
|
780 |
+
if i == max_i:
|
781 |
+
result.add_segments(i-1, i, inplace=True)
|
782 |
+
elif i == 0:
|
783 |
+
result.add_segments(i, i+1, inplace=True)
|
784 |
+
else:
|
785 |
+
if result.segments[i+1].duration < result.segments[i-1].duration:
|
786 |
+
result.add_segments(i-1, i, inplace=True)
|
787 |
+
else:
|
788 |
+
result.add_segments(i, i+1, inplace=True)
|
789 |
+
max_i -= 1
|
790 |
+
result.reassign_ids()
|
791 |
+
for s in result.segments:
|
792 |
+
s.apply_min_dur(min_dur, inplace=True)
|
793 |
+
return result
|
794 |
+
|
795 |
+
def offset_time(self, offset_seconds: float):
|
796 |
+
for s in self.segments:
|
797 |
+
s.offset_time(offset_seconds)
|
798 |
+
|
799 |
+
def suppress_silence(
|
800 |
+
self,
|
801 |
+
silent_starts: np.ndarray,
|
802 |
+
silent_ends: np.ndarray,
|
803 |
+
min_word_dur: float = 0.1,
|
804 |
+
word_level: bool = True,
|
805 |
+
nonspeech_error: float = 0.3,
|
806 |
+
use_word_position: bool = True
|
807 |
+
) -> "WhisperResult":
|
808 |
+
"""
|
809 |
+
Move any start/end timestamps in silence parts of audio to the boundaries of the silence.
|
810 |
+
|
811 |
+
Parameters
|
812 |
+
----------
|
813 |
+
silent_starts : numpy.ndarray
|
814 |
+
An array starting timestamps of silent sections of audio.
|
815 |
+
silent_ends : numpy.ndarray
|
816 |
+
An array ending timestamps of silent sections of audio.
|
817 |
+
min_word_dur : float, default 0.1
|
818 |
+
Shortest duration each word is allowed to reach for adjustments.
|
819 |
+
word_level : bool, default False
|
820 |
+
Whether to settings to word level timestamps.
|
821 |
+
nonspeech_error : float, default 0.3
|
822 |
+
Relative error of non-speech sections that appear in between a word for adjustments.
|
823 |
+
use_word_position : bool, default True
|
824 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
825 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
826 |
+
|
827 |
+
Returns
|
828 |
+
-------
|
829 |
+
stable_whisper.result.WhisperResult
|
830 |
+
The current instance after the changes.
|
831 |
+
"""
|
832 |
+
for s in self.segments:
|
833 |
+
s.suppress_silence(
|
834 |
+
silent_starts,
|
835 |
+
silent_ends,
|
836 |
+
min_word_dur,
|
837 |
+
word_level=word_level,
|
838 |
+
nonspeech_error=nonspeech_error,
|
839 |
+
use_word_position=use_word_position
|
840 |
+
)
|
841 |
+
|
842 |
+
return self
|
843 |
+
|
844 |
+
def adjust_by_silence(
|
845 |
+
self,
|
846 |
+
audio: Union[torch.Tensor, np.ndarray, str, bytes],
|
847 |
+
vad: bool = False,
|
848 |
+
*,
|
849 |
+
verbose: (bool, None) = False,
|
850 |
+
sample_rate: int = None,
|
851 |
+
vad_onnx: bool = False,
|
852 |
+
vad_threshold: float = 0.35,
|
853 |
+
q_levels: int = 20,
|
854 |
+
k_size: int = 5,
|
855 |
+
min_word_dur: float = 0.1,
|
856 |
+
word_level: bool = True,
|
857 |
+
nonspeech_error: float = 0.3,
|
858 |
+
use_word_position: bool = True
|
859 |
+
|
860 |
+
) -> "WhisperResult":
|
861 |
+
"""
|
862 |
+
Adjust timestamps base detected speech gaps.
|
863 |
+
|
864 |
+
This is method combines :meth:`stable_whisper.result.WhisperResult.suppress_silence` with silence detection.
|
865 |
+
|
866 |
+
Parameters
|
867 |
+
----------
|
868 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
869 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
870 |
+
vad : bool, default False
|
871 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
872 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
873 |
+
verbose : bool or None, default False
|
874 |
+
If ``False``, mute messages about hitting local caches. Note that the message about first download cannot be
|
875 |
+
muted. Only applies if ``vad = True``.
|
876 |
+
sample_rate : int, default None, meaning ``whisper.audio.SAMPLE_RATE``, 16kHZ
|
877 |
+
The sample rate of ``audio``.
|
878 |
+
vad_onnx : bool, default False
|
879 |
+
Whether to use ONNX for Silero VAD.
|
880 |
+
vad_threshold : float, default 0.35
|
881 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
882 |
+
q_levels : int, default 20
|
883 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
884 |
+
Acts as a threshold to marking sound as silent.
|
885 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
886 |
+
k_size : int, default 5
|
887 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
888 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
889 |
+
min_word_dur : float, default 0.1
|
890 |
+
Shortest duration each word is allowed to reach from adjustments.
|
891 |
+
word_level : bool, default False
|
892 |
+
Whether to settings to word level timestamps.
|
893 |
+
nonspeech_error : float, default 0.3
|
894 |
+
Relative error of non-speech sections that appear in between a word for adjustments.
|
895 |
+
use_word_position : bool, default True
|
896 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
897 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
898 |
+
|
899 |
+
Returns
|
900 |
+
-------
|
901 |
+
stable_whisper.result.WhisperResult
|
902 |
+
The current instance after the changes.
|
903 |
+
|
904 |
+
Notes
|
905 |
+
-----
|
906 |
+
This operation is already performed by :func:`stable_whisper.whisper_word_level.transcribe_stable` /
|
907 |
+
:func:`stable_whisper.whisper_word_level.transcribe_minimal`/
|
908 |
+
:func:`stable_whisper.non_whisper.transcribe_any` / :func:`stable_whisper.alignment.align`
|
909 |
+
if ``suppress_silence = True``.
|
910 |
+
"""
|
911 |
+
if vad:
|
912 |
+
silent_timings = get_vad_silence_func(
|
913 |
+
onnx=vad_onnx,
|
914 |
+
verbose=verbose
|
915 |
+
)(audio, speech_threshold=vad_threshold, sr=sample_rate)
|
916 |
+
else:
|
917 |
+
silent_timings = mask2timing(
|
918 |
+
wav2mask(audio, q_levels=q_levels, k_size=k_size, sr=sample_rate)
|
919 |
+
)
|
920 |
+
if silent_timings is None:
|
921 |
+
return self
|
922 |
+
self.suppress_silence(
|
923 |
+
*silent_timings,
|
924 |
+
min_word_dur=min_word_dur,
|
925 |
+
word_level=word_level,
|
926 |
+
nonspeech_error=nonspeech_error,
|
927 |
+
use_word_position=use_word_position
|
928 |
+
)
|
929 |
+
self.update_nonspeech_sections(*silent_timings)
|
930 |
+
return self
|
931 |
+
|
932 |
+
def adjust_by_result(
|
933 |
+
self,
|
934 |
+
other_result: "WhisperResult",
|
935 |
+
min_word_dur: float = 0.1,
|
936 |
+
verbose: bool = False
|
937 |
+
):
|
938 |
+
"""
|
939 |
+
Minimize the duration of words using timestamps of another result.
|
940 |
+
|
941 |
+
Parameters
|
942 |
+
----------
|
943 |
+
other_result : "WhisperResult"
|
944 |
+
Timing data of the same words in a WhisperResult instance.
|
945 |
+
min_word_dur : float, default 0.1
|
946 |
+
Prevent changes to timestamps if the resultant word duration is less than ``min_word_dur``.
|
947 |
+
verbose : bool, default False
|
948 |
+
Whether to print out the timestamp changes.
|
949 |
+
"""
|
950 |
+
if not (self.has_words and other_result.has_words):
|
951 |
+
raise NotImplementedError('This operation can only be performed on results with word timestamps')
|
952 |
+
assert [w.word for w in self.all_words()] == [w.word for w in other_result.all_words()], \
|
953 |
+
'The words in [other_result] do not match the current words.'
|
954 |
+
for word, other_word in zip(self.all_words(), other_result.all_words()):
|
955 |
+
if word.end > other_word.start:
|
956 |
+
new_start = max(word.start, other_word.start)
|
957 |
+
new_end = min(word.end, other_word.end)
|
958 |
+
if new_end - new_start >= min_word_dur:
|
959 |
+
line = ''
|
960 |
+
if word.start != new_start:
|
961 |
+
if verbose:
|
962 |
+
line += f'[Start:{word.start:.3f}->{new_start:.3f}] '
|
963 |
+
word.start = new_start
|
964 |
+
if word.end != new_end:
|
965 |
+
if verbose:
|
966 |
+
line += f'[End:{word.end:.3f}->{new_end:.3f}] '
|
967 |
+
word.end = new_end
|
968 |
+
if line:
|
969 |
+
print(f'{line}"{word.word}"')
|
970 |
+
self.update_all_segs_with_words()
|
971 |
+
|
972 |
+
def reassign_ids(self, only_segments: bool = False):
|
973 |
+
for i, s in enumerate(self.segments):
|
974 |
+
s.id = i
|
975 |
+
if not only_segments:
|
976 |
+
s.reassign_ids()
|
977 |
+
|
978 |
+
def remove_no_word_segments(self, ignore_ori=False):
|
979 |
+
for i in reversed(range(len(self.segments))):
|
980 |
+
if (ignore_ori or self.segments[i].ori_has_words) and not self.segments[i].has_words:
|
981 |
+
del self.segments[i]
|
982 |
+
self.reassign_ids()
|
983 |
+
|
984 |
+
def get_locked_indices(self):
|
985 |
+
locked_indices = [i
|
986 |
+
for i, (left, right) in enumerate(zip(self.segments[1:], self.segments[:-1]))
|
987 |
+
if left.left_locked or right.right_locked]
|
988 |
+
return locked_indices
|
989 |
+
|
990 |
+
def get_gaps(self, as_ndarray=False):
|
991 |
+
s_ts = np.array([s.start for s in self.segments])
|
992 |
+
e_ts = np.array([s.end for s in self.segments])
|
993 |
+
gap = s_ts[1:] - e_ts[:-1]
|
994 |
+
return gap if as_ndarray else gap.tolist()
|
995 |
+
|
996 |
+
def get_gap_indices(self, min_gap: float = 0.1): # for merging
|
997 |
+
if len(self.segments) < 2:
|
998 |
+
return []
|
999 |
+
if min_gap is None:
|
1000 |
+
min_gap = 0
|
1001 |
+
indices = (self.get_gaps(True) <= min_gap).nonzero()[0].tolist()
|
1002 |
+
return sorted(set(indices) - set(self.get_locked_indices()))
|
1003 |
+
|
1004 |
+
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): # for merging
|
1005 |
+
if len(self.segments) < 2:
|
1006 |
+
return []
|
1007 |
+
if isinstance(punctuation, str):
|
1008 |
+
punctuation = [punctuation]
|
1009 |
+
indices = []
|
1010 |
+
for p in punctuation:
|
1011 |
+
if isinstance(p, str):
|
1012 |
+
for i, s in enumerate(self.segments[:-1]):
|
1013 |
+
if s.text.endswith(p):
|
1014 |
+
indices.append(i)
|
1015 |
+
elif i != 0 and s.text.startswith(p):
|
1016 |
+
indices.append(i-1)
|
1017 |
+
else:
|
1018 |
+
ending, beginning = p
|
1019 |
+
indices.extend([i for i, (s0, s1) in enumerate(zip(self.segments[:-1], self.segments[1:]))
|
1020 |
+
if s0.text.endswith(ending) and s1.text.startswith(beginning)])
|
1021 |
+
|
1022 |
+
return sorted(set(indices) - set(self.get_locked_indices()))
|
1023 |
+
|
1024 |
+
def all_words(self):
|
1025 |
+
return list(chain.from_iterable(s.words for s in self.segments))
|
1026 |
+
|
1027 |
+
def all_words_or_segments(self):
|
1028 |
+
return self.all_words() if self.has_words else self.segments
|
1029 |
+
|
1030 |
+
def all_words_by_lock(self, only_text: bool = True, by_segment: bool = False, include_single: bool = False):
|
1031 |
+
if by_segment:
|
1032 |
+
return [
|
1033 |
+
segment.words_by_lock(only_text=only_text, include_single=include_single)
|
1034 |
+
for segment in self.segments
|
1035 |
+
]
|
1036 |
+
return _words_by_lock(self.all_words(), only_text=only_text, include_single=include_single)
|
1037 |
+
|
1038 |
+
def all_tokens(self):
|
1039 |
+
return list(chain.from_iterable(s.tokens for s in self.all_words()))
|
1040 |
+
|
1041 |
+
def to_dict(self):
|
1042 |
+
return dict(text=self.text,
|
1043 |
+
segments=self.segments_to_dicts(),
|
1044 |
+
language=self.language,
|
1045 |
+
ori_dict=self.ori_dict,
|
1046 |
+
regroup_history=self._regroup_history,
|
1047 |
+
nonspeech_sections=self._nonspeech_sections)
|
1048 |
+
|
1049 |
+
def segments_to_dicts(self, reverse_text: Union[bool, tuple] = False):
|
1050 |
+
return [s.to_dict(reverse_text=reverse_text) for s in self.segments]
|
1051 |
+
|
1052 |
+
def _split_segments(self, get_indices, args: list = None, *, lock: bool = False, newline: bool = False):
|
1053 |
+
if args is None:
|
1054 |
+
args = []
|
1055 |
+
no_words = False
|
1056 |
+
for i in reversed(range(0, len(self.segments))):
|
1057 |
+
no_words = no_words or not self.segments[i].has_words
|
1058 |
+
indices = sorted(set(get_indices(self.segments[i], *args)))
|
1059 |
+
if not indices:
|
1060 |
+
continue
|
1061 |
+
if newline:
|
1062 |
+
if indices[-1] == len(self.segments[i].words) - 1:
|
1063 |
+
del indices[-1]
|
1064 |
+
if not indices:
|
1065 |
+
continue
|
1066 |
+
|
1067 |
+
for word_idx in indices:
|
1068 |
+
if self.segments[i].words[word_idx].word.endswith('\n'):
|
1069 |
+
continue
|
1070 |
+
self.segments[i].words[word_idx].word += '\n'
|
1071 |
+
if lock:
|
1072 |
+
self.segments[i].words[word_idx].lock_right()
|
1073 |
+
if word_idx + 1 < len(self.segments[i].words):
|
1074 |
+
self.segments[i].words[word_idx+1].lock_left()
|
1075 |
+
self.segments[i].update_seg_with_words()
|
1076 |
+
else:
|
1077 |
+
new_segments = self.segments[i].split(indices)
|
1078 |
+
if lock:
|
1079 |
+
for s in new_segments:
|
1080 |
+
if s == new_segments[0]:
|
1081 |
+
s.lock_right()
|
1082 |
+
elif s == new_segments[-1]:
|
1083 |
+
s.lock_left()
|
1084 |
+
else:
|
1085 |
+
s.lock_both()
|
1086 |
+
del self.segments[i]
|
1087 |
+
for s in reversed(new_segments):
|
1088 |
+
self.segments.insert(i, s)
|
1089 |
+
if no_words:
|
1090 |
+
warnings.warn('Found segment(s) without word timings. These segment(s) cannot be split.')
|
1091 |
+
self.remove_no_word_segments()
|
1092 |
+
|
1093 |
+
def _merge_segments(self, indices: List[int],
|
1094 |
+
*, max_words: int = None, max_chars: int = None, is_sum_max: bool = False, lock: bool = False):
|
1095 |
+
if len(indices) == 0:
|
1096 |
+
return
|
1097 |
+
for i in reversed(indices):
|
1098 |
+
seg = self.segments[i]
|
1099 |
+
if (
|
1100 |
+
(
|
1101 |
+
max_words and
|
1102 |
+
seg.has_words and
|
1103 |
+
(
|
1104 |
+
(seg.word_count() + self.segments[i + 1].word_count() > max_words)
|
1105 |
+
if is_sum_max else
|
1106 |
+
(seg.word_count() > max_words and self.segments[i + 1].word_count() > max_words)
|
1107 |
+
)
|
1108 |
+
) or
|
1109 |
+
(
|
1110 |
+
max_chars and
|
1111 |
+
(
|
1112 |
+
(seg.char_count() + self.segments[i + 1].char_count() > max_chars)
|
1113 |
+
if is_sum_max else
|
1114 |
+
(seg.char_count() > max_chars and self.segments[i + 1].char_count() > max_chars)
|
1115 |
+
)
|
1116 |
+
)
|
1117 |
+
):
|
1118 |
+
continue
|
1119 |
+
self.add_segments(i, i + 1, inplace=True, lock=lock)
|
1120 |
+
self.remove_no_word_segments()
|
1121 |
+
|
1122 |
+
def get_content_by_time(
|
1123 |
+
self,
|
1124 |
+
time: Union[float, Tuple[float, float], dict],
|
1125 |
+
within: bool = False,
|
1126 |
+
segment_level: bool = False
|
1127 |
+
) -> Union[List[WordTiming], List[Segment]]:
|
1128 |
+
"""
|
1129 |
+
Return content in the ``time`` range.
|
1130 |
+
|
1131 |
+
Parameters
|
1132 |
+
----------
|
1133 |
+
time : float or tuple of (float, float) or dict
|
1134 |
+
Range of time to find content. For tuple of two floats, first value is the start time and second value is
|
1135 |
+
the end time. For a single float value, it is treated as both the start and end time.
|
1136 |
+
within : bool, default False
|
1137 |
+
Whether to only find content fully overlaps with ``time`` range.
|
1138 |
+
segment_level : bool, default False
|
1139 |
+
Whether to look only on the segment level and return instances of :class:`stable_whisper.result.Segment`
|
1140 |
+
instead of :class:`stable_whisper.result.WordTiming`.
|
1141 |
+
|
1142 |
+
Returns
|
1143 |
+
-------
|
1144 |
+
list of stable_whisper.result.WordTiming or list of stable_whisper.result.Segment
|
1145 |
+
List of contents in the ``time`` range. The contents are instances of
|
1146 |
+
:class:`stable_whisper.result.Segment` if ``segment_level = True`` else
|
1147 |
+
:class:`stable_whisper.result.WordTiming`.
|
1148 |
+
"""
|
1149 |
+
if not segment_level and not self.has_words:
|
1150 |
+
raise ValueError('Missing word timestamps in result. Use ``segment_level=True`` instead.')
|
1151 |
+
contents = self.segments if segment_level else self.all_words()
|
1152 |
+
if isinstance(time, (float, int)):
|
1153 |
+
time = [time, time]
|
1154 |
+
elif isinstance(time, dict):
|
1155 |
+
time = [time['start'], time['end']]
|
1156 |
+
start, end = time
|
1157 |
+
|
1158 |
+
if within:
|
1159 |
+
def is_in_range(c):
|
1160 |
+
return start <= c.start and end >= c.end
|
1161 |
+
else:
|
1162 |
+
def is_in_range(c):
|
1163 |
+
return start <= c.end and end >= c.start
|
1164 |
+
|
1165 |
+
return [c for c in contents if is_in_range(c)]
|
1166 |
+
|
1167 |
+
def split_by_gap(
|
1168 |
+
self,
|
1169 |
+
max_gap: float = 0.1,
|
1170 |
+
lock: bool = False,
|
1171 |
+
newline: bool = False
|
1172 |
+
) -> "WhisperResult":
|
1173 |
+
"""
|
1174 |
+
Split (in-place) any segment where the gap between two of its words is greater than ``max_gap``.
|
1175 |
+
|
1176 |
+
Parameters
|
1177 |
+
----------
|
1178 |
+
max_gap : float, default 0.1
|
1179 |
+
Maximum second(s) allowed between two words if the same segment.
|
1180 |
+
lock : bool, default False
|
1181 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1182 |
+
newline: bool, default False
|
1183 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1184 |
+
|
1185 |
+
Returns
|
1186 |
+
-------
|
1187 |
+
stable_whisper.result.WhisperResult
|
1188 |
+
The current instance after the changes.
|
1189 |
+
"""
|
1190 |
+
self._split_segments(lambda x: x.get_gap_indices(max_gap), lock=lock, newline=newline)
|
1191 |
+
if self._regroup_history:
|
1192 |
+
self._regroup_history += '_'
|
1193 |
+
self._regroup_history += f'sg={max_gap}+{int(lock)}+{int(newline)}'
|
1194 |
+
return self
|
1195 |
+
|
1196 |
+
def merge_by_gap(
|
1197 |
+
self,
|
1198 |
+
min_gap: float = 0.1,
|
1199 |
+
max_words: int = None,
|
1200 |
+
max_chars: int = None,
|
1201 |
+
is_sum_max: bool = False,
|
1202 |
+
lock: bool = False
|
1203 |
+
) -> "WhisperResult":
|
1204 |
+
"""
|
1205 |
+
Merge (in-place) any pair of adjacent segments if the gap between them <= ``min_gap``.
|
1206 |
+
|
1207 |
+
Parameters
|
1208 |
+
----------
|
1209 |
+
min_gap : float, default 0.1
|
1210 |
+
Minimum second(s) allow between two segment.
|
1211 |
+
max_words : int, optional
|
1212 |
+
Maximum number of words allowed in each segment.
|
1213 |
+
max_chars : int, optional
|
1214 |
+
Maximum number of characters allowed in each segment.
|
1215 |
+
is_sum_max : bool, default False
|
1216 |
+
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
|
1217 |
+
to be merged.
|
1218 |
+
lock : bool, default False
|
1219 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1220 |
+
|
1221 |
+
Returns
|
1222 |
+
-------
|
1223 |
+
stable_whisper.result.WhisperResult
|
1224 |
+
The current instance after the changes.
|
1225 |
+
"""
|
1226 |
+
indices = self.get_gap_indices(min_gap)
|
1227 |
+
self._merge_segments(indices,
|
1228 |
+
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock)
|
1229 |
+
if self._regroup_history:
|
1230 |
+
self._regroup_history += '_'
|
1231 |
+
self._regroup_history += f'mg={min_gap}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}'
|
1232 |
+
return self
|
1233 |
+
|
1234 |
+
def split_by_punctuation(
|
1235 |
+
self,
|
1236 |
+
punctuation: Union[List[str], List[Tuple[str, str]], str],
|
1237 |
+
lock: bool = False,
|
1238 |
+
newline: bool = False,
|
1239 |
+
min_words: Optional[int] = None,
|
1240 |
+
min_chars: Optional[int] = None,
|
1241 |
+
min_dur: Optional[int] = None
|
1242 |
+
) -> "WhisperResult":
|
1243 |
+
"""
|
1244 |
+
Split (in-place) segments at words that start/end with ``punctuation``.
|
1245 |
+
|
1246 |
+
Parameters
|
1247 |
+
----------
|
1248 |
+
punctuation : list of str of list of tuple of (str, str) or str
|
1249 |
+
Punctuation(s) to split segments by.
|
1250 |
+
lock : bool, default False
|
1251 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1252 |
+
newline : bool, default False
|
1253 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1254 |
+
min_words : int, optional
|
1255 |
+
Split segments with words >= ``min_words``.
|
1256 |
+
min_chars : int, optional
|
1257 |
+
Split segments with characters >= ``min_chars``.
|
1258 |
+
min_dur : int, optional
|
1259 |
+
split segments with duration (in seconds) >= ``min_dur``.
|
1260 |
+
|
1261 |
+
Returns
|
1262 |
+
-------
|
1263 |
+
stable_whisper.result.WhisperResult
|
1264 |
+
The current instance after the changes.
|
1265 |
+
"""
|
1266 |
+
def _over_max(x: Segment):
|
1267 |
+
return (
|
1268 |
+
(min_words and len(x.words) >= min_words) or
|
1269 |
+
(min_chars and x.char_count() >= min_chars) or
|
1270 |
+
(min_dur and x.duration >= min_dur)
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
indices = set(s.id for s in self.segments if _over_max(s)) if any((min_words, min_chars, min_dur)) else None
|
1274 |
+
|
1275 |
+
def _get_indices(x: Segment):
|
1276 |
+
return x.get_punctuation_indices(punctuation) if indices is None or x.id in indices else []
|
1277 |
+
|
1278 |
+
self._split_segments(_get_indices, lock=lock, newline=newline)
|
1279 |
+
if self._regroup_history:
|
1280 |
+
self._regroup_history += '_'
|
1281 |
+
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation)
|
1282 |
+
self._regroup_history += f'sp={punct_str}+{int(lock)}+{int(newline)}'
|
1283 |
+
self._regroup_history += f'+{min_words or ""}+{min_chars or ""}+{min_dur or ""}'.rstrip('+')
|
1284 |
+
return self
|
1285 |
+
|
1286 |
+
def merge_by_punctuation(
|
1287 |
+
self,
|
1288 |
+
punctuation: Union[List[str], List[Tuple[str, str]], str],
|
1289 |
+
max_words: int = None,
|
1290 |
+
max_chars: int = None,
|
1291 |
+
is_sum_max: bool = False,
|
1292 |
+
lock: bool = False
|
1293 |
+
) -> "WhisperResult":
|
1294 |
+
"""
|
1295 |
+
Merge (in-place) any two segments that has specific punctuations inbetween.
|
1296 |
+
|
1297 |
+
Parameters
|
1298 |
+
----------
|
1299 |
+
punctuation : list of str of list of tuple of (str, str) or str
|
1300 |
+
Punctuation(s) to merge segments by.
|
1301 |
+
max_words : int, optional
|
1302 |
+
Maximum number of words allowed in each segment.
|
1303 |
+
max_chars : int, optional
|
1304 |
+
Maximum number of characters allowed in each segment.
|
1305 |
+
is_sum_max : bool, default False
|
1306 |
+
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
|
1307 |
+
to be merged.
|
1308 |
+
lock : bool, default False
|
1309 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1310 |
+
|
1311 |
+
Returns
|
1312 |
+
-------
|
1313 |
+
stable_whisper.result.WhisperResult
|
1314 |
+
The current instance after the changes.
|
1315 |
+
"""
|
1316 |
+
indices = self.get_punctuation_indices(punctuation)
|
1317 |
+
self._merge_segments(indices,
|
1318 |
+
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock)
|
1319 |
+
if self._regroup_history:
|
1320 |
+
self._regroup_history += '_'
|
1321 |
+
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation)
|
1322 |
+
self._regroup_history += f'mp={punct_str}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}'
|
1323 |
+
return self
|
1324 |
+
|
1325 |
+
def merge_all_segments(self) -> "WhisperResult":
|
1326 |
+
"""
|
1327 |
+
Merge all segments into one segment.
|
1328 |
+
|
1329 |
+
Returns
|
1330 |
+
-------
|
1331 |
+
stable_whisper.result.WhisperResult
|
1332 |
+
The current instance after the changes.
|
1333 |
+
"""
|
1334 |
+
if not self.segments:
|
1335 |
+
return self
|
1336 |
+
if self.has_words:
|
1337 |
+
self.segments[0].words = self.all_words()
|
1338 |
+
else:
|
1339 |
+
self.segments[0].text += ''.join(s.text for s in self.segments[1:])
|
1340 |
+
if all(s.tokens is not None for s in self.segments):
|
1341 |
+
self.segments[0].tokens += list(chain.from_iterable(s.tokens for s in self.segments[1:]))
|
1342 |
+
self.segments[0].end = self.segments[-1].end
|
1343 |
+
self.segments = [self.segments[0]]
|
1344 |
+
self.reassign_ids()
|
1345 |
+
self.update_all_segs_with_words()
|
1346 |
+
if self._regroup_history:
|
1347 |
+
self._regroup_history += '_'
|
1348 |
+
self._regroup_history += 'ms'
|
1349 |
+
return self
|
1350 |
+
|
1351 |
+
def split_by_length(
|
1352 |
+
self,
|
1353 |
+
max_chars: int = None,
|
1354 |
+
max_words: int = None,
|
1355 |
+
even_split: bool = True,
|
1356 |
+
force_len: bool = False,
|
1357 |
+
lock: bool = False,
|
1358 |
+
include_lock: bool = False,
|
1359 |
+
newline: bool = False
|
1360 |
+
) -> "WhisperResult":
|
1361 |
+
"""
|
1362 |
+
Split (in-place) any segment that exceeds ``max_chars`` or ``max_words`` into smaller segments.
|
1363 |
+
|
1364 |
+
Parameters
|
1365 |
+
----------
|
1366 |
+
max_chars : int, optional
|
1367 |
+
Maximum number of characters allowed in each segment.
|
1368 |
+
max_words : int, optional
|
1369 |
+
Maximum number of words allowed in each segment.
|
1370 |
+
even_split : bool, default True
|
1371 |
+
Whether to evenly split a segment in length if it exceeds ``max_chars`` or ``max_words``.
|
1372 |
+
force_len : bool, default False
|
1373 |
+
Whether to force a constant length for each segment except the last segment.
|
1374 |
+
This will ignore all previous non-locked segment boundaries.
|
1375 |
+
lock : bool, default False
|
1376 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1377 |
+
include_lock: bool, default False
|
1378 |
+
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
|
1379 |
+
Splitting will be done after the first non-locked word > ``max_chars`` / ``max_words``.
|
1380 |
+
newline: bool, default False
|
1381 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1382 |
+
|
1383 |
+
Returns
|
1384 |
+
-------
|
1385 |
+
stable_whisper.result.WhisperResult
|
1386 |
+
The current instance after the changes.
|
1387 |
+
|
1388 |
+
Notes
|
1389 |
+
-----
|
1390 |
+
If ``even_split = True``, segments can still exceed ``max_chars`` and locked words will be ignored to avoid
|
1391 |
+
uneven splitting.
|
1392 |
+
"""
|
1393 |
+
if force_len:
|
1394 |
+
self.merge_all_segments()
|
1395 |
+
self._split_segments(
|
1396 |
+
lambda x: x.get_length_indices(
|
1397 |
+
max_chars=max_chars,
|
1398 |
+
max_words=max_words,
|
1399 |
+
even_split=even_split,
|
1400 |
+
include_lock=include_lock
|
1401 |
+
),
|
1402 |
+
lock=lock,
|
1403 |
+
newline=newline
|
1404 |
+
)
|
1405 |
+
if self._regroup_history:
|
1406 |
+
self._regroup_history += '_'
|
1407 |
+
self._regroup_history += (f'sl={max_chars or ""}+{max_words or ""}+{int(even_split)}+{int(force_len)}'
|
1408 |
+
f'+{int(lock)}+{int(include_lock)}+{int(newline)}')
|
1409 |
+
return self
|
1410 |
+
|
1411 |
+
def split_by_duration(
|
1412 |
+
self,
|
1413 |
+
max_dur: float,
|
1414 |
+
even_split: bool = True,
|
1415 |
+
force_len: bool = False,
|
1416 |
+
lock: bool = False,
|
1417 |
+
include_lock: bool = False,
|
1418 |
+
newline: bool = False
|
1419 |
+
) -> "WhisperResult":
|
1420 |
+
"""
|
1421 |
+
Split (in-place) any segment that exceeds ``max_dur`` into smaller segments.
|
1422 |
+
|
1423 |
+
Parameters
|
1424 |
+
----------
|
1425 |
+
max_dur : float
|
1426 |
+
Maximum duration (in seconds) per segment.
|
1427 |
+
even_split : bool, default True
|
1428 |
+
Whether to evenly split a segment in length if it exceeds ``max_dur``.
|
1429 |
+
force_len : bool, default False
|
1430 |
+
Whether to force a constant length for each segment except the last segment.
|
1431 |
+
This will ignore all previous non-locked segment boundaries.
|
1432 |
+
lock : bool, default False
|
1433 |
+
Whether to prevent future splits/merges from altering changes made by this method.
|
1434 |
+
include_lock: bool, default False
|
1435 |
+
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
|
1436 |
+
Splitting will be done after the first non-locked word > ``max_dur``.
|
1437 |
+
newline: bool, default False
|
1438 |
+
Whether to insert line break at the split points instead of splitting into separate segments.
|
1439 |
+
|
1440 |
+
Returns
|
1441 |
+
-------
|
1442 |
+
stable_whisper.result.WhisperResult
|
1443 |
+
The current instance after the changes.
|
1444 |
+
|
1445 |
+
Notes
|
1446 |
+
-----
|
1447 |
+
If ``even_split = True``, segments can still exceed ``max_dur`` and locked words will be ignored to avoid
|
1448 |
+
uneven splitting.
|
1449 |
+
"""
|
1450 |
+
if force_len:
|
1451 |
+
self.merge_all_segments()
|
1452 |
+
self._split_segments(
|
1453 |
+
lambda x: x.get_duration_indices(
|
1454 |
+
max_dur=max_dur,
|
1455 |
+
even_split=even_split,
|
1456 |
+
include_lock=include_lock
|
1457 |
+
),
|
1458 |
+
lock=lock,
|
1459 |
+
newline=newline
|
1460 |
+
)
|
1461 |
+
if self._regroup_history:
|
1462 |
+
self._regroup_history += '_'
|
1463 |
+
self._regroup_history += (f'sd={max_dur}+{int(even_split)}+{int(force_len)}'
|
1464 |
+
f'+{int(lock)}+{int(include_lock)}+{int(newline)}')
|
1465 |
+
return self
|
1466 |
+
|
1467 |
+
def clamp_max(
|
1468 |
+
self,
|
1469 |
+
medium_factor: float = 2.5,
|
1470 |
+
max_dur: float = None,
|
1471 |
+
clip_start: Optional[bool] = None,
|
1472 |
+
verbose: bool = False
|
1473 |
+
) -> "WhisperResult":
|
1474 |
+
"""
|
1475 |
+
Clamp all word durations above certain value.
|
1476 |
+
|
1477 |
+
This is most effective when applied before and after other regroup operations.
|
1478 |
+
|
1479 |
+
Parameters
|
1480 |
+
----------
|
1481 |
+
medium_factor : float, default 2.5
|
1482 |
+
Clamp durations above (``medium_factor`` * medium duration) per segment.
|
1483 |
+
If ``medium_factor = None/0`` or segment has less than 3 words, it will be ignored and use only ``max_dur``.
|
1484 |
+
max_dur : float, optional
|
1485 |
+
Clamp durations above ``max_dur``.
|
1486 |
+
clip_start : bool or None, default None
|
1487 |
+
Whether to clamp the start of a word. If ``None``, clamp the start of first word and end of last word per
|
1488 |
+
segment.
|
1489 |
+
verbose : bool, default False
|
1490 |
+
Whether to print out the timestamp changes.
|
1491 |
+
|
1492 |
+
Returns
|
1493 |
+
-------
|
1494 |
+
stable_whisper.result.WhisperResult
|
1495 |
+
The current instance after the changes.
|
1496 |
+
"""
|
1497 |
+
if not (medium_factor or max_dur):
|
1498 |
+
raise ValueError('At least one of following arguments requires non-zero value: medium_factor; max_dur')
|
1499 |
+
|
1500 |
+
if not self.has_words:
|
1501 |
+
warnings.warn('Cannot clamp due to missing/no word-timestamps')
|
1502 |
+
return self
|
1503 |
+
|
1504 |
+
for seg in self.segments:
|
1505 |
+
curr_max_dur = None
|
1506 |
+
if medium_factor and len(seg.words) > 2:
|
1507 |
+
durations = np.array([word.duration for word in seg.words])
|
1508 |
+
durations.sort()
|
1509 |
+
curr_max_dur = medium_factor * durations[len(durations)//2 + 1]
|
1510 |
+
|
1511 |
+
if max_dur and (not curr_max_dur or curr_max_dur > max_dur):
|
1512 |
+
curr_max_dur = max_dur
|
1513 |
+
|
1514 |
+
if not curr_max_dur:
|
1515 |
+
continue
|
1516 |
+
|
1517 |
+
if clip_start is None:
|
1518 |
+
seg.words[0].clamp_max(curr_max_dur, clip_start=True, verbose=verbose)
|
1519 |
+
seg.words[-1].clamp_max(curr_max_dur, clip_start=False, verbose=verbose)
|
1520 |
+
else:
|
1521 |
+
for i, word in enumerate(seg.words):
|
1522 |
+
word.clamp_max(curr_max_dur, clip_start=clip_start, verbose=verbose)
|
1523 |
+
|
1524 |
+
seg.update_seg_with_words()
|
1525 |
+
if self._regroup_history:
|
1526 |
+
self._regroup_history += '_'
|
1527 |
+
self._regroup_history += f'cm={medium_factor}+{max_dur or ""}+{clip_start or ""}+{int(verbose)}'
|
1528 |
+
return self
|
1529 |
+
|
1530 |
+
def lock(
|
1531 |
+
self,
|
1532 |
+
startswith: Union[str, List[str]] = None,
|
1533 |
+
endswith: Union[str, List[str]] = None,
|
1534 |
+
right: bool = True,
|
1535 |
+
left: bool = False,
|
1536 |
+
case_sensitive: bool = False,
|
1537 |
+
strip: bool = True
|
1538 |
+
) -> "WhisperResult":
|
1539 |
+
"""
|
1540 |
+
Lock words/segments with matching prefix/suffix to prevent splitting/merging.
|
1541 |
+
|
1542 |
+
Parameters
|
1543 |
+
----------
|
1544 |
+
startswith: str or list of str
|
1545 |
+
Prefixes to lock.
|
1546 |
+
endswith: str or list of str
|
1547 |
+
Suffixes to lock.
|
1548 |
+
right : bool, default True
|
1549 |
+
Whether prevent splits/merges with the next word/segment.
|
1550 |
+
left : bool, default False
|
1551 |
+
Whether prevent splits/merges with the previous word/segment.
|
1552 |
+
case_sensitive : bool, default False
|
1553 |
+
Whether to match the case of the prefixes/suffixes with the words/segments.
|
1554 |
+
strip : bool, default True
|
1555 |
+
Whether to ignore spaces before and after both words/segments and prefixes/suffixes.
|
1556 |
+
|
1557 |
+
Returns
|
1558 |
+
-------
|
1559 |
+
stable_whisper.result.WhisperResult
|
1560 |
+
The current instance after the changes.
|
1561 |
+
"""
|
1562 |
+
assert startswith or endswith, 'Must specify [startswith] or/and [endswith].'
|
1563 |
+
startswith = [] if startswith is None else ([startswith] if isinstance(startswith, str) else startswith)
|
1564 |
+
endswith = [] if endswith is None else ([endswith] if isinstance(endswith, str) else endswith)
|
1565 |
+
if not case_sensitive:
|
1566 |
+
startswith = [t.lower() for t in startswith]
|
1567 |
+
endswith = [t.lower() for t in endswith]
|
1568 |
+
if strip:
|
1569 |
+
startswith = [t.strip() for t in startswith]
|
1570 |
+
endswith = [t.strip() for t in endswith]
|
1571 |
+
for part in self.all_words_or_segments():
|
1572 |
+
text = part.word if hasattr(part, 'word') else part.text
|
1573 |
+
if not case_sensitive:
|
1574 |
+
text = text.lower()
|
1575 |
+
if strip:
|
1576 |
+
text = text.strip()
|
1577 |
+
for prefix in startswith:
|
1578 |
+
if text.startswith(prefix):
|
1579 |
+
if right:
|
1580 |
+
part.lock_right()
|
1581 |
+
if left:
|
1582 |
+
part.lock_left()
|
1583 |
+
for suffix in endswith:
|
1584 |
+
if text.endswith(suffix):
|
1585 |
+
if right:
|
1586 |
+
part.lock_right()
|
1587 |
+
if left:
|
1588 |
+
part.lock_left()
|
1589 |
+
if self._regroup_history:
|
1590 |
+
self._regroup_history += '_'
|
1591 |
+
startswith_str = (startswith if isinstance(startswith, str) else '/'.join(startswith)) if startswith else ""
|
1592 |
+
endswith_str = (endswith if isinstance(endswith, str) else '/'.join(endswith)) if endswith else ""
|
1593 |
+
self._regroup_history += (f'l={startswith_str}+{endswith_str}'
|
1594 |
+
f'+{int(right)}+{int(left)}+{int(case_sensitive)}+{int(strip)}')
|
1595 |
+
return self
|
1596 |
+
|
1597 |
+
def remove_word(
|
1598 |
+
self,
|
1599 |
+
word: Union[WordTiming, Tuple[int, int]],
|
1600 |
+
reassign_ids: bool = True,
|
1601 |
+
verbose: bool = True
|
1602 |
+
) -> 'WhisperResult':
|
1603 |
+
"""
|
1604 |
+
Remove a word.
|
1605 |
+
|
1606 |
+
Parameters
|
1607 |
+
----------
|
1608 |
+
word : WordTiming or tuple of (int, int)
|
1609 |
+
Instance of :class:`stable_whisper.result.WordTiming` or tuple of (segment index, word index).
|
1610 |
+
reassign_ids : bool, default True
|
1611 |
+
Whether to reassign segment and word ids (indices) after removing ``word``.
|
1612 |
+
verbose : bool, default True
|
1613 |
+
Whether to print detail of the removed word.
|
1614 |
+
|
1615 |
+
Returns
|
1616 |
+
-------
|
1617 |
+
stable_whisper.result.WhisperResult
|
1618 |
+
The current instance after the changes.
|
1619 |
+
"""
|
1620 |
+
if isinstance(word, WordTiming):
|
1621 |
+
if self[word.segment_id][word.id] is not word:
|
1622 |
+
self.reassign_ids()
|
1623 |
+
if self[word.segment_id][word.id] is not word:
|
1624 |
+
raise ValueError('word not in result')
|
1625 |
+
seg_id, word_id = word.segment_id, word.id
|
1626 |
+
else:
|
1627 |
+
seg_id, word_id = word
|
1628 |
+
if verbose:
|
1629 |
+
print(f'Removed: {self[seg_id][word_id].to_dict()}')
|
1630 |
+
del self.segments[seg_id].words[word_id]
|
1631 |
+
if not reassign_ids:
|
1632 |
+
return self
|
1633 |
+
if self[seg_id].has_words:
|
1634 |
+
self[seg_id].reassign_ids()
|
1635 |
+
else:
|
1636 |
+
self.remove_no_word_segments()
|
1637 |
+
return self
|
1638 |
+
|
1639 |
+
def remove_segment(
|
1640 |
+
self,
|
1641 |
+
segment: Union[Segment, int],
|
1642 |
+
reassign_ids: bool = True,
|
1643 |
+
verbose: bool = True
|
1644 |
+
) -> 'WhisperResult':
|
1645 |
+
"""
|
1646 |
+
Remove a segment.
|
1647 |
+
|
1648 |
+
Parameters
|
1649 |
+
----------
|
1650 |
+
segment : Segment or int
|
1651 |
+
Instance :class:`stable_whisper.result.Segment` or segment index.
|
1652 |
+
reassign_ids : bool, default True
|
1653 |
+
Whether to reassign segment IDs (indices) after removing ``segment``.
|
1654 |
+
verbose : bool, default True
|
1655 |
+
Whether to print detail of the removed word.
|
1656 |
+
|
1657 |
+
Returns
|
1658 |
+
-------
|
1659 |
+
stable_whisper.result.WhisperResult
|
1660 |
+
The current instance after the changes.
|
1661 |
+
"""
|
1662 |
+
if isinstance(segment, Segment):
|
1663 |
+
if self[segment.id] is not segment:
|
1664 |
+
self.reassign_ids()
|
1665 |
+
if self[segment.id] is not segment:
|
1666 |
+
raise ValueError('segment not in result')
|
1667 |
+
segment = segment.id
|
1668 |
+
if verbose:
|
1669 |
+
print(f'Removed: [id:{self[segment].id}] {self[segment].to_display_str(True)}')
|
1670 |
+
del self.segments[segment]
|
1671 |
+
if not reassign_ids:
|
1672 |
+
return self
|
1673 |
+
self.reassign_ids(True)
|
1674 |
+
return self
|
1675 |
+
|
1676 |
+
def remove_repetition(
|
1677 |
+
self,
|
1678 |
+
max_words: int = 1,
|
1679 |
+
case_sensitive: bool = False,
|
1680 |
+
strip: bool = True,
|
1681 |
+
ignore_punctuations: str = "\"',.?!",
|
1682 |
+
extend_duration: bool = True,
|
1683 |
+
verbose: bool = True
|
1684 |
+
) -> 'WhisperResult':
|
1685 |
+
"""
|
1686 |
+
Remove words that repeat consecutively.
|
1687 |
+
|
1688 |
+
Parameters
|
1689 |
+
----------
|
1690 |
+
max_words : int
|
1691 |
+
Maximum number of words to look for consecutively.
|
1692 |
+
case_sensitive : bool, default False
|
1693 |
+
Whether the case of words need to match to be considered as repetition.
|
1694 |
+
strip : bool, default True
|
1695 |
+
Whether to ignore spaces before and after each word.
|
1696 |
+
ignore_punctuations : bool, default '"',.?!'
|
1697 |
+
Ending punctuations to ignore.
|
1698 |
+
extend_duration: bool, default True
|
1699 |
+
Whether to extend the duration of the previous word to cover the duration of the repetition.
|
1700 |
+
verbose: bool, default True
|
1701 |
+
Whether to print detail of the removed repetitions.
|
1702 |
+
|
1703 |
+
Returns
|
1704 |
+
-------
|
1705 |
+
stable_whisper.result.WhisperResult
|
1706 |
+
The current instance after the changes.
|
1707 |
+
"""
|
1708 |
+
if not self.has_words:
|
1709 |
+
return self
|
1710 |
+
|
1711 |
+
for count in range(1, max_words + 1):
|
1712 |
+
all_words = self.all_words()
|
1713 |
+
if len(all_words) < 2:
|
1714 |
+
return self
|
1715 |
+
all_words_str = [w.word for w in all_words]
|
1716 |
+
if strip:
|
1717 |
+
all_words_str = [w.strip() for w in all_words_str]
|
1718 |
+
if ignore_punctuations:
|
1719 |
+
ptn = f'[{ignore_punctuations}]+$'
|
1720 |
+
all_words_str = [re.sub(ptn, '', w) for w in all_words_str]
|
1721 |
+
if not case_sensitive:
|
1722 |
+
all_words_str = [w.lower() for w in all_words_str]
|
1723 |
+
next_i = None
|
1724 |
+
changes = []
|
1725 |
+
for i in reversed(range(count*2, len(all_words_str)+1)):
|
1726 |
+
if next_i is not None:
|
1727 |
+
if next_i != i:
|
1728 |
+
continue
|
1729 |
+
else:
|
1730 |
+
next_i = None
|
1731 |
+
s = i - count
|
1732 |
+
if all_words_str[s - count:s] != all_words_str[s:i]:
|
1733 |
+
continue
|
1734 |
+
next_i = s
|
1735 |
+
if extend_duration:
|
1736 |
+
all_words[s-1].end = all_words[i-1].end
|
1737 |
+
temp_changes = []
|
1738 |
+
for j in reversed(range(s, i)):
|
1739 |
+
if verbose:
|
1740 |
+
temp_changes.append(f'- {all_words[j].to_dict()}')
|
1741 |
+
self.remove_word(all_words[j], False, verbose=False)
|
1742 |
+
if temp_changes:
|
1743 |
+
changes.append(
|
1744 |
+
f'Remove: [{format_timestamp(all_words[s].start)} -> {format_timestamp(all_words[i-1].end)}] '
|
1745 |
+
+ ''.join(_w.word for _w in all_words[s:i]) + '\n'
|
1746 |
+
+ '\n'.join(reversed(temp_changes)) + '\n'
|
1747 |
+
)
|
1748 |
+
for i0, i1 in zip(range(s - count, s), range(s, i)):
|
1749 |
+
if len(all_words[i0].word) < len(all_words[i1].word):
|
1750 |
+
all_words[i1].start = all_words[i0].start
|
1751 |
+
all_words[i1].end = all_words[i0].end
|
1752 |
+
_sid, _wid = all_words[i0].segment_id, all_words[i0].id
|
1753 |
+
self.segments[_sid].words[_wid] = all_words[i1]
|
1754 |
+
|
1755 |
+
if changes:
|
1756 |
+
print('\n'.join(reversed(changes)))
|
1757 |
+
|
1758 |
+
self.remove_no_word_segments()
|
1759 |
+
self.update_all_segs_with_words()
|
1760 |
+
|
1761 |
+
return self
|
1762 |
+
|
1763 |
+
def remove_words_by_str(
|
1764 |
+
self,
|
1765 |
+
words: Union[str, List[str], None],
|
1766 |
+
case_sensitive: bool = False,
|
1767 |
+
strip: bool = True,
|
1768 |
+
ignore_punctuations: str = "\"',.?!",
|
1769 |
+
min_prob: float = None,
|
1770 |
+
filters: Callable = None,
|
1771 |
+
verbose: bool = True
|
1772 |
+
) -> 'WhisperResult':
|
1773 |
+
"""
|
1774 |
+
Remove words that match ``words``.
|
1775 |
+
|
1776 |
+
Parameters
|
1777 |
+
----------
|
1778 |
+
words : str or list of str or None
|
1779 |
+
A word or list of words to remove.``None`` for all words to be passed into ``filters``.
|
1780 |
+
case_sensitive : bool, default False
|
1781 |
+
Whether the case of words need to match to be considered as repetition.
|
1782 |
+
strip : bool, default True
|
1783 |
+
Whether to ignore spaces before and after each word.
|
1784 |
+
ignore_punctuations : bool, default '"',.?!'
|
1785 |
+
Ending punctuations to ignore.
|
1786 |
+
min_prob : float, optional
|
1787 |
+
Acts as the first filter the for the words that match ``words``. Words with probability < ``min_prob`` will
|
1788 |
+
be removed if ``filters`` is ``None``, else pass the words into ``filters``. Words without probability will
|
1789 |
+
be treated as having probability < ``min_prob``.
|
1790 |
+
filters : Callable, optional
|
1791 |
+
A function that takes an instance of :class:`stable_whisper.result.WordTiming` as its only argument.
|
1792 |
+
This function is custom filter for the words that match ``words`` and were not caught by ``min_prob``.
|
1793 |
+
verbose:
|
1794 |
+
Whether to print detail of the removed words.
|
1795 |
+
|
1796 |
+
Returns
|
1797 |
+
-------
|
1798 |
+
stable_whisper.result.WhisperResult
|
1799 |
+
The current instance after the changes.
|
1800 |
+
"""
|
1801 |
+
if not self.has_words:
|
1802 |
+
return self
|
1803 |
+
if isinstance(words, str):
|
1804 |
+
words = [words]
|
1805 |
+
all_words = self.all_words()
|
1806 |
+
all_words_str = [w.word for w in all_words]
|
1807 |
+
if strip:
|
1808 |
+
all_words_str = [w.strip() for w in all_words_str]
|
1809 |
+
words = [w.strip() for w in words]
|
1810 |
+
if ignore_punctuations:
|
1811 |
+
ptn = f'[{ignore_punctuations}]+$'
|
1812 |
+
all_words_str = [re.sub(ptn, '', w) for w in all_words_str]
|
1813 |
+
words = [re.sub(ptn, '', w) for w in words]
|
1814 |
+
if not case_sensitive:
|
1815 |
+
all_words_str = [w.lower() for w in all_words_str]
|
1816 |
+
words = [w.lower() for w in words]
|
1817 |
+
|
1818 |
+
changes = []
|
1819 |
+
for i, w in reversed(list(enumerate(all_words_str))):
|
1820 |
+
if not (words is None or any(w == _w for _w in words)):
|
1821 |
+
continue
|
1822 |
+
if (
|
1823 |
+
(min_prob is None or all_words[i].probability is None or min_prob > all_words[i].probability) and
|
1824 |
+
(filters is None or filters(all_words[i]))
|
1825 |
+
):
|
1826 |
+
if verbose:
|
1827 |
+
changes.append(f'Removed: {all_words[i].to_dict()}')
|
1828 |
+
self.remove_word(all_words[i], False, verbose=False)
|
1829 |
+
if changes:
|
1830 |
+
print('\n'.join(reversed(changes)))
|
1831 |
+
self.remove_no_word_segments()
|
1832 |
+
self.update_all_segs_with_words()
|
1833 |
+
|
1834 |
+
return self
|
1835 |
+
|
1836 |
+
def fill_in_gaps(
|
1837 |
+
self,
|
1838 |
+
other_result: Union['WhisperResult', str],
|
1839 |
+
min_gap: float = 0.1,
|
1840 |
+
case_sensitive: bool = False,
|
1841 |
+
strip: bool = True,
|
1842 |
+
ignore_punctuations: str = "\"',.?!",
|
1843 |
+
verbose: bool = True
|
1844 |
+
) -> 'WhisperResult':
|
1845 |
+
"""
|
1846 |
+
Fill in segment gaps larger than ``min_gap`` with content from ``other_result`` at the times of gaps.
|
1847 |
+
|
1848 |
+
Parameters
|
1849 |
+
----------
|
1850 |
+
other_result : WhisperResult or str
|
1851 |
+
Another transcription result as an instance of :class:`stable_whisper.result.WhisperResult` or path to the
|
1852 |
+
JSON of the result.
|
1853 |
+
min_gap : float, default 0.1
|
1854 |
+
The minimum seconds of a gap between segments that must be exceeded to be filled in.
|
1855 |
+
case_sensitive : bool, default False
|
1856 |
+
Whether to consider the case of the first and last word of the gap to determine overlapping words to remove
|
1857 |
+
before filling in.
|
1858 |
+
strip : bool, default True
|
1859 |
+
Whether to ignore spaces before and after the first and last word of the gap to determine overlapping words
|
1860 |
+
to remove before filling in.
|
1861 |
+
ignore_punctuations : bool, default '"',.?!'
|
1862 |
+
Ending punctuations to ignore in the first and last word of the gap to determine overlapping words to
|
1863 |
+
remove before filling in.
|
1864 |
+
verbose:
|
1865 |
+
Whether to print detail of the filled content.
|
1866 |
+
|
1867 |
+
Returns
|
1868 |
+
-------
|
1869 |
+
stable_whisper.result.WhisperResult
|
1870 |
+
The current instance after the changes.
|
1871 |
+
"""
|
1872 |
+
if len(self.segments) < 2:
|
1873 |
+
return self
|
1874 |
+
if isinstance(other_result, str):
|
1875 |
+
other_result = WhisperResult(other_result)
|
1876 |
+
|
1877 |
+
if strip:
|
1878 |
+
def strip_space(w):
|
1879 |
+
return w.strip()
|
1880 |
+
else:
|
1881 |
+
def strip_space(w):
|
1882 |
+
return w
|
1883 |
+
|
1884 |
+
if ignore_punctuations:
|
1885 |
+
ptn = f'[{ignore_punctuations}]+$'
|
1886 |
+
|
1887 |
+
def strip_punctuations(w):
|
1888 |
+
return re.sub(ptn, '', strip_space(w))
|
1889 |
+
else:
|
1890 |
+
strip_punctuations = strip_space
|
1891 |
+
|
1892 |
+
if case_sensitive:
|
1893 |
+
strip = strip_punctuations
|
1894 |
+
else:
|
1895 |
+
def strip(w):
|
1896 |
+
return strip_punctuations(w).lower()
|
1897 |
+
|
1898 |
+
seg_pairs = list(enumerate(zip(self.segments[:-1], self.segments[1:])))
|
1899 |
+
seg_pairs.insert(0, (-1, (None, self.segments[0])))
|
1900 |
+
seg_pairs.append((seg_pairs[-1][0]+1, (self.segments[-1], None)))
|
1901 |
+
|
1902 |
+
changes = []
|
1903 |
+
for i, (seg0, seg1) in reversed(seg_pairs):
|
1904 |
+
first_word = None if seg0 is None else seg0.words[-1]
|
1905 |
+
last_word = None if seg1 is None else seg1.words[0]
|
1906 |
+
start = (other_result[0].start if first_word is None else first_word.end)
|
1907 |
+
end = other_result[-1].end if last_word is None else last_word.start
|
1908 |
+
if end - start <= min_gap:
|
1909 |
+
continue
|
1910 |
+
gap_words = other_result.get_content_by_time((start, end))
|
1911 |
+
if first_word is not None and gap_words and strip(first_word.word) == strip(gap_words[0].word):
|
1912 |
+
first_word.end = gap_words[0].end
|
1913 |
+
gap_words = gap_words[1:]
|
1914 |
+
if last_word is not None and gap_words and strip(last_word.word) == strip(gap_words[-1].word):
|
1915 |
+
last_word.start = gap_words[-1].start
|
1916 |
+
gap_words = gap_words[:-1]
|
1917 |
+
if not gap_words:
|
1918 |
+
continue
|
1919 |
+
if last_word is not None and last_word.start < gap_words[-1].end:
|
1920 |
+
last_word.start = gap_words[-1].end
|
1921 |
+
new_segments = [other_result[gap_words[0].segment_id].copy([])]
|
1922 |
+
for j, new_word in enumerate(gap_words):
|
1923 |
+
new_word = deepcopy(new_word)
|
1924 |
+
if j == 0 and first_word is not None and first_word.end > gap_words[0].start:
|
1925 |
+
new_word.start = first_word.end
|
1926 |
+
if new_segments[-1].id != new_word.segment_id:
|
1927 |
+
new_segments.append(other_result[new_word.segment_id].copy([]))
|
1928 |
+
new_segments[-1].words.append(new_word)
|
1929 |
+
if verbose:
|
1930 |
+
changes.append('\n'.join('Added: ' + s.to_display_str(True) for s in new_segments))
|
1931 |
+
self.segments = self.segments[:i+1] + new_segments + self.segments[i+1:]
|
1932 |
+
if changes:
|
1933 |
+
print('\n'.join(reversed(changes)))
|
1934 |
+
self.reassign_ids()
|
1935 |
+
self.update_all_segs_with_words()
|
1936 |
+
|
1937 |
+
return self
|
1938 |
+
|
1939 |
+
def regroup(
|
1940 |
+
self,
|
1941 |
+
regroup_algo: Union[str, bool] = None,
|
1942 |
+
verbose: bool = False,
|
1943 |
+
only_show: bool = False
|
1944 |
+
) -> "WhisperResult":
|
1945 |
+
"""
|
1946 |
+
Regroup (in-place) words into segments.
|
1947 |
+
|
1948 |
+
Parameters
|
1949 |
+
----------
|
1950 |
+
regroup_algo: str or bool, default 'da'
|
1951 |
+
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'.
|
1952 |
+
verbose : bool, default False
|
1953 |
+
Whether to show all the methods and arguments parsed from ``regroup_algo``.
|
1954 |
+
only_show : bool, default False
|
1955 |
+
Whether to show the all methods and arguments parsed from ``regroup_algo`` without running the methods
|
1956 |
+
|
1957 |
+
Returns
|
1958 |
+
-------
|
1959 |
+
stable_whisper.result.WhisperResult
|
1960 |
+
The current instance after the changes.
|
1961 |
+
|
1962 |
+
Notes
|
1963 |
+
-----
|
1964 |
+
Syntax for string representation of custom regrouping algorithm.
|
1965 |
+
Method keys:
|
1966 |
+
sg: split_by_gap
|
1967 |
+
sp: split_by_punctuation
|
1968 |
+
sl: split_by_length
|
1969 |
+
sd: split_by_duration
|
1970 |
+
mg: merge_by_gap
|
1971 |
+
mp: merge_by_punctuation
|
1972 |
+
ms: merge_all_segment
|
1973 |
+
cm: clamp_max
|
1974 |
+
l: lock
|
1975 |
+
us: unlock_all_segments
|
1976 |
+
da: default algorithm (cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?)
|
1977 |
+
rw: remove_word
|
1978 |
+
rs: remove_segment
|
1979 |
+
rp: remove_repetition
|
1980 |
+
rws: remove_words_by_str
|
1981 |
+
fg: fill_in_gaps
|
1982 |
+
Metacharacters:
|
1983 |
+
= separates a method key and its arguments (not used if no argument)
|
1984 |
+
_ separates method keys (after arguments if there are any)
|
1985 |
+
+ separates arguments for a method key
|
1986 |
+
/ separates an argument into list of strings
|
1987 |
+
* separates an item in list of strings into a nested list of strings
|
1988 |
+
Notes:
|
1989 |
+
-arguments are parsed positionally
|
1990 |
+
-if no argument is provided, the default ones will be used
|
1991 |
+
-use 1 or 0 to represent True or False
|
1992 |
+
Example 1:
|
1993 |
+
merge_by_gap(.2, 10, lock=True)
|
1994 |
+
mg=.2+10+++1
|
1995 |
+
Note: [lock] is the 5th argument hence the 2 missing arguments inbetween the three + before 1
|
1996 |
+
Example 2:
|
1997 |
+
split_by_punctuation([('.', ' '), '。', '?', '?'], True)
|
1998 |
+
sp=.* /。/?/?+1
|
1999 |
+
Example 3:
|
2000 |
+
merge_all_segments().split_by_gap(.5).merge_by_gap(.15, 3)
|
2001 |
+
ms_sg=.5_mg=.15+3
|
2002 |
+
"""
|
2003 |
+
if regroup_algo is False:
|
2004 |
+
return self
|
2005 |
+
if regroup_algo is None or regroup_algo is True:
|
2006 |
+
regroup_algo = 'da'
|
2007 |
+
|
2008 |
+
for method, kwargs, msg in self.parse_regroup_algo(regroup_algo, include_str=verbose or only_show):
|
2009 |
+
if msg:
|
2010 |
+
print(msg)
|
2011 |
+
if not only_show:
|
2012 |
+
method(**kwargs)
|
2013 |
+
|
2014 |
+
return self
|
2015 |
+
|
2016 |
+
def parse_regroup_algo(self, regroup_algo: str, include_str: bool = True) -> List[Tuple[Callable, dict, str]]:
|
2017 |
+
methods = dict(
|
2018 |
+
sg=self.split_by_gap,
|
2019 |
+
sp=self.split_by_punctuation,
|
2020 |
+
sl=self.split_by_length,
|
2021 |
+
sd=self.split_by_duration,
|
2022 |
+
mg=self.merge_by_gap,
|
2023 |
+
mp=self.merge_by_punctuation,
|
2024 |
+
ms=self.merge_all_segments,
|
2025 |
+
cm=self.clamp_max,
|
2026 |
+
us=self.unlock_all_segments,
|
2027 |
+
l=self.lock,
|
2028 |
+
rw=self.remove_word,
|
2029 |
+
rs=self.remove_segment,
|
2030 |
+
rp=self.remove_repetition,
|
2031 |
+
rws=self.remove_words_by_str,
|
2032 |
+
fg=self.fill_in_gaps,
|
2033 |
+
)
|
2034 |
+
if not regroup_algo:
|
2035 |
+
return []
|
2036 |
+
|
2037 |
+
calls = regroup_algo.split('_')
|
2038 |
+
if 'da' in calls:
|
2039 |
+
default_calls = 'cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?'.split('_')
|
2040 |
+
calls = chain.from_iterable(default_calls if method == 'da' else [method] for method in calls)
|
2041 |
+
operations = []
|
2042 |
+
for method in calls:
|
2043 |
+
method, args = method.split('=', maxsplit=1) if '=' in method else (method, '')
|
2044 |
+
if method not in methods:
|
2045 |
+
raise NotImplementedError(f'{method} is not one of the available methods: {tuple(methods.keys())}')
|
2046 |
+
args = [] if len(args) == 0 else list(map(str_to_valid_type, args.split('+')))
|
2047 |
+
kwargs = {k: v for k, v in zip(methods[method].__code__.co_varnames[1:], args) if v is not None}
|
2048 |
+
if include_str:
|
2049 |
+
kwargs_str = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in kwargs.items())
|
2050 |
+
op_str = f'{methods[method].__name__}({kwargs_str})'
|
2051 |
+
else:
|
2052 |
+
op_str = None
|
2053 |
+
operations.append((methods[method], kwargs, op_str))
|
2054 |
+
|
2055 |
+
return operations
|
2056 |
+
|
2057 |
+
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches":
|
2058 |
+
"""
|
2059 |
+
Find segments/words and timestamps with regular expression.
|
2060 |
+
|
2061 |
+
Parameters
|
2062 |
+
----------
|
2063 |
+
pattern : str
|
2064 |
+
RegEx pattern to search for.
|
2065 |
+
word_level : bool, default True
|
2066 |
+
Whether to search at word-level.
|
2067 |
+
flags : optional
|
2068 |
+
RegEx flags.
|
2069 |
+
|
2070 |
+
Returns
|
2071 |
+
-------
|
2072 |
+
stable_whisper.result.WhisperResultMatches
|
2073 |
+
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``.
|
2074 |
+
"""
|
2075 |
+
return WhisperResultMatches(self).find(pattern, word_level=word_level, flags=flags)
|
2076 |
+
|
2077 |
+
@property
|
2078 |
+
def text(self):
|
2079 |
+
return ''.join(s.text for s in self.segments)
|
2080 |
+
|
2081 |
+
@property
|
2082 |
+
def regroup_history(self):
|
2083 |
+
# same syntax as ``regroup_algo`` for :meth:``result.WhisperResult.regroup`
|
2084 |
+
return self._regroup_history
|
2085 |
+
|
2086 |
+
@property
|
2087 |
+
def nonspeech_sections(self):
|
2088 |
+
return self._nonspeech_sections
|
2089 |
+
|
2090 |
+
def show_regroup_history(self):
|
2091 |
+
"""
|
2092 |
+
Print details of all regrouping operations that been performed on data.
|
2093 |
+
"""
|
2094 |
+
if not self._regroup_history:
|
2095 |
+
print('Result has no history.')
|
2096 |
+
for *_, msg in self.parse_regroup_algo(self._regroup_history):
|
2097 |
+
print(f'.{msg}')
|
2098 |
+
|
2099 |
+
def __len__(self):
|
2100 |
+
return len(self.segments)
|
2101 |
+
|
2102 |
+
def unlock_all_segments(self):
|
2103 |
+
for s in self.segments:
|
2104 |
+
s.unlock_all_words()
|
2105 |
+
return self
|
2106 |
+
|
2107 |
+
def reset(self):
|
2108 |
+
"""
|
2109 |
+
Restore all values to that at initialization.
|
2110 |
+
"""
|
2111 |
+
self.language = self.ori_dict.get('language')
|
2112 |
+
self._regroup_history = ''
|
2113 |
+
segments = self.ori_dict.get('segments')
|
2114 |
+
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else []
|
2115 |
+
if self._forced_order:
|
2116 |
+
self.force_order()
|
2117 |
+
self.remove_no_word_segments(any(seg.has_words for seg in self.segments))
|
2118 |
+
self.update_all_segs_with_words()
|
2119 |
+
|
2120 |
+
@property
|
2121 |
+
def has_words(self):
|
2122 |
+
return all(seg.has_words for seg in self.segments)
|
2123 |
+
|
2124 |
+
to_srt_vtt = result_to_srt_vtt
|
2125 |
+
to_ass = result_to_ass
|
2126 |
+
to_tsv = result_to_tsv
|
2127 |
+
to_txt = result_to_txt
|
2128 |
+
save_as_json = save_as_json
|
2129 |
+
|
2130 |
+
|
2131 |
+
class SegmentMatch:
|
2132 |
+
|
2133 |
+
def __init__(
|
2134 |
+
self,
|
2135 |
+
segments: Union[List[Segment], Segment],
|
2136 |
+
_word_indices: List[List[int]] = None,
|
2137 |
+
_text_match: str = None
|
2138 |
+
):
|
2139 |
+
self.segments = [segments] if isinstance(segments, Segment) else segments
|
2140 |
+
self.word_indices = [] if _word_indices is None else _word_indices
|
2141 |
+
self.words = [self.segments[i].words[j] for i, indices in enumerate(self.word_indices) for j in indices]
|
2142 |
+
if len(self.words) != 0:
|
2143 |
+
self.text = ''.join(
|
2144 |
+
self.segments[i].words[j].word
|
2145 |
+
for i, indices in enumerate(self.word_indices)
|
2146 |
+
for j in indices
|
2147 |
+
)
|
2148 |
+
else:
|
2149 |
+
self.text = ''.join(seg.text for seg in self.segments)
|
2150 |
+
self.text_match = _text_match
|
2151 |
+
|
2152 |
+
@property
|
2153 |
+
def start(self):
|
2154 |
+
return (
|
2155 |
+
self.words[0].start
|
2156 |
+
if len(self.words) != 0 else
|
2157 |
+
(self.segments[0].start if len(self.segments) != 0 else None)
|
2158 |
+
)
|
2159 |
+
|
2160 |
+
@property
|
2161 |
+
def end(self):
|
2162 |
+
return (
|
2163 |
+
self.words[-1].end
|
2164 |
+
if len(self.words) != 0 else
|
2165 |
+
(self.segments[-1].end if len(self.segments) != 0 else None)
|
2166 |
+
)
|
2167 |
+
|
2168 |
+
def __len__(self):
|
2169 |
+
return len(self.segments)
|
2170 |
+
|
2171 |
+
def __repr__(self):
|
2172 |
+
return self.__dict__.__repr__()
|
2173 |
+
|
2174 |
+
def __str__(self):
|
2175 |
+
return self.__dict__.__str__()
|
2176 |
+
|
2177 |
+
|
2178 |
+
class WhisperResultMatches:
|
2179 |
+
"""
|
2180 |
+
RegEx matches for WhisperResults.
|
2181 |
+
"""
|
2182 |
+
# Use WhisperResult.find() instead of instantiating this class directly.
|
2183 |
+
def __init__(
|
2184 |
+
self,
|
2185 |
+
matches: Union[List[SegmentMatch], WhisperResult],
|
2186 |
+
_segment_indices: List[List[int]] = None
|
2187 |
+
):
|
2188 |
+
if isinstance(matches, WhisperResult):
|
2189 |
+
self.matches = list(map(SegmentMatch, matches.segments))
|
2190 |
+
self._segment_indices = [[i] for i in range(len(matches.segments))]
|
2191 |
+
else:
|
2192 |
+
self.matches = matches
|
2193 |
+
assert _segment_indices is not None
|
2194 |
+
assert len(self.matches) == len(_segment_indices)
|
2195 |
+
assert all(len(match.segments) == len(_segment_indices[i]) for i, match in enumerate(self.matches))
|
2196 |
+
self._segment_indices = _segment_indices
|
2197 |
+
|
2198 |
+
@property
|
2199 |
+
def segment_indices(self):
|
2200 |
+
return self._segment_indices
|
2201 |
+
|
2202 |
+
def _curr_seg_groups(self) -> List[List[Tuple[int, Segment]]]:
|
2203 |
+
seg_groups, curr_segs = [], []
|
2204 |
+
curr_max = -1
|
2205 |
+
for seg_indices, match in zip(self._segment_indices, self.matches):
|
2206 |
+
for i, seg in zip(sorted(seg_indices), match.segments):
|
2207 |
+
if i > curr_max:
|
2208 |
+
curr_segs.append((i, seg))
|
2209 |
+
if i - 1 != curr_max:
|
2210 |
+
seg_groups.append(curr_segs)
|
2211 |
+
curr_segs = []
|
2212 |
+
curr_max = i
|
2213 |
+
|
2214 |
+
if curr_segs:
|
2215 |
+
seg_groups.append(curr_segs)
|
2216 |
+
return seg_groups
|
2217 |
+
|
2218 |
+
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches":
|
2219 |
+
"""
|
2220 |
+
Find segments/words and timestamps with regular expression.
|
2221 |
+
|
2222 |
+
Parameters
|
2223 |
+
----------
|
2224 |
+
pattern : str
|
2225 |
+
RegEx pattern to search for.
|
2226 |
+
word_level : bool, default True
|
2227 |
+
Whether to search at word-level.
|
2228 |
+
flags : optional
|
2229 |
+
RegEx flags.
|
2230 |
+
|
2231 |
+
Returns
|
2232 |
+
-------
|
2233 |
+
stable_whisper.result.WhisperResultMatches
|
2234 |
+
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``.
|
2235 |
+
"""
|
2236 |
+
|
2237 |
+
seg_groups = self._curr_seg_groups()
|
2238 |
+
matches: List[SegmentMatch] = []
|
2239 |
+
match_seg_indices: List[List[int]] = []
|
2240 |
+
if word_level:
|
2241 |
+
if not all(all(seg.has_words for seg in match.segments) for match in self.matches):
|
2242 |
+
warnings.warn('Cannot perform word-level search with segment(s) missing word timestamps.')
|
2243 |
+
word_level = False
|
2244 |
+
|
2245 |
+
for segs in seg_groups:
|
2246 |
+
if word_level:
|
2247 |
+
idxs = list(chain.from_iterable(
|
2248 |
+
[(i, j)]*len(word.word) for (i, seg) in segs for j, word in enumerate(seg.words)
|
2249 |
+
))
|
2250 |
+
text = ''.join(word.word for (_, seg) in segs for word in seg.words)
|
2251 |
+
else:
|
2252 |
+
idxs = list(chain.from_iterable([(i, None)]*len(seg.text) for (i, seg) in segs))
|
2253 |
+
text = ''.join(seg.text for (_, seg) in segs)
|
2254 |
+
assert len(idxs) == len(text)
|
2255 |
+
for curr_match in re.finditer(pattern, text, flags=flags or 0):
|
2256 |
+
start, end = curr_match.span()
|
2257 |
+
curr_idxs = idxs[start: end]
|
2258 |
+
curr_seg_idxs = sorted(set(i[0] for i in curr_idxs))
|
2259 |
+
if word_level:
|
2260 |
+
curr_word_idxs = [
|
2261 |
+
sorted(set(j for i, j in curr_idxs if i == seg_idx))
|
2262 |
+
for seg_idx in curr_seg_idxs
|
2263 |
+
]
|
2264 |
+
else:
|
2265 |
+
curr_word_idxs = None
|
2266 |
+
matches.append(SegmentMatch(
|
2267 |
+
segments=[s for i, s in segs if i in curr_seg_idxs],
|
2268 |
+
_word_indices=curr_word_idxs,
|
2269 |
+
_text_match=curr_match.group()
|
2270 |
+
))
|
2271 |
+
match_seg_indices.append(curr_seg_idxs)
|
2272 |
+
return WhisperResultMatches(matches, match_seg_indices)
|
2273 |
+
|
2274 |
+
def __len__(self):
|
2275 |
+
return len(self.matches)
|
2276 |
+
|
2277 |
+
def __bool__(self):
|
2278 |
+
return self.__len__() != 0
|
2279 |
+
|
2280 |
+
def __getitem__(self, idx):
|
2281 |
+
return self.matches[idx]
|
stable_whisper/stabilization.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import List, Union, Tuple, Optional
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from whisper.audio import TOKENS_PER_SECOND, SAMPLE_RATE, N_SAMPLES_PER_TOKEN
|
10 |
+
|
11 |
+
|
12 |
+
NONVAD_SAMPLE_RATES = (16000,)
|
13 |
+
VAD_SAMPLE_RATES = (16000, 8000)
|
14 |
+
|
15 |
+
|
16 |
+
def is_ascending_sequence(
|
17 |
+
seq: List[Union[int, float]],
|
18 |
+
verbose=True
|
19 |
+
) -> bool:
|
20 |
+
"""
|
21 |
+
check if a sequence of numbers are in ascending order
|
22 |
+
"""
|
23 |
+
is_ascending = True
|
24 |
+
for idx, (i, j) in enumerate(zip(seq[:-1], seq[1:])):
|
25 |
+
if i > j:
|
26 |
+
is_ascending = False
|
27 |
+
if verbose:
|
28 |
+
print(f'[Index{idx}]:{i} > [Index{idx + 1}]:{j}')
|
29 |
+
else:
|
30 |
+
break
|
31 |
+
|
32 |
+
return is_ascending
|
33 |
+
|
34 |
+
|
35 |
+
def valid_ts(
|
36 |
+
ts: List[dict],
|
37 |
+
warn=True
|
38 |
+
) -> bool:
|
39 |
+
valid = is_ascending_sequence(list(chain.from_iterable([s['start'], s['end']] for s in ts)), False)
|
40 |
+
if warn and not valid:
|
41 |
+
warnings.warn(message='Found timestamp(s) jumping backwards in time. '
|
42 |
+
'Use word_timestamps=True to avoid the issue.')
|
43 |
+
return valid
|
44 |
+
|
45 |
+
|
46 |
+
def mask2timing(
|
47 |
+
silence_mask: (np.ndarray, torch.Tensor),
|
48 |
+
time_offset: float = 0.0,
|
49 |
+
) -> (Tuple[np.ndarray, np.ndarray], None):
|
50 |
+
if silence_mask is None or not silence_mask.any():
|
51 |
+
return
|
52 |
+
assert silence_mask.ndim == 1
|
53 |
+
if isinstance(silence_mask, torch.Tensor):
|
54 |
+
silences = silence_mask.cpu().numpy().copy()
|
55 |
+
elif isinstance(silence_mask, np.ndarray):
|
56 |
+
silences = silence_mask.copy()
|
57 |
+
else:
|
58 |
+
raise NotImplementedError(f'Expected torch.Tensor or numpy.ndarray, but got {type(silence_mask)}')
|
59 |
+
silences[0] = False
|
60 |
+
silences[-1] = False
|
61 |
+
silent_starts = np.logical_and(~silences[:-1], silences[1:]).nonzero()[0] / TOKENS_PER_SECOND
|
62 |
+
silent_ends = (np.logical_and(silences[:-1], ~silences[1:]).nonzero()[0] + 1) / TOKENS_PER_SECOND
|
63 |
+
if time_offset:
|
64 |
+
silent_starts += time_offset
|
65 |
+
silent_ends += time_offset
|
66 |
+
return silent_starts, silent_ends
|
67 |
+
|
68 |
+
|
69 |
+
def timing2mask(
|
70 |
+
silent_starts: np.ndarray,
|
71 |
+
silent_ends: np.ndarray,
|
72 |
+
size: int,
|
73 |
+
time_offset: float = None
|
74 |
+
) -> torch.Tensor:
|
75 |
+
assert len(silent_starts) == len(silent_ends)
|
76 |
+
ts_token_mask = torch.zeros(size, dtype=torch.bool)
|
77 |
+
if time_offset:
|
78 |
+
silent_starts = (silent_starts - time_offset).clip(min=0)
|
79 |
+
silent_ends = (silent_ends - time_offset).clip(min=0)
|
80 |
+
mask_i = (silent_starts * TOKENS_PER_SECOND).round().astype(np.int16)
|
81 |
+
mask_e = (silent_ends * TOKENS_PER_SECOND).round().astype(np.int16)
|
82 |
+
for mi, me in zip(mask_i, mask_e):
|
83 |
+
ts_token_mask[mi:me+1] = True
|
84 |
+
|
85 |
+
return ts_token_mask
|
86 |
+
|
87 |
+
|
88 |
+
def suppress_silence(
|
89 |
+
result_obj,
|
90 |
+
silent_starts: Union[np.ndarray, List[float]],
|
91 |
+
silent_ends: Union[np.ndarray, List[float]],
|
92 |
+
min_word_dur: float,
|
93 |
+
nonspeech_error: float = 0.3,
|
94 |
+
keep_end: Optional[bool] = True
|
95 |
+
):
|
96 |
+
assert len(silent_starts) == len(silent_ends)
|
97 |
+
if len(silent_starts) == 0 or (result_obj.end - result_obj.start) <= min_word_dur:
|
98 |
+
return
|
99 |
+
if isinstance(silent_starts, list):
|
100 |
+
silent_starts = np.array(silent_starts)
|
101 |
+
if isinstance(silent_ends, list):
|
102 |
+
silent_ends = np.array(silent_ends)
|
103 |
+
|
104 |
+
start_overlaps = np.all(
|
105 |
+
(silent_starts <= result_obj.start, result_obj.start < silent_ends, silent_ends <= result_obj.end),
|
106 |
+
axis=0
|
107 |
+
).nonzero()[0].tolist()
|
108 |
+
if start_overlaps:
|
109 |
+
new_start = silent_ends[start_overlaps[0]]
|
110 |
+
result_obj.start = min(new_start, round(result_obj.end - min_word_dur, 3))
|
111 |
+
if (result_obj.end - result_obj.start) <= min_word_dur:
|
112 |
+
return
|
113 |
+
|
114 |
+
end_overlaps = np.all(
|
115 |
+
(result_obj.start <= silent_starts, silent_starts < result_obj.end, result_obj.end <= silent_ends),
|
116 |
+
axis=0
|
117 |
+
).nonzero()[0].tolist()
|
118 |
+
if end_overlaps:
|
119 |
+
new_end = silent_starts[end_overlaps[0]]
|
120 |
+
result_obj.end = max(new_end, round(result_obj.start + min_word_dur, 3))
|
121 |
+
if (result_obj.end - result_obj.start) <= min_word_dur:
|
122 |
+
return
|
123 |
+
|
124 |
+
if nonspeech_error:
|
125 |
+
matches = np.logical_and(
|
126 |
+
result_obj.start <= silent_starts,
|
127 |
+
result_obj.end >= silent_ends,
|
128 |
+
).nonzero()[0].tolist()
|
129 |
+
if len(matches) == 0:
|
130 |
+
return
|
131 |
+
silence_start = np.min(silent_starts[matches])
|
132 |
+
silence_end = np.max(silent_ends[matches])
|
133 |
+
start_extra = silence_start - result_obj.start
|
134 |
+
end_extra = result_obj.end - silence_end
|
135 |
+
silent_duration = silence_end - silence_start
|
136 |
+
start_within_error = (start_extra / silent_duration) <= nonspeech_error
|
137 |
+
end_within_error = (end_extra / silent_duration) <= nonspeech_error
|
138 |
+
if keep_end is None:
|
139 |
+
keep_end = start_extra <= end_extra
|
140 |
+
within_error = start_within_error if keep_end else end_within_error
|
141 |
+
else:
|
142 |
+
within_error = start_within_error or end_within_error
|
143 |
+
|
144 |
+
if within_error:
|
145 |
+
if keep_end:
|
146 |
+
result_obj.start = min(silence_end, round(result_obj.end - min_word_dur, 3))
|
147 |
+
else:
|
148 |
+
result_obj.end = max(silence_start, round(result_obj.start + min_word_dur, 3))
|
149 |
+
|
150 |
+
|
151 |
+
def standardize_audio(
|
152 |
+
audio: Union[torch.Tensor, np.ndarray, str, bytes],
|
153 |
+
resample_sr: Tuple[Optional[int], Union[int, Tuple[int]]] = None
|
154 |
+
) -> torch.Tensor:
|
155 |
+
if isinstance(audio, (str, bytes)):
|
156 |
+
from .audio import load_audio
|
157 |
+
audio = load_audio(audio)
|
158 |
+
if isinstance(audio, np.ndarray):
|
159 |
+
audio = torch.from_numpy(audio)
|
160 |
+
audio = audio.float()
|
161 |
+
if resample_sr:
|
162 |
+
in_sr, out_sr = resample_sr
|
163 |
+
if in_sr:
|
164 |
+
if isinstance(out_sr, int):
|
165 |
+
out_sr = [out_sr]
|
166 |
+
if in_sr not in out_sr:
|
167 |
+
from torchaudio.functional import resample
|
168 |
+
audio = resample(audio, in_sr, out_sr[0])
|
169 |
+
|
170 |
+
return audio
|
171 |
+
|
172 |
+
|
173 |
+
def audio2loudness(
|
174 |
+
audio_tensor: torch.Tensor
|
175 |
+
) -> (torch.Tensor, None):
|
176 |
+
assert audio_tensor.dim() == 1, f'waveform must be 1D, but got {audio_tensor.dim()}D'
|
177 |
+
audio_tensor = audio_tensor.abs()
|
178 |
+
k = int(audio_tensor.numel() * 0.001)
|
179 |
+
if k:
|
180 |
+
top_values, _ = torch.topk(audio_tensor, k)
|
181 |
+
threshold = top_values[-1]
|
182 |
+
else:
|
183 |
+
threshold = audio_tensor.quantile(0.999, dim=-1)
|
184 |
+
if (token_count := round(audio_tensor.shape[-1] / N_SAMPLES_PER_TOKEN)+1) > 2:
|
185 |
+
if threshold < 1e-5:
|
186 |
+
return torch.zeros(token_count, dtype=audio_tensor.dtype, device=audio_tensor.device)
|
187 |
+
audio_tensor = audio_tensor / min(1., threshold * 1.75)
|
188 |
+
audio_tensor = F.interpolate(
|
189 |
+
audio_tensor[None, None],
|
190 |
+
size=token_count,
|
191 |
+
mode='linear',
|
192 |
+
align_corners=False
|
193 |
+
)[0, 0]
|
194 |
+
return audio_tensor
|
195 |
+
|
196 |
+
|
197 |
+
def visualize_mask(
|
198 |
+
loudness_tensor: torch.Tensor,
|
199 |
+
silence_mask: torch.Tensor = None,
|
200 |
+
width: int = 1500,
|
201 |
+
height: int = 200,
|
202 |
+
output: str = None,
|
203 |
+
):
|
204 |
+
no_silence = silence_mask is None or not silence_mask.any()
|
205 |
+
assert no_silence or silence_mask.shape[0] == loudness_tensor.shape[0]
|
206 |
+
if loudness_tensor.shape[0] < 2:
|
207 |
+
raise NotImplementedError(f'audio size, {loudness_tensor.shape[0]}, is too short to visualize')
|
208 |
+
else:
|
209 |
+
width = loudness_tensor.shape[0] if width == -1 else width
|
210 |
+
im = torch.zeros((height, width, 3), dtype=torch.uint8)
|
211 |
+
mid = round(height / 2)
|
212 |
+
for i, j in enumerate(loudness_tensor.tolist()):
|
213 |
+
j = round(abs(j) * mid)
|
214 |
+
if j == 0 or width <= i:
|
215 |
+
continue
|
216 |
+
im[mid - j:mid + 1, i] = 255
|
217 |
+
im[mid + 1:mid + j + 1, i] = 255
|
218 |
+
if not no_silence:
|
219 |
+
im[:, silence_mask[:width], 1:] = 0
|
220 |
+
im = im.cpu().numpy()
|
221 |
+
if output and not output.endswith('.png'):
|
222 |
+
output += '.png'
|
223 |
+
try:
|
224 |
+
from PIL import Image
|
225 |
+
except ModuleNotFoundError:
|
226 |
+
try:
|
227 |
+
import cv2
|
228 |
+
except ModuleNotFoundError:
|
229 |
+
raise ModuleNotFoundError('Failed to import "PIL" or "cv2" to visualize suppression mask. '
|
230 |
+
'Try "pip install Pillow" or "pip install opencv-python"')
|
231 |
+
else:
|
232 |
+
im = im[..., [2, 1, 0]]
|
233 |
+
if isinstance(output, str):
|
234 |
+
cv2.imwrite(output, im)
|
235 |
+
else:
|
236 |
+
cv2.imshow('image', im)
|
237 |
+
cv2.waitKey(0)
|
238 |
+
else:
|
239 |
+
im = Image.fromarray(im)
|
240 |
+
if isinstance(output, str):
|
241 |
+
im.save(output)
|
242 |
+
else:
|
243 |
+
im.show(im)
|
244 |
+
if output:
|
245 |
+
print(f'Save: {output}')
|
246 |
+
|
247 |
+
|
248 |
+
def wav2mask(
|
249 |
+
audio: (torch.Tensor, np.ndarray, str, bytes),
|
250 |
+
q_levels: int = 20,
|
251 |
+
k_size: int = 5,
|
252 |
+
sr: int = None
|
253 |
+
) -> (Tuple[torch.Tensor, Tuple[np.ndarray, np.ndarray]], None):
|
254 |
+
"""
|
255 |
+
Generate 1D mask from waveform for suppressing timestamp tokens.
|
256 |
+
"""
|
257 |
+
audio = standardize_audio(audio, (sr, NONVAD_SAMPLE_RATES))
|
258 |
+
loudness_tensor = audio2loudness(audio)
|
259 |
+
if loudness_tensor is None:
|
260 |
+
return
|
261 |
+
p = k_size // 2 if k_size else 0
|
262 |
+
if p and p < loudness_tensor.shape[-1]:
|
263 |
+
assert k_size % 2, f'kernel_size must be odd but got {k_size}'
|
264 |
+
mask = torch.avg_pool1d(
|
265 |
+
F.pad(
|
266 |
+
loudness_tensor[None],
|
267 |
+
(p, p),
|
268 |
+
'reflect'
|
269 |
+
),
|
270 |
+
kernel_size=k_size,
|
271 |
+
stride=1
|
272 |
+
)[0]
|
273 |
+
else:
|
274 |
+
mask = loudness_tensor.clone()
|
275 |
+
|
276 |
+
if q_levels:
|
277 |
+
mask = mask.mul(q_levels).round()
|
278 |
+
|
279 |
+
mask = mask.bool()
|
280 |
+
|
281 |
+
if not mask.any(): # entirely silent
|
282 |
+
return ~mask
|
283 |
+
temp_timings = mask2timing(mask)
|
284 |
+
s, e = temp_timings
|
285 |
+
se_mask = (e - s) > 0.1
|
286 |
+
s = s[se_mask]
|
287 |
+
e = e[se_mask]
|
288 |
+
mask = ~timing2mask(s, e, loudness_tensor.shape[-1])
|
289 |
+
|
290 |
+
if not mask.any(): # no silence
|
291 |
+
return
|
292 |
+
|
293 |
+
return mask
|
294 |
+
|
295 |
+
|
296 |
+
_model_cache = {}
|
297 |
+
|
298 |
+
|
299 |
+
def get_vad_silence_func(
|
300 |
+
onnx=False,
|
301 |
+
verbose: (bool, None) = False
|
302 |
+
):
|
303 |
+
if onnx in _model_cache:
|
304 |
+
model, get_ts = _model_cache[onnx]
|
305 |
+
else:
|
306 |
+
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad:master',
|
307 |
+
model='silero_vad',
|
308 |
+
verbose=verbose,
|
309 |
+
onnx=onnx,
|
310 |
+
trust_repo=True)
|
311 |
+
get_ts = utils[0]
|
312 |
+
_model_cache[onnx] = (model, get_ts)
|
313 |
+
|
314 |
+
warnings.filterwarnings('ignore', message=r'operator \(\) profile_node.*', category=UserWarning)
|
315 |
+
|
316 |
+
def get_speech_timestamps(wav: torch.Tensor, threshold: float = .35):
|
317 |
+
return get_ts(wav, model, threshold, min_speech_duration_ms=100, min_silence_duration_ms=20)
|
318 |
+
|
319 |
+
def vad_silence_timing(
|
320 |
+
audio: (torch.Tensor, np.ndarray, str, bytes),
|
321 |
+
speech_threshold: float = .35,
|
322 |
+
sr: int = None
|
323 |
+
) -> (Tuple[np.ndarray, np.ndarray], None):
|
324 |
+
|
325 |
+
audio = standardize_audio(audio, (sr, VAD_SAMPLE_RATES))
|
326 |
+
|
327 |
+
total_duration = round(audio.shape[-1] / SAMPLE_RATE, 3)
|
328 |
+
if not total_duration:
|
329 |
+
return
|
330 |
+
ori_t = torch.get_num_threads()
|
331 |
+
if verbose is not None:
|
332 |
+
print('Predicting silences(s) with VAD...\r', end='')
|
333 |
+
torch.set_num_threads(1) # vad was optimized for single performance
|
334 |
+
speech_ts = get_speech_timestamps(audio, speech_threshold)
|
335 |
+
if verbose is not None:
|
336 |
+
print('Predicted silence(s) with VAD. ')
|
337 |
+
torch.set_num_threads(ori_t)
|
338 |
+
if len(speech_ts) == 0: # all silent
|
339 |
+
return np.array([0.0]), np.array([total_duration])
|
340 |
+
silent_starts = []
|
341 |
+
silent_ends = []
|
342 |
+
for ts in speech_ts:
|
343 |
+
start = round(ts['start'] / SAMPLE_RATE, 3)
|
344 |
+
end = round(ts['end'] / SAMPLE_RATE, 3)
|
345 |
+
if start != 0:
|
346 |
+
silent_ends.append(start)
|
347 |
+
if len(silent_starts) == 0:
|
348 |
+
silent_starts.append(0.0)
|
349 |
+
if end < total_duration:
|
350 |
+
silent_starts.append(end)
|
351 |
+
|
352 |
+
if len(silent_starts) == 0 and len(silent_ends) == 0:
|
353 |
+
return
|
354 |
+
|
355 |
+
if len(silent_starts) != 0 and (len(silent_ends) == 0 or silent_ends[-1] < silent_starts[-1]):
|
356 |
+
silent_ends.append(total_duration)
|
357 |
+
|
358 |
+
silent_starts = np.array(silent_starts)
|
359 |
+
silent_ends = np.array(silent_ends)
|
360 |
+
|
361 |
+
return silent_starts, silent_ends
|
362 |
+
|
363 |
+
return vad_silence_timing
|
364 |
+
|
365 |
+
|
366 |
+
def visualize_suppression(
|
367 |
+
audio: Union[torch.Tensor, np.ndarray, str, bytes],
|
368 |
+
output: str = None,
|
369 |
+
q_levels: int = 20,
|
370 |
+
k_size: int = 5,
|
371 |
+
vad_threshold: float = 0.35,
|
372 |
+
vad: bool = False,
|
373 |
+
max_width: int = 1500,
|
374 |
+
height: int = 200
|
375 |
+
):
|
376 |
+
"""
|
377 |
+
Visualize regions on the waveform of ``audio`` detected as silent.
|
378 |
+
|
379 |
+
Regions on the waveform colored red are detected as silent.
|
380 |
+
|
381 |
+
Parameters
|
382 |
+
----------
|
383 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
384 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
385 |
+
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
|
386 |
+
output : str, default None, meaning image will be shown directly via Pillow or opencv-python
|
387 |
+
Path to save visualization.
|
388 |
+
q_levels : int, default 20
|
389 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
390 |
+
Acts as a threshold to marking sound as silent.
|
391 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
392 |
+
k_size : int, default 5
|
393 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
394 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
395 |
+
vad : bool, default False
|
396 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
397 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
398 |
+
vad_threshold : float, default 0.35
|
399 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
400 |
+
max_width : int, default 1500
|
401 |
+
Maximum width of visualization to avoid overly large image from long audio.
|
402 |
+
Each unit of pixel is equivalent to 1 token. Use -1 to visualize the entire audio track.
|
403 |
+
height : int, default 200
|
404 |
+
Height of visualization.
|
405 |
+
"""
|
406 |
+
max_n_samples = None if max_width == -1 else round(max_width * N_SAMPLES_PER_TOKEN)
|
407 |
+
|
408 |
+
audio = standardize_audio(audio)
|
409 |
+
if max_n_samples is None:
|
410 |
+
max_width = audio.shape[-1]
|
411 |
+
else:
|
412 |
+
audio = audio[:max_n_samples]
|
413 |
+
loudness_tensor = audio2loudness(audio)
|
414 |
+
width = min(max_width, loudness_tensor.shape[-1])
|
415 |
+
if loudness_tensor is None:
|
416 |
+
raise NotImplementedError(f'Audio is too short and cannot visualized.')
|
417 |
+
|
418 |
+
if vad:
|
419 |
+
silence_timings = get_vad_silence_func()(audio, vad_threshold)
|
420 |
+
silence_mask = None if silence_timings is None else timing2mask(*silence_timings, size=loudness_tensor.shape[0])
|
421 |
+
else:
|
422 |
+
silence_mask = wav2mask(audio, q_levels=q_levels, k_size=k_size)
|
423 |
+
|
424 |
+
visualize_mask(loudness_tensor, silence_mask, width=width, height=height, output=output)
|
stable_whisper/text_output.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from typing import List, Tuple, Union, Callable
|
5 |
+
from itertools import chain
|
6 |
+
from .stabilization import valid_ts
|
7 |
+
|
8 |
+
__all__ = ['result_to_srt_vtt', 'result_to_ass', 'result_to_tsv', 'result_to_txt', 'save_as_json', 'load_result']
|
9 |
+
SUPPORTED_FORMATS = ('srt', 'vtt', 'ass', 'tsv', 'txt')
|
10 |
+
|
11 |
+
|
12 |
+
def _save_as_file(content: str, path: str):
|
13 |
+
with open(path, 'w', encoding='utf-8') as f:
|
14 |
+
f.write(content)
|
15 |
+
print(f'Saved: {os.path.abspath(path)}')
|
16 |
+
|
17 |
+
|
18 |
+
def _get_segments(result: (dict, list), min_dur: float, reverse_text: Union[bool, tuple] = False):
|
19 |
+
if isinstance(result, dict):
|
20 |
+
if reverse_text:
|
21 |
+
warnings.warn(f'[reverse_text]=True only applies to WhisperResult but result is {type(result)}')
|
22 |
+
return result.get('segments')
|
23 |
+
elif not isinstance(result, list) and callable(getattr(result, 'segments_to_dicts', None)):
|
24 |
+
return result.apply_min_dur(min_dur, inplace=False).segments_to_dicts(reverse_text=reverse_text)
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def finalize_text(text: str, strip: bool = True):
|
29 |
+
if not strip:
|
30 |
+
return text
|
31 |
+
return text.strip().replace('\n ', '\n')
|
32 |
+
|
33 |
+
|
34 |
+
def sec2hhmmss(seconds: (float, int)):
|
35 |
+
mm, ss = divmod(seconds, 60)
|
36 |
+
hh, mm = divmod(mm, 60)
|
37 |
+
return hh, mm, ss
|
38 |
+
|
39 |
+
|
40 |
+
def sec2milliseconds(seconds: (float, int)) -> int:
|
41 |
+
return round(seconds * 1000)
|
42 |
+
|
43 |
+
|
44 |
+
def sec2centiseconds(seconds: (float, int)) -> int:
|
45 |
+
return round(seconds * 100)
|
46 |
+
|
47 |
+
|
48 |
+
def sec2vtt(seconds: (float, int)) -> str:
|
49 |
+
hh, mm, ss = sec2hhmmss(seconds)
|
50 |
+
return f'{hh:0>2.0f}:{mm:0>2.0f}:{ss:0>6.3f}'
|
51 |
+
|
52 |
+
|
53 |
+
def sec2srt(seconds: (float, int)) -> str:
|
54 |
+
return sec2vtt(seconds).replace(".", ",")
|
55 |
+
|
56 |
+
|
57 |
+
def sec2ass(seconds: (float, int)) -> str:
|
58 |
+
hh, mm, ss = sec2hhmmss(seconds)
|
59 |
+
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
|
60 |
+
|
61 |
+
|
62 |
+
def segment2vttblock(segment: dict, strip=True) -> str:
|
63 |
+
return f'{sec2vtt(segment["start"])} --> {sec2vtt(segment["end"])}\n' \
|
64 |
+
f'{finalize_text(segment["text"], strip)}'
|
65 |
+
|
66 |
+
|
67 |
+
def segment2srtblock(segment: dict, idx: int, strip=True) -> str:
|
68 |
+
return f'{idx}\n{sec2srt(segment["start"])} --> {sec2srt(segment["end"])}\n' \
|
69 |
+
f'{finalize_text(segment["text"], strip)}'
|
70 |
+
|
71 |
+
|
72 |
+
def segment2assblock(segment: dict, idx: int, strip=True) -> str:
|
73 |
+
return f'Dialogue: {idx},{sec2ass(segment["start"])},{sec2ass(segment["end"])},Default,,0,0,0,,' \
|
74 |
+
f'{finalize_text(segment["text"], strip)}'
|
75 |
+
|
76 |
+
|
77 |
+
def segment2tsvblock(segment: dict, strip=True) -> str:
|
78 |
+
return f'{sec2milliseconds(segment["start"])}' \
|
79 |
+
f'\t{sec2milliseconds(segment["end"])}' \
|
80 |
+
f'\t{segment["text"].strip() if strip else segment["text"]}'
|
81 |
+
|
82 |
+
|
83 |
+
def words2segments(words: List[dict], tag: Tuple[str, str], reverse_text: bool = False) -> List[dict]:
|
84 |
+
def add_tag(idx: int):
|
85 |
+
return ''.join(
|
86 |
+
(
|
87 |
+
f" {tag[0]}{w['word'][1:]}{tag[1]}"
|
88 |
+
if w['word'].startswith(' ') else
|
89 |
+
f"{tag[0]}{w['word']}{tag[1]}"
|
90 |
+
)
|
91 |
+
if w['word'] not in ('', ' ') and idx_ == idx else
|
92 |
+
w['word']
|
93 |
+
for idx_, w in idx_filled_words
|
94 |
+
)
|
95 |
+
|
96 |
+
filled_words = []
|
97 |
+
for i, word in enumerate(words):
|
98 |
+
curr_end = round(word['end'], 3)
|
99 |
+
filled_words.append(dict(word=word['word'], start=round(word['start'], 3), end=curr_end))
|
100 |
+
if word != words[-1]:
|
101 |
+
next_start = round(words[i + 1]['start'], 3)
|
102 |
+
if next_start - curr_end != 0:
|
103 |
+
filled_words.append(dict(word='', start=curr_end, end=next_start))
|
104 |
+
idx_filled_words = list(enumerate(filled_words))
|
105 |
+
if reverse_text:
|
106 |
+
idx_filled_words = list(reversed(idx_filled_words))
|
107 |
+
|
108 |
+
segments = [dict(text=add_tag(i), start=filled_words[i]['start'], end=filled_words[i]['end'])
|
109 |
+
for i in range(len(filled_words))]
|
110 |
+
return segments
|
111 |
+
|
112 |
+
|
113 |
+
def to_word_level_segments(segments: List[dict], tag: Tuple[str, str]) -> List[dict]:
|
114 |
+
return list(
|
115 |
+
chain.from_iterable(
|
116 |
+
words2segments(s['words'], tag, reverse_text=s.get('reversed_text'))
|
117 |
+
for s in segments
|
118 |
+
)
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def to_vtt_word_level_segments(segments: List[dict], tag: Tuple[str, str] = None) -> List[dict]:
|
123 |
+
def to_segment_string(segment: dict):
|
124 |
+
segment_string = ''
|
125 |
+
prev_end = 0
|
126 |
+
for i, word in enumerate(segment['words']):
|
127 |
+
if i != 0:
|
128 |
+
curr_start = word['start']
|
129 |
+
if prev_end == curr_start:
|
130 |
+
segment_string += f"<{sec2vtt(curr_start)}>"
|
131 |
+
else:
|
132 |
+
if segment_string.endswith(' '):
|
133 |
+
segment_string = segment_string[:-1]
|
134 |
+
elif segment['words'][i]['word'].startswith(' '):
|
135 |
+
segment['words'][i]['word'] = segment['words'][i]['word'][1:]
|
136 |
+
segment_string += f"<{sec2vtt(prev_end)}> <{sec2vtt(curr_start)}>"
|
137 |
+
segment_string += word['word']
|
138 |
+
prev_end = word['end']
|
139 |
+
return segment_string
|
140 |
+
|
141 |
+
return [
|
142 |
+
dict(
|
143 |
+
text=to_segment_string(s),
|
144 |
+
start=s['start'],
|
145 |
+
end=s['end']
|
146 |
+
)
|
147 |
+
for s in segments
|
148 |
+
]
|
149 |
+
|
150 |
+
|
151 |
+
def to_ass_word_level_segments(segments: List[dict], tag: Tuple[str, str], karaoke: bool = False) -> List[dict]:
|
152 |
+
|
153 |
+
def to_segment_string(segment: dict):
|
154 |
+
segment_string = ''
|
155 |
+
for i, word in enumerate(segment['words']):
|
156 |
+
curr_word, space = (word['word'][1:], " ") if word['word'].startswith(" ") else (word['word'], "")
|
157 |
+
segment_string += (
|
158 |
+
space +
|
159 |
+
r"{\k" +
|
160 |
+
("f" if karaoke else "") +
|
161 |
+
f"{sec2centiseconds(word['end']-word['start'])}" +
|
162 |
+
r"}" +
|
163 |
+
curr_word
|
164 |
+
)
|
165 |
+
return segment_string
|
166 |
+
|
167 |
+
return [
|
168 |
+
dict(
|
169 |
+
text=to_segment_string(s),
|
170 |
+
start=s['start'],
|
171 |
+
end=s['end']
|
172 |
+
)
|
173 |
+
for s in segments
|
174 |
+
]
|
175 |
+
|
176 |
+
|
177 |
+
def to_word_level(segments: List[dict]) -> List[dict]:
|
178 |
+
return [dict(text=w['word'], start=w['start'], end=w['end']) for s in segments for w in s['words']]
|
179 |
+
|
180 |
+
|
181 |
+
def _confirm_word_level(segments: List[dict]) -> bool:
|
182 |
+
if not all(bool(s.get('words')) for s in segments):
|
183 |
+
warnings.warn('Result is missing word timestamps. Word-level timing cannot be exported. '
|
184 |
+
'Use "word_level=False" to avoid this warning')
|
185 |
+
return False
|
186 |
+
return True
|
187 |
+
|
188 |
+
|
189 |
+
def _preprocess_args(result: (dict, list),
|
190 |
+
segment_level: bool,
|
191 |
+
word_level: bool,
|
192 |
+
min_dur: float,
|
193 |
+
reverse_text: Union[bool, tuple] = False):
|
194 |
+
assert segment_level or word_level, '`segment_level` or `word_level` must be True'
|
195 |
+
segments = _get_segments(result, min_dur, reverse_text=reverse_text)
|
196 |
+
if word_level:
|
197 |
+
word_level = _confirm_word_level(segments)
|
198 |
+
return segments, segment_level, word_level
|
199 |
+
|
200 |
+
|
201 |
+
def result_to_any(result: (dict, list),
|
202 |
+
filepath: str = None,
|
203 |
+
filetype: str = None,
|
204 |
+
segments2blocks: Callable = None,
|
205 |
+
segment_level=True,
|
206 |
+
word_level=True,
|
207 |
+
min_dur: float = 0.02,
|
208 |
+
tag: Tuple[str, str] = None,
|
209 |
+
default_tag: Tuple[str, str] = None,
|
210 |
+
strip=True,
|
211 |
+
reverse_text: Union[bool, tuple] = False,
|
212 |
+
to_word_level_string_callback: Callable = None):
|
213 |
+
"""
|
214 |
+
Generate file from ``result`` to display segment-level and/or word-level timestamp.
|
215 |
+
|
216 |
+
Returns
|
217 |
+
-------
|
218 |
+
str
|
219 |
+
String of the content if ``filepath`` is ``None``.
|
220 |
+
"""
|
221 |
+
segments, segment_level, word_level = _preprocess_args(
|
222 |
+
result, segment_level, word_level, min_dur, reverse_text=reverse_text
|
223 |
+
)
|
224 |
+
|
225 |
+
if filetype is None:
|
226 |
+
filetype = os.path.splitext(filepath)[-1][1:] or 'srt'
|
227 |
+
if filetype.lower() not in SUPPORTED_FORMATS:
|
228 |
+
raise NotImplementedError(f'{filetype} not supported')
|
229 |
+
if filepath and not filepath.lower().endswith(f'.{filetype}'):
|
230 |
+
filepath += f'.{filetype}'
|
231 |
+
|
232 |
+
if word_level and segment_level:
|
233 |
+
if tag is None:
|
234 |
+
if default_tag is None:
|
235 |
+
tag = ('<font color="#00ff00">', '</font>') if filetype == 'srt' else ('<u>', '</u>')
|
236 |
+
else:
|
237 |
+
tag = default_tag
|
238 |
+
if to_word_level_string_callback is None:
|
239 |
+
to_word_level_string_callback = to_word_level_segments
|
240 |
+
segments = to_word_level_string_callback(segments, tag)
|
241 |
+
elif word_level:
|
242 |
+
segments = to_word_level(segments)
|
243 |
+
|
244 |
+
valid_ts(segments)
|
245 |
+
|
246 |
+
if segments2blocks is None:
|
247 |
+
sub_str = '\n\n'.join(segment2srtblock(s, i, strip=strip) for i, s in enumerate(segments))
|
248 |
+
else:
|
249 |
+
sub_str = segments2blocks(segments)
|
250 |
+
|
251 |
+
if filepath:
|
252 |
+
_save_as_file(sub_str, filepath)
|
253 |
+
else:
|
254 |
+
return sub_str
|
255 |
+
|
256 |
+
|
257 |
+
def result_to_srt_vtt(result: (dict, list),
|
258 |
+
filepath: str = None,
|
259 |
+
segment_level=True,
|
260 |
+
word_level=True,
|
261 |
+
min_dur: float = 0.02,
|
262 |
+
tag: Tuple[str, str] = None,
|
263 |
+
vtt: bool = None,
|
264 |
+
strip=True,
|
265 |
+
reverse_text: Union[bool, tuple] = False):
|
266 |
+
"""
|
267 |
+
Generate SRT/VTT from ``result`` to display segment-level and/or word-level timestamp.
|
268 |
+
|
269 |
+
Parameters
|
270 |
+
----------
|
271 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
272 |
+
Result of transcription.
|
273 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
274 |
+
Path to save file.
|
275 |
+
segment_level : bool, default True
|
276 |
+
Whether to use segment-level timestamps in output.
|
277 |
+
word_level : bool, default True
|
278 |
+
Whether to use word-level timestamps in output.
|
279 |
+
min_dur : float, default 0.2
|
280 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
281 |
+
tag: tuple of (str, str), default None, meaning ('<font color="#00ff00">', '</font>') if SRT else ('<u>', '</u>')
|
282 |
+
Tag used to change the properties a word at its timestamp.
|
283 |
+
vtt : bool, default None, meaning determined by extension of ``filepath`` or ``False`` if no valid extension.
|
284 |
+
Whether to output VTT.
|
285 |
+
strip : bool, default True
|
286 |
+
Whether to remove spaces before and after text on each segment for output.
|
287 |
+
reverse_text: bool or tuple, default False
|
288 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
289 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
290 |
+
|
291 |
+
Returns
|
292 |
+
-------
|
293 |
+
str
|
294 |
+
String of the content if ``filepath`` is ``None``.
|
295 |
+
|
296 |
+
Notes
|
297 |
+
-----
|
298 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
299 |
+
seems to not suffer from this issue.
|
300 |
+
|
301 |
+
Examples
|
302 |
+
--------
|
303 |
+
>>> import stable_whisper
|
304 |
+
>>> model = stable_whisper.load_model('base')
|
305 |
+
>>> result = model.transcribe('audio.mp3')
|
306 |
+
>>> result.to_srt_vtt('audio.srt')
|
307 |
+
Saved: audio.srt
|
308 |
+
"""
|
309 |
+
is_srt = (filepath is None or not filepath.lower().endswith('.vtt')) if vtt is None else not vtt
|
310 |
+
if is_srt:
|
311 |
+
segments2blocks = None
|
312 |
+
to_word_level_string_callback = None
|
313 |
+
else:
|
314 |
+
def segments2blocks(segments):
|
315 |
+
return 'WEBVTT\n\n' + '\n\n'.join(segment2vttblock(s, strip=strip) for i, s in enumerate(segments))
|
316 |
+
to_word_level_string_callback = to_vtt_word_level_segments if tag is None else tag
|
317 |
+
|
318 |
+
return result_to_any(
|
319 |
+
result=result,
|
320 |
+
filepath=filepath,
|
321 |
+
filetype=('vtt', 'srt')[is_srt],
|
322 |
+
segments2blocks=segments2blocks,
|
323 |
+
segment_level=segment_level,
|
324 |
+
word_level=word_level,
|
325 |
+
min_dur=min_dur,
|
326 |
+
tag=tag,
|
327 |
+
strip=strip,
|
328 |
+
reverse_text=reverse_text,
|
329 |
+
to_word_level_string_callback=to_word_level_string_callback
|
330 |
+
)
|
331 |
+
|
332 |
+
|
333 |
+
def result_to_tsv(result: (dict, list),
|
334 |
+
filepath: str = None,
|
335 |
+
segment_level: bool = None,
|
336 |
+
word_level: bool = None,
|
337 |
+
min_dur: float = 0.02,
|
338 |
+
strip=True,
|
339 |
+
reverse_text: Union[bool, tuple] = False):
|
340 |
+
"""
|
341 |
+
Generate TSV from ``result`` to display segment-level and/or word-level timestamp.
|
342 |
+
|
343 |
+
Parameters
|
344 |
+
----------
|
345 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
346 |
+
Result of transcription.
|
347 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
348 |
+
Path to save file.
|
349 |
+
segment_level : bool, default True
|
350 |
+
Whether to use segment-level timestamps in output.
|
351 |
+
word_level : bool, default True
|
352 |
+
Whether to use word-level timestamps in output.
|
353 |
+
min_dur : float, default 0.2
|
354 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
355 |
+
strip : bool, default True
|
356 |
+
Whether to remove spaces before and after text on each segment for output.
|
357 |
+
reverse_text: bool or tuple, default False
|
358 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
359 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
360 |
+
|
361 |
+
Returns
|
362 |
+
-------
|
363 |
+
str
|
364 |
+
String of the content if ``filepath`` is ``None``.
|
365 |
+
|
366 |
+
Notes
|
367 |
+
-----
|
368 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
369 |
+
seems to not suffer from this issue.
|
370 |
+
|
371 |
+
Examples
|
372 |
+
--------
|
373 |
+
>>> import stable_whisper
|
374 |
+
>>> model = stable_whisper.load_model('base')
|
375 |
+
>>> result = model.transcribe('audio.mp3')
|
376 |
+
>>> result.to_tsv('audio.tsv')
|
377 |
+
Saved: audio.tsv
|
378 |
+
"""
|
379 |
+
if segment_level is None and word_level is None:
|
380 |
+
segment_level = True
|
381 |
+
assert word_level is not segment_level, '[word_level] and [segment_level] cannot be the same ' \
|
382 |
+
'since [tag] is not support for this format'
|
383 |
+
|
384 |
+
def segments2blocks(segments):
|
385 |
+
return '\n\n'.join(segment2tsvblock(s, strip=strip) for i, s in enumerate(segments))
|
386 |
+
return result_to_any(
|
387 |
+
result=result,
|
388 |
+
filepath=filepath,
|
389 |
+
filetype='tsv',
|
390 |
+
segments2blocks=segments2blocks,
|
391 |
+
segment_level=segment_level,
|
392 |
+
word_level=word_level,
|
393 |
+
min_dur=min_dur,
|
394 |
+
strip=strip,
|
395 |
+
reverse_text=reverse_text
|
396 |
+
)
|
397 |
+
|
398 |
+
|
399 |
+
def result_to_ass(result: (dict, list),
|
400 |
+
filepath: str = None,
|
401 |
+
segment_level=True,
|
402 |
+
word_level=True,
|
403 |
+
min_dur: float = 0.02,
|
404 |
+
tag: Union[Tuple[str, str], int] = None,
|
405 |
+
font: str = None,
|
406 |
+
font_size: int = 24,
|
407 |
+
strip=True,
|
408 |
+
highlight_color: str = None,
|
409 |
+
karaoke=False,
|
410 |
+
reverse_text: Union[bool, tuple] = False,
|
411 |
+
**kwargs):
|
412 |
+
"""
|
413 |
+
Generate Advanced SubStation Alpha (ASS) file from ``result`` to display segment-level and/or word-level timestamp.
|
414 |
+
|
415 |
+
Parameters
|
416 |
+
----------
|
417 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
418 |
+
Result of transcription.
|
419 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
420 |
+
Path to save file.
|
421 |
+
segment_level : bool, default True
|
422 |
+
Whether to use segment-level timestamps in output.
|
423 |
+
word_level : bool, default True
|
424 |
+
Whether to use word-level timestamps in output.
|
425 |
+
min_dur : float, default 0.2
|
426 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
427 |
+
tag: tuple of (str, str) or int, default None, meaning use default highlighting
|
428 |
+
Tag used to change the properties a word at its timestamp. -1 for individual word highlight tag.
|
429 |
+
font : str, default `Arial`
|
430 |
+
Word font.
|
431 |
+
font_size : int, default 48
|
432 |
+
Word font size.
|
433 |
+
strip : bool, default True
|
434 |
+
Whether to remove spaces before and after text on each segment for output.
|
435 |
+
highlight_color : str, default '00ff00'
|
436 |
+
Hexadecimal of the color use for default highlights as '<bb><gg><rr>'.
|
437 |
+
karaoke : bool, default False
|
438 |
+
Whether to use progressive filling highlights (for karaoke effect).
|
439 |
+
reverse_text: bool or tuple, default False
|
440 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
441 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
442 |
+
kwargs:
|
443 |
+
Format styles:
|
444 |
+
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
|
445 |
+
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
|
446 |
+
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
|
447 |
+
|
448 |
+
Returns
|
449 |
+
-------
|
450 |
+
str
|
451 |
+
String of the content if ``filepath`` is ``None``.
|
452 |
+
|
453 |
+
Notes
|
454 |
+
-----
|
455 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
456 |
+
seems to not suffer from this issue.
|
457 |
+
|
458 |
+
Examples
|
459 |
+
--------
|
460 |
+
>>> import stable_whisper
|
461 |
+
>>> model = stable_whisper.load_model('base')
|
462 |
+
>>> result = model.transcribe('audio.mp3')
|
463 |
+
>>> result.to_ass('audio.ass')
|
464 |
+
Saved: audio.ass
|
465 |
+
"""
|
466 |
+
if tag == ['-1']: # CLI
|
467 |
+
tag = -1
|
468 |
+
if highlight_color is None:
|
469 |
+
highlight_color = '00ff00'
|
470 |
+
|
471 |
+
def segments2blocks(segments):
|
472 |
+
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
|
473 |
+
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
|
474 |
+
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
|
475 |
+
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
|
476 |
+
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
|
477 |
+
|
478 |
+
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
|
479 |
+
kwargs[k] = f'&H{kwargs[k]}'
|
480 |
+
|
481 |
+
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
|
482 |
+
|
483 |
+
if tag is None and 'PrimaryColour' not in kwargs:
|
484 |
+
fmt_style_dict['PrimaryColour'] = \
|
485 |
+
highlight_color if highlight_color.startswith('&H') else f'&H{highlight_color}'
|
486 |
+
|
487 |
+
if font:
|
488 |
+
fmt_style_dict.update(Fontname=font)
|
489 |
+
if font_size:
|
490 |
+
fmt_style_dict.update(Fontsize=font_size)
|
491 |
+
|
492 |
+
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
|
493 |
+
|
494 |
+
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
|
495 |
+
|
496 |
+
sub_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
|
497 |
+
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
|
498 |
+
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
|
499 |
+
|
500 |
+
sub_str += '\n'.join(segment2assblock(s, i, strip=strip) for i, s in enumerate(segments))
|
501 |
+
|
502 |
+
return sub_str
|
503 |
+
|
504 |
+
if tag is not None and karaoke:
|
505 |
+
warnings.warn(f'[tag] is not support for [karaoke]=True; [tag] will be ignored.')
|
506 |
+
|
507 |
+
return result_to_any(
|
508 |
+
result=result,
|
509 |
+
filepath=filepath,
|
510 |
+
filetype='ass',
|
511 |
+
segments2blocks=segments2blocks,
|
512 |
+
segment_level=segment_level,
|
513 |
+
word_level=word_level,
|
514 |
+
min_dur=min_dur,
|
515 |
+
tag=None if tag == -1 else tag,
|
516 |
+
default_tag=(r'{\1c' + f'{highlight_color}&' + '}', r'{\r}'),
|
517 |
+
strip=strip,
|
518 |
+
reverse_text=reverse_text,
|
519 |
+
to_word_level_string_callback=(
|
520 |
+
(lambda s, t: to_ass_word_level_segments(s, t, karaoke=karaoke))
|
521 |
+
if karaoke or (word_level and segment_level and tag is None)
|
522 |
+
else None
|
523 |
+
)
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def result_to_txt(
|
528 |
+
result: (dict, list),
|
529 |
+
filepath: str = None,
|
530 |
+
min_dur: float = 0.02,
|
531 |
+
strip=True,
|
532 |
+
reverse_text: Union[bool, tuple] = False
|
533 |
+
):
|
534 |
+
"""
|
535 |
+
Generate plain-text without timestamps from ``result``.
|
536 |
+
|
537 |
+
Parameters
|
538 |
+
----------
|
539 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
540 |
+
Result of transcription.
|
541 |
+
filepath : str, default None, meaning content will be returned as a ``str``
|
542 |
+
Path to save file.
|
543 |
+
min_dur : float, default 0.2
|
544 |
+
Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments.
|
545 |
+
strip : bool, default True
|
546 |
+
Whether to remove spaces before and after text on each segment for output.
|
547 |
+
reverse_text: bool or tuple, default False
|
548 |
+
Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and
|
549 |
+
``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations.
|
550 |
+
|
551 |
+
Returns
|
552 |
+
-------
|
553 |
+
str
|
554 |
+
String of the content if ``filepath`` is ``None``.
|
555 |
+
|
556 |
+
Notes
|
557 |
+
-----
|
558 |
+
``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC
|
559 |
+
seems to not suffer from this issue.
|
560 |
+
|
561 |
+
Examples
|
562 |
+
--------
|
563 |
+
>>> import stable_whisper
|
564 |
+
>>> model = stable_whisper.load_model('base')
|
565 |
+
>>> result = model.transcribe('audio.mp3')
|
566 |
+
>>> result.to_txt('audio.txt')
|
567 |
+
Saved: audio.txt
|
568 |
+
"""
|
569 |
+
|
570 |
+
def segments2blocks(segments: dict, _strip=True) -> str:
|
571 |
+
return '\n'.join(f'{segment["text"].strip() if _strip else segment["text"]}' for segment in segments)
|
572 |
+
|
573 |
+
return result_to_any(
|
574 |
+
result=result,
|
575 |
+
filepath=filepath,
|
576 |
+
filetype='txt',
|
577 |
+
segments2blocks=segments2blocks,
|
578 |
+
segment_level=True,
|
579 |
+
word_level=False,
|
580 |
+
min_dur=min_dur,
|
581 |
+
strip=strip,
|
582 |
+
reverse_text=reverse_text
|
583 |
+
)
|
584 |
+
|
585 |
+
|
586 |
+
def save_as_json(result: dict, path: str, ensure_ascii: bool = False, **kwargs):
|
587 |
+
"""
|
588 |
+
Save ``result`` as JSON file to ``path``.
|
589 |
+
|
590 |
+
Parameters
|
591 |
+
----------
|
592 |
+
result : dict or list or stable_whisper.result.WhisperResult
|
593 |
+
Result of transcription.
|
594 |
+
path : str
|
595 |
+
Path to save file.
|
596 |
+
ensure_ascii : bool, default False
|
597 |
+
Whether to escape non-ASCII characters.
|
598 |
+
|
599 |
+
Examples
|
600 |
+
--------
|
601 |
+
>>> import stable_whisper
|
602 |
+
>>> model = stable_whisper.load_model('base')
|
603 |
+
>>> result = model.transcribe('audio.mp3')
|
604 |
+
>>> result.save_as_json('audio.json')
|
605 |
+
Saved: audio.json
|
606 |
+
"""
|
607 |
+
if not isinstance(result, dict) and callable(getattr(result, 'to_dict')):
|
608 |
+
result = result.to_dict()
|
609 |
+
if not path.lower().endswith('.json'):
|
610 |
+
path += '.json'
|
611 |
+
result = json.dumps(result, allow_nan=True, ensure_ascii=ensure_ascii, **kwargs)
|
612 |
+
_save_as_file(result, path)
|
613 |
+
|
614 |
+
|
615 |
+
def load_result(json_path: str) -> dict:
|
616 |
+
"""
|
617 |
+
Return a ``dict`` of the contents in ``json_path``.
|
618 |
+
"""
|
619 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
620 |
+
return json.load(f)
|
stable_whisper/timing.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import string
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import TYPE_CHECKING, List, Callable, Optional
|
5 |
+
from itertools import chain
|
6 |
+
from whisper.audio import TOKENS_PER_SECOND, N_SAMPLES_PER_TOKEN
|
7 |
+
from whisper.timing import WordTiming, median_filter, dtw, merge_punctuations
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from whisper.tokenizer import Tokenizer
|
11 |
+
from whisper.model import Whisper
|
12 |
+
|
13 |
+
|
14 |
+
# modified version of whisper.timing.find_alignment
|
15 |
+
def find_alignment_stable(
|
16 |
+
model: "Whisper",
|
17 |
+
tokenizer: "Tokenizer",
|
18 |
+
text_tokens: List[int],
|
19 |
+
mel: torch.Tensor,
|
20 |
+
num_samples: int,
|
21 |
+
*,
|
22 |
+
medfilt_width: int = 7,
|
23 |
+
qk_scale: float = 1.0,
|
24 |
+
ts_num: int = 0,
|
25 |
+
ts_noise: float = 0.1,
|
26 |
+
token_split=None,
|
27 |
+
audio_features: torch.Tensor = None
|
28 |
+
) -> List[WordTiming]:
|
29 |
+
tokens = torch.tensor(
|
30 |
+
[
|
31 |
+
*tokenizer.sot_sequence,
|
32 |
+
tokenizer.no_timestamps,
|
33 |
+
*text_tokens,
|
34 |
+
tokenizer.eot,
|
35 |
+
]
|
36 |
+
).to(model.device)
|
37 |
+
|
38 |
+
# install hooks on the cross attention layers to retrieve the attention weights
|
39 |
+
QKs = [None] * model.dims.n_text_layer
|
40 |
+
hooks = [
|
41 |
+
block.cross_attn.register_forward_hook(
|
42 |
+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
|
43 |
+
)
|
44 |
+
for i, block in enumerate(model.decoder.blocks)
|
45 |
+
]
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
if audio_features is None:
|
49 |
+
audio_features = model.encoder(mel.unsqueeze(0))
|
50 |
+
if ts_num:
|
51 |
+
if ts_noise is None:
|
52 |
+
ts_noise = 0.1
|
53 |
+
extra_audio_features = audio_features.repeat_interleave(ts_num, 0)
|
54 |
+
torch.manual_seed(0)
|
55 |
+
audio_features = torch.cat([audio_features,
|
56 |
+
extra_audio_features *
|
57 |
+
(1 - (torch.rand_like(extra_audio_features) * ts_noise))],
|
58 |
+
dim=0)
|
59 |
+
logits = model.decoder(tokens.unsqueeze(0).repeat_interleave(audio_features.shape[0], 0),
|
60 |
+
audio_features)
|
61 |
+
else:
|
62 |
+
logits = model.decoder(tokens.unsqueeze(0), audio_features)
|
63 |
+
|
64 |
+
logits = logits[0]
|
65 |
+
sampled_logits = logits[len(tokenizer.sot_sequence):, : tokenizer.eot]
|
66 |
+
token_probs = sampled_logits.softmax(dim=-1)
|
67 |
+
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
68 |
+
text_token_probs = text_token_probs.tolist()
|
69 |
+
|
70 |
+
for hook in hooks:
|
71 |
+
hook.remove()
|
72 |
+
|
73 |
+
# heads * tokens * frames
|
74 |
+
weights = torch.cat([QKs[_l][:, _h] for _l, _h in model.alignment_heads.indices().T], dim=0)
|
75 |
+
weights = weights[:, :, : round(num_samples / N_SAMPLES_PER_TOKEN)]
|
76 |
+
weights = (weights * qk_scale).softmax(dim=-1)
|
77 |
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
78 |
+
weights = (weights - mean) / std
|
79 |
+
weights = median_filter(weights, medfilt_width)
|
80 |
+
|
81 |
+
matrix = weights.mean(axis=0)
|
82 |
+
matrix = matrix[len(tokenizer.sot_sequence): -1]
|
83 |
+
text_indices, time_indices = dtw(-matrix)
|
84 |
+
|
85 |
+
if token_split is None:
|
86 |
+
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
87 |
+
else:
|
88 |
+
words, word_tokens = token_split
|
89 |
+
words.append(tokenizer.decode([tokenizer.eot]))
|
90 |
+
word_tokens.append([tokenizer.eot])
|
91 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
92 |
+
|
93 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
94 |
+
jump_times = time_indices[jumps].clip(min=0) / TOKENS_PER_SECOND
|
95 |
+
start_times = jump_times[word_boundaries[:-1]]
|
96 |
+
end_times = jump_times[word_boundaries[1:]]
|
97 |
+
word_probabilities = [
|
98 |
+
np.mean(text_token_probs[i:j])
|
99 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
100 |
+
]
|
101 |
+
|
102 |
+
return [
|
103 |
+
WordTiming(word, tokens, start, end, probability)
|
104 |
+
for word, tokens, start, end, probability in zip(
|
105 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
106 |
+
)
|
107 |
+
]
|
108 |
+
|
109 |
+
|
110 |
+
def _split_tokens(tokens: List[int], tokenizer: "Tokenizer"):
|
111 |
+
split_by_space = getattr(tokenizer, 'language_code', tokenizer.language) not in {"zh", "ja", "th", "lo", "my"}
|
112 |
+
text = tokenizer.decode_with_timestamps(tokens)
|
113 |
+
words = []
|
114 |
+
word_tokens = []
|
115 |
+
curr_tokens = []
|
116 |
+
is_append = False
|
117 |
+
for token in tokens:
|
118 |
+
curr_tokens.append(token)
|
119 |
+
curr_text = tokenizer.decode(curr_tokens)
|
120 |
+
is_whole = token >= tokenizer.eot
|
121 |
+
if not is_whole:
|
122 |
+
is_whole = text[:len(curr_text)] == curr_text
|
123 |
+
if is_whole and split_by_space:
|
124 |
+
is_append = not (curr_text.startswith(" ") or curr_text.strip() in string.punctuation)
|
125 |
+
|
126 |
+
if is_whole:
|
127 |
+
if is_append and len(words) != 0:
|
128 |
+
words[-1] += curr_text
|
129 |
+
word_tokens[-1].extend(curr_tokens)
|
130 |
+
else:
|
131 |
+
words.append(curr_text)
|
132 |
+
word_tokens.append(curr_tokens)
|
133 |
+
text = text[len(curr_text):]
|
134 |
+
curr_tokens = []
|
135 |
+
|
136 |
+
if len(curr_tokens) != 0:
|
137 |
+
words.append(curr_text if len(text) == 0 else text)
|
138 |
+
word_tokens.append(curr_tokens)
|
139 |
+
elif len(text) != 0:
|
140 |
+
words[-1] += text
|
141 |
+
|
142 |
+
return words, word_tokens
|
143 |
+
|
144 |
+
|
145 |
+
def split_word_tokens(segments: List[dict],
|
146 |
+
tokenizer: "Tokenizer",
|
147 |
+
*,
|
148 |
+
padding: (str, int) = None,
|
149 |
+
split_callback: Callable = None):
|
150 |
+
if padding is not None:
|
151 |
+
if isinstance(padding, str):
|
152 |
+
padding = tokenizer.encode(padding)
|
153 |
+
else:
|
154 |
+
padding = [padding]
|
155 |
+
tokens = []
|
156 |
+
seg_indices = []
|
157 |
+
words = []
|
158 |
+
word_tokens = []
|
159 |
+
for i, s in enumerate(segments):
|
160 |
+
temp_word_tokens = [t for t in s['tokens'] if not isinstance(t, int) or t < tokenizer.eot]
|
161 |
+
curr_words, curr_word_tokens = (
|
162 |
+
_split_tokens(temp_word_tokens, tokenizer)
|
163 |
+
if split_callback is None else
|
164 |
+
split_callback(temp_word_tokens, tokenizer)
|
165 |
+
)
|
166 |
+
assert len(curr_words) == len(curr_word_tokens), \
|
167 |
+
f'word count and token group count do not match, {len(curr_words)} and {len(curr_word_tokens)}'
|
168 |
+
if (
|
169 |
+
padding is not None and
|
170 |
+
curr_word_tokens[0][0] != padding and
|
171 |
+
(len(tokens) == 0 or tokens[-1] != padding)
|
172 |
+
):
|
173 |
+
tokens.extend(padding)
|
174 |
+
words.append(None)
|
175 |
+
word_tokens.append(padding)
|
176 |
+
seg_indices.extend([i] * len(curr_words))
|
177 |
+
tokens.extend(list(chain.from_iterable(curr_word_tokens)))
|
178 |
+
words.extend(curr_words)
|
179 |
+
word_tokens.extend(curr_word_tokens)
|
180 |
+
|
181 |
+
return tokens, (words, word_tokens), seg_indices
|
182 |
+
|
183 |
+
|
184 |
+
def pop_empty_alignment(alignment: List[WordTiming]):
|
185 |
+
return list(reversed([alignment.pop(i) for i in reversed(range(len(alignment))) if alignment[i].word is None]))
|
186 |
+
|
187 |
+
|
188 |
+
# modified version of whisper.timing.add_word_timestamps
|
189 |
+
def add_word_timestamps_stable(
|
190 |
+
*,
|
191 |
+
segments: List[dict],
|
192 |
+
model: "Whisper",
|
193 |
+
tokenizer: "Tokenizer",
|
194 |
+
mel: torch.Tensor,
|
195 |
+
num_samples: int,
|
196 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
197 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
198 |
+
audio_features: torch.Tensor = None,
|
199 |
+
ts_num: int = 0,
|
200 |
+
ts_noise: float = 0.1,
|
201 |
+
min_word_dur: float = 0.1,
|
202 |
+
split_callback: Callable = None,
|
203 |
+
gap_padding: Optional[str] = ' ...',
|
204 |
+
**kwargs,
|
205 |
+
):
|
206 |
+
if len(segments) == 0:
|
207 |
+
return
|
208 |
+
|
209 |
+
if min_word_dur is None:
|
210 |
+
min_word_dur = 0
|
211 |
+
|
212 |
+
if prepend_punctuations is None:
|
213 |
+
prepend_punctuations = "\"'“¿([{-"
|
214 |
+
|
215 |
+
if append_punctuations is None:
|
216 |
+
append_punctuations = "\"'.。,,!!??::”)]}、"
|
217 |
+
|
218 |
+
def align():
|
219 |
+
for seg in segments:
|
220 |
+
seg['words'] = []
|
221 |
+
|
222 |
+
text_tokens, token_split, seg_indices = split_word_tokens(segments, tokenizer,
|
223 |
+
padding=gap_padding, split_callback=split_callback)
|
224 |
+
|
225 |
+
alignment = find_alignment_stable(model, tokenizer, text_tokens, mel, num_samples,
|
226 |
+
**kwargs,
|
227 |
+
token_split=token_split,
|
228 |
+
audio_features=audio_features,
|
229 |
+
ts_num=ts_num,
|
230 |
+
ts_noise=ts_noise)
|
231 |
+
alt_beginning_alignment = pop_empty_alignment(alignment)
|
232 |
+
|
233 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
234 |
+
|
235 |
+
time_offset = segments[0]["seek"]
|
236 |
+
|
237 |
+
assert len(alignment) == len(seg_indices)
|
238 |
+
assert (gap_padding is None or len(segments) == len(alt_beginning_alignment))
|
239 |
+
for i, timing in zip(seg_indices, alignment):
|
240 |
+
if len(timing.tokens) != 0:
|
241 |
+
start = timing.start
|
242 |
+
end = timing.end
|
243 |
+
if (
|
244 |
+
len(segments[i]['words']) == 0 and
|
245 |
+
((end - start) < min_word_dur) and
|
246 |
+
len(alt_beginning_alignment)
|
247 |
+
):
|
248 |
+
start = alt_beginning_alignment[i].start
|
249 |
+
segments[i]['words'].append(
|
250 |
+
dict(
|
251 |
+
word=timing.word,
|
252 |
+
start=round(time_offset + start, 3),
|
253 |
+
end=round(time_offset + end, 3),
|
254 |
+
probability=timing.probability,
|
255 |
+
tokens=timing.tokens
|
256 |
+
)
|
257 |
+
)
|
258 |
+
|
259 |
+
align()
|
260 |
+
if (
|
261 |
+
gap_padding is not None and
|
262 |
+
any(
|
263 |
+
(word['end'] - word['start']) < min_word_dur
|
264 |
+
for seg in segments
|
265 |
+
for word in seg['words']
|
266 |
+
)
|
267 |
+
):
|
268 |
+
gap_padding = None
|
269 |
+
align()
|
270 |
+
|
271 |
+
for segment in segments:
|
272 |
+
if len(words := segment["words"]) > 0:
|
273 |
+
# adjust the segment-level timestamps based on the word-level timestamps
|
274 |
+
segment["start"] = words[0]["start"]
|
275 |
+
segment["end"] = words[-1]["end"]
|
stable_whisper/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
system_encoding = sys.getdefaultencoding()
|
6 |
+
|
7 |
+
if system_encoding != "utf-8":
|
8 |
+
|
9 |
+
def make_safe(string):
|
10 |
+
# replaces any character not representable using the system default encoding with an '?',
|
11 |
+
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
12 |
+
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
13 |
+
|
14 |
+
else:
|
15 |
+
|
16 |
+
def make_safe(string):
|
17 |
+
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
18 |
+
return string
|
19 |
+
|
20 |
+
|
21 |
+
def str_to_valid_type(val: str):
|
22 |
+
if len(val) == 0:
|
23 |
+
return None
|
24 |
+
if '/' in val:
|
25 |
+
return [a.split('*') if '*' in a else a for a in val.split('/')]
|
26 |
+
try:
|
27 |
+
val = float(val) if '.' in val else int(val)
|
28 |
+
except ValueError:
|
29 |
+
pass
|
30 |
+
finally:
|
31 |
+
return val
|
32 |
+
|
33 |
+
|
34 |
+
def get_func_parameters(func):
|
35 |
+
return inspect.signature(func).parameters.keys()
|
36 |
+
|
37 |
+
|
38 |
+
def isolate_useful_options(options: dict, method, pop: bool = False) -> dict:
|
39 |
+
_get = dict.pop if pop else dict.get
|
40 |
+
return {k: _get(options, k) for k in get_func_parameters(method) if k in options}
|
41 |
+
|
42 |
+
|
43 |
+
def safe_print(msg: str, _print=None):
|
44 |
+
if msg:
|
45 |
+
(_print or print)(make_safe(msg))
|
46 |
+
|
47 |
+
|
48 |
+
def format_timestamp(
|
49 |
+
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
50 |
+
):
|
51 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
52 |
+
milliseconds = round(seconds * 1000.0)
|
53 |
+
|
54 |
+
hours = milliseconds // 3_600_000
|
55 |
+
milliseconds -= hours * 3_600_000
|
56 |
+
|
57 |
+
minutes = milliseconds // 60_000
|
58 |
+
milliseconds -= minutes * 60_000
|
59 |
+
|
60 |
+
seconds = milliseconds // 1_000
|
61 |
+
milliseconds -= seconds * 1_000
|
62 |
+
|
63 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
64 |
+
return (
|
65 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class UnsortedException(Exception):
|
70 |
+
|
71 |
+
def __init__(self, message: str = None, data: dict = None):
|
72 |
+
if not message:
|
73 |
+
message = 'Timestamps are not in ascending order. If data is produced by Stable-ts, please submit an issue.'
|
74 |
+
super().__init__(message)
|
75 |
+
self.data = data
|
76 |
+
|
77 |
+
def get_data(self):
|
78 |
+
return self.data
|
stable_whisper/video_output.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess as sp
|
3 |
+
import warnings
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
__all__ = ['encode_video_comparison']
|
7 |
+
|
8 |
+
|
9 |
+
def encode_video_comparison(
|
10 |
+
audiofile: str,
|
11 |
+
subtitle_files: List[str],
|
12 |
+
output_videopath: str = None,
|
13 |
+
*,
|
14 |
+
labels: List[str] = None,
|
15 |
+
height: int = 90,
|
16 |
+
width: int = 720,
|
17 |
+
color: str = 'black',
|
18 |
+
fontsize: int = 70,
|
19 |
+
border_color: str = 'white',
|
20 |
+
label_color: str = 'white',
|
21 |
+
label_size: int = 14,
|
22 |
+
fps: int = 25,
|
23 |
+
video_codec: str = None,
|
24 |
+
audio_codec: str = None,
|
25 |
+
overwrite=False,
|
26 |
+
only_cmd: bool = False,
|
27 |
+
verbose=True
|
28 |
+
) -> (str, None):
|
29 |
+
"""
|
30 |
+
Encode multiple subtitle files into one video with the subtitles vertically stacked.
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
audiofile : str
|
35 |
+
Path of audio file.
|
36 |
+
subtitle_files : list of str
|
37 |
+
List of paths for subtitle file.
|
38 |
+
output_videopath : str, optional
|
39 |
+
Output video path.
|
40 |
+
labels : list of str, default, None, meaning use ``subtitle_files`` as labels
|
41 |
+
List of labels for ``subtitle_files``.
|
42 |
+
height : int, default 90
|
43 |
+
Height for each subtitle section.
|
44 |
+
width : int, default 720
|
45 |
+
Width for each subtitle section.
|
46 |
+
color : str, default 'black'
|
47 |
+
Background color of the video.
|
48 |
+
fontsize: int, default 70
|
49 |
+
Font size for subtitles.
|
50 |
+
border_color : str, default 'white'
|
51 |
+
Border color for separating the sections of subtitle.
|
52 |
+
label_color : str, default 'white'
|
53 |
+
Color of labels.
|
54 |
+
label_size : int, default 14
|
55 |
+
Font size of labels.
|
56 |
+
fps : int, default 25
|
57 |
+
Frame-rate of the video.
|
58 |
+
video_codec : str, optional
|
59 |
+
Video codec opf the video.
|
60 |
+
audio_codec : str, optional
|
61 |
+
Audio codec opf the video.
|
62 |
+
overwrite : bool, default False
|
63 |
+
Whether to overwrite existing video files with the same path as the output video.
|
64 |
+
only_cmd : bool, default False
|
65 |
+
Whether to skip encoding and only return the full command generate from the specified options.
|
66 |
+
verbose : bool, default True
|
67 |
+
Whether to display ffmpeg processing info.
|
68 |
+
|
69 |
+
Returns
|
70 |
+
-------
|
71 |
+
str or None
|
72 |
+
Encoding command as a string if ``only_cmd = True``.
|
73 |
+
"""
|
74 |
+
vc = '' if video_codec is None else f' -c:v {video_codec}'
|
75 |
+
ac = '' if audio_codec is None else f' -c:a {audio_codec}'
|
76 |
+
background = f'-f lavfi -i color=size={width}x{height}:rate={fps}:color={color}'
|
77 |
+
border = f'-f lavfi -i color=size={width}x3:rate={fps}:color={border_color}'
|
78 |
+
audio = f'-i "{audiofile}"'
|
79 |
+
cfilters0 = []
|
80 |
+
assert labels is None or len(labels) == len(subtitle_files)
|
81 |
+
for i, sub in enumerate(subtitle_files):
|
82 |
+
label = sub if labels is None else labels[i]
|
83 |
+
label = label.replace("'", '"')
|
84 |
+
fil = f"[0]drawtext=text='{label}':fontcolor={label_color}:fontsize={label_size}:x=10:y=10[a{i}]," \
|
85 |
+
f"[a{i}]subtitles='{sub}':force_style='Fontsize={fontsize}'[b{i}]"
|
86 |
+
cfilters0.append(fil)
|
87 |
+
cfilters1 = (
|
88 |
+
'[1]'.join(
|
89 |
+
f'[b{i}]' for i in range(len(cfilters0))
|
90 |
+
)
|
91 |
+
+
|
92 |
+
f'vstack=inputs={len(cfilters0) * 2 - 1}'
|
93 |
+
)
|
94 |
+
final_fil = ','.join(cfilters0) + f';{cfilters1}'
|
95 |
+
ow = '-y' if overwrite else '-n'
|
96 |
+
if output_videopath is None:
|
97 |
+
name = os.path.split(os.path.splitext(audiofile)[0])[1]
|
98 |
+
output_videopath = f'{name}_sub_comparison.mp4'
|
99 |
+
cmd = (f'ffmpeg {ow} {background} {border} {audio} '
|
100 |
+
f'-filter_complex "{final_fil}"{vc}{ac} -shortest "{output_videopath}"')
|
101 |
+
if only_cmd:
|
102 |
+
return cmd
|
103 |
+
if verbose:
|
104 |
+
print(cmd)
|
105 |
+
rc = sp.run(cmd, capture_output=not verbose).returncode
|
106 |
+
if rc == 0:
|
107 |
+
if verbose:
|
108 |
+
print(f'Encoded: {output_videopath}')
|
109 |
+
else:
|
110 |
+
warnings.warn(f'Failed to encode {output_videopath}')
|
111 |
+
|
stable_whisper/whisper_compatibility.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import importlib.metadata
|
3 |
+
|
4 |
+
import whisper.tokenizer
|
5 |
+
|
6 |
+
from .utils import get_func_parameters
|
7 |
+
|
8 |
+
_COMPATIBLE_WHISPER_VERSIONS = (
|
9 |
+
'20230314',
|
10 |
+
'20230918',
|
11 |
+
'20231105',
|
12 |
+
'20231106',
|
13 |
+
'20231117',
|
14 |
+
)
|
15 |
+
_required_whisper_ver = _COMPATIBLE_WHISPER_VERSIONS[-1]
|
16 |
+
|
17 |
+
_TOKENIZER_PARAMS = get_func_parameters(whisper.tokenizer.get_tokenizer)
|
18 |
+
|
19 |
+
|
20 |
+
def warn_compatibility_issues(
|
21 |
+
whisper_module,
|
22 |
+
ignore: bool = False,
|
23 |
+
additional_msg: str = ''
|
24 |
+
):
|
25 |
+
compatibility_warning = ''
|
26 |
+
if not ignore:
|
27 |
+
if whisper_module.__version__ not in _COMPATIBLE_WHISPER_VERSIONS:
|
28 |
+
compatibility_warning += (f'Whisper {whisper_module.__version__} is installed.'
|
29 |
+
f'Versions confirm to be compatible: {", ".join(_COMPATIBLE_WHISPER_VERSIONS)}\n')
|
30 |
+
_is_whisper_repo_version = bool(importlib.metadata.distribution('openai-whisper').read_text('direct_url.json'))
|
31 |
+
if _is_whisper_repo_version:
|
32 |
+
compatibility_warning += ('The detected version appears to be installed from the repository '
|
33 |
+
'which can have compatibility issues '
|
34 |
+
'due to multiple commits sharing the same version number. '
|
35 |
+
f'It is recommended to install version {_required_whisper_ver} from PyPI.\n')
|
36 |
+
|
37 |
+
if compatibility_warning:
|
38 |
+
compatibility_warning = (
|
39 |
+
'The installed version of Whisper might be incompatible.\n'
|
40 |
+
+ compatibility_warning +
|
41 |
+
'To prevent errors and performance issues, reinstall correct version with: '
|
42 |
+
f'"pip install --upgrade --no-deps --force-reinstall openai-whisper=={_required_whisper_ver}".'
|
43 |
+
)
|
44 |
+
if additional_msg:
|
45 |
+
compatibility_warning += f' {additional_msg}'
|
46 |
+
warnings.warn(compatibility_warning)
|
47 |
+
|
48 |
+
|
49 |
+
def get_tokenizer(model=None, is_faster_model: bool = False, **kwargs):
|
50 |
+
"""
|
51 |
+
Backward compatible wrapper of :func:`whisper.tokenizer.get_tokenizer` and
|
52 |
+
:class:`faster_whisper.tokenizer.Tokenizer`.
|
53 |
+
"""
|
54 |
+
if is_faster_model:
|
55 |
+
import faster_whisper.tokenizer
|
56 |
+
tokenizer = faster_whisper.tokenizer.Tokenizer
|
57 |
+
params = get_func_parameters(tokenizer)
|
58 |
+
if model is not None and 'tokenizer' not in kwargs:
|
59 |
+
kwargs['tokenizer'] = model.hf_tokenizer
|
60 |
+
else:
|
61 |
+
tokenizer = whisper.tokenizer.get_tokenizer
|
62 |
+
params = _TOKENIZER_PARAMS
|
63 |
+
if model is not None and 'multilingual' not in kwargs:
|
64 |
+
kwargs['multilingual'] = \
|
65 |
+
(model.is_multilingual if hasattr(model, 'is_multilingual') else model.model.is_multilingual)
|
66 |
+
if 'num_languages' in params:
|
67 |
+
if hasattr(model, 'num_languages'):
|
68 |
+
kwargs['num_languages'] = \
|
69 |
+
(model.num_languages if hasattr(model, 'num_languages') else model.model.num_languages)
|
70 |
+
elif 'num_languages' in kwargs:
|
71 |
+
del kwargs['num_languages']
|
72 |
+
return tokenizer(**kwargs)
|
73 |
+
|
stable_whisper/whisper_word_level.py
ADDED
@@ -0,0 +1,1651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable
|
5 |
+
from types import MethodType
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import whisper
|
9 |
+
from whisper.audio import (
|
10 |
+
SAMPLE_RATE, N_FRAMES, HOP_LENGTH, N_SAMPLES, N_SAMPLES_PER_TOKEN, TOKENS_PER_SECOND, FRAMES_PER_SECOND, N_FFT,
|
11 |
+
pad_or_trim, log_mel_spectrogram
|
12 |
+
)
|
13 |
+
from whisper.utils import exact_div
|
14 |
+
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
15 |
+
from whisper.decoding import DecodingOptions, DecodingResult
|
16 |
+
|
17 |
+
from .audio import prep_audio
|
18 |
+
from .decode import decode_stable
|
19 |
+
from .result import WhisperResult, Segment
|
20 |
+
from .timing import add_word_timestamps_stable
|
21 |
+
from .stabilization import get_vad_silence_func, wav2mask, mask2timing, timing2mask
|
22 |
+
from .non_whisper import transcribe_any
|
23 |
+
from .utils import isolate_useful_options, safe_print
|
24 |
+
from .whisper_compatibility import warn_compatibility_issues, get_tokenizer
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from whisper.model import Whisper
|
28 |
+
|
29 |
+
__all__ = ['modify_model', 'load_model', 'load_faster_whisper']
|
30 |
+
|
31 |
+
warnings.filterwarnings('ignore', module='whisper', message='.*Triton.*', category=UserWarning)
|
32 |
+
|
33 |
+
|
34 |
+
# modified version of whisper.transcribe.transcribe
|
35 |
+
def transcribe_stable(
|
36 |
+
model: "Whisper",
|
37 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
38 |
+
*,
|
39 |
+
verbose: Optional[bool] = False,
|
40 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
41 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
42 |
+
logprob_threshold: Optional[float] = -1.0,
|
43 |
+
no_speech_threshold: Optional[float] = 0.6,
|
44 |
+
condition_on_previous_text: bool = True,
|
45 |
+
initial_prompt: Optional[str] = None,
|
46 |
+
word_timestamps: bool = True,
|
47 |
+
regroup: Union[bool, str] = True,
|
48 |
+
ts_num: int = 0,
|
49 |
+
ts_noise: float = 0.1,
|
50 |
+
suppress_silence: bool = True,
|
51 |
+
suppress_word_ts: bool = True,
|
52 |
+
use_word_position: bool = True,
|
53 |
+
q_levels: int = 20,
|
54 |
+
k_size: int = 5,
|
55 |
+
time_scale: float = None,
|
56 |
+
demucs: Union[bool, torch.nn.Module] = False,
|
57 |
+
demucs_output: str = None,
|
58 |
+
demucs_options: dict = None,
|
59 |
+
vad: bool = False,
|
60 |
+
vad_threshold: float = 0.35,
|
61 |
+
vad_onnx: bool = False,
|
62 |
+
min_word_dur: float = 0.1,
|
63 |
+
nonspeech_error: float = 0.3,
|
64 |
+
only_voice_freq: bool = False,
|
65 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
66 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
67 |
+
mel_first: bool = False,
|
68 |
+
split_callback: Callable = None,
|
69 |
+
suppress_ts_tokens: bool = False,
|
70 |
+
gap_padding: str = ' ...',
|
71 |
+
only_ffmpeg: bool = False,
|
72 |
+
max_instant_words: float = 0.5,
|
73 |
+
avg_prob_threshold: Optional[float] = None,
|
74 |
+
progress_callback: Callable = None,
|
75 |
+
ignore_compatibility: bool = False,
|
76 |
+
**decode_options) \
|
77 |
+
-> WhisperResult:
|
78 |
+
"""
|
79 |
+
Transcribe audio using Whisper.
|
80 |
+
|
81 |
+
This is a modified version of :func:`whisper.transcribe.transcribe` with slightly different decoding logic while
|
82 |
+
allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating
|
83 |
+
voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
|
84 |
+
result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.
|
85 |
+
|
86 |
+
Parameters
|
87 |
+
----------
|
88 |
+
model : whisper.model.Whisper
|
89 |
+
An instance of Whisper ASR model.
|
90 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
91 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
92 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
93 |
+
verbose : bool or None, default False
|
94 |
+
Whether to display the text being decoded to the console.
|
95 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
96 |
+
temperature : float or iterable of float, default (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
97 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
98 |
+
upon failures according to either ``compression_ratio_threshold`` or ``logprob_threshold``.
|
99 |
+
compression_ratio_threshold : float, default 2.4
|
100 |
+
If the gzip compression ratio is above this value, treat as failed.
|
101 |
+
logprob_threshold : float, default -1
|
102 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
103 |
+
no_speech_threshold : float, default 0.6
|
104 |
+
If the no_speech probability is higher than this value AND the average log probability
|
105 |
+
over sampled tokens is below ``logprob_threshold``, consider the segment as silent
|
106 |
+
condition_on_previous_text : bool, default True
|
107 |
+
If ``True``, the previous output of the model is provided as a prompt for the next window;
|
108 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
109 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
110 |
+
initial_prompt : str, optional
|
111 |
+
Text to provide as a prompt for the first window. This can be used to provide, or
|
112 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
113 |
+
to make it more likely to predict those word correctly.
|
114 |
+
word_timestamps : bool, default True
|
115 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
116 |
+
and include the timestamps for each word in each segment.
|
117 |
+
Disabling this will prevent segments from splitting/merging properly.
|
118 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
119 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
120 |
+
Ignored if ``word_timestamps = False``.
|
121 |
+
ts_num : int, default 0, meaning disable this option
|
122 |
+
Number of extra timestamp inferences to perform then use average of these extra timestamps.
|
123 |
+
An experimental option that might hurt performance.
|
124 |
+
ts_noise : float, default 0.1
|
125 |
+
Percentage of noise to add to audio_features to perform inferences for ``ts_num``.
|
126 |
+
suppress_silence : bool, default True
|
127 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
128 |
+
suppress_word_ts : bool, default True
|
129 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
130 |
+
use_word_position : bool, default True
|
131 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
132 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
133 |
+
q_levels : int, default 20
|
134 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
135 |
+
Acts as a threshold to marking sound as silent.
|
136 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
137 |
+
k_size : int, default 5
|
138 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
139 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
140 |
+
time_scale : float, optional
|
141 |
+
Factor for scaling audio duration for inference.
|
142 |
+
Greater than 1.0 'slows down' the audio, and less than 1.0 'speeds up' the audio. None is same as 1.0.
|
143 |
+
A factor of 1.5 will stretch 10s audio to 15s for inference. This increases the effective resolution
|
144 |
+
of the model but can increase word error rate.
|
145 |
+
demucs : bool or torch.nn.Module, default False
|
146 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
147 |
+
a Demucs model to avoid reloading the model for each run.
|
148 |
+
Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
|
149 |
+
demucs_output : str, optional
|
150 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
151 |
+
Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
|
152 |
+
demucs_options : dict, optional
|
153 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
154 |
+
vad : bool, default False
|
155 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
156 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
157 |
+
vad_threshold : float, default 0.35
|
158 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
159 |
+
vad_onnx : bool, default False
|
160 |
+
Whether to use ONNX for Silero VAD.
|
161 |
+
min_word_dur : float, default 0.1
|
162 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
163 |
+
nonspeech_error : float, default 0.3
|
164 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
165 |
+
only_voice_freq : bool, default False
|
166 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
167 |
+
prepend_punctuations : str, default '"\'“¿([{-)'
|
168 |
+
Punctuations to prepend to next word.
|
169 |
+
append_punctuations : str, default '.。,,!!??::”)]}、)'
|
170 |
+
Punctuations to append to previous word.
|
171 |
+
mel_first : bool, default False
|
172 |
+
Process entire audio track into log-Mel spectrogram first instead in chunks.
|
173 |
+
Used if odd behavior seen in stable-ts but not in whisper, but use significantly more memory for long audio.
|
174 |
+
split_callback : Callable, optional
|
175 |
+
Custom callback for grouping tokens up with their corresponding words.
|
176 |
+
The callback must take two arguments, list of tokens and tokenizer.
|
177 |
+
The callback returns a tuple with a list of words and a corresponding nested list of tokens.
|
178 |
+
suppress_ts_tokens : bool, default False
|
179 |
+
Whether to suppress timestamp tokens during inference for timestamps are detected at silent.
|
180 |
+
Reduces hallucinations in some cases, but also prone to ignore disfluencies and repetitions.
|
181 |
+
This option is ignored if ``suppress_silence = False``.
|
182 |
+
gap_padding : str, default ' ...'
|
183 |
+
Padding prepend to each segments for word timing alignment.
|
184 |
+
Used to reduce the probability of model predicting timestamps earlier than the first utterance.
|
185 |
+
only_ffmpeg : bool, default False
|
186 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
187 |
+
max_instant_words : float, default 0.5
|
188 |
+
If percentage of instantaneous words in a segment exceed this amount, the segment is removed.
|
189 |
+
avg_prob_threshold: float or None, default None
|
190 |
+
Transcribe the gap after the previous word and if the average word proababiliy of a segment falls below this
|
191 |
+
value, discard the segment. If ``None``, skip transcribing the gap to reduce chance of timestamps starting
|
192 |
+
before the next utterance.
|
193 |
+
progress_callback : Callable, optional
|
194 |
+
A function that will be called when transcription progress is updated.
|
195 |
+
The callback need two parameters.
|
196 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
197 |
+
The second parameter is a float for total duration of audio in seconds.
|
198 |
+
ignore_compatibility : bool, default False
|
199 |
+
Whether to ignore warnings for compatibility issues with the detected Whisper version.
|
200 |
+
decode_options
|
201 |
+
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
|
202 |
+
|
203 |
+
Returns
|
204 |
+
-------
|
205 |
+
stable_whisper.result.WhisperResult
|
206 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
207 |
+
|
208 |
+
See Also
|
209 |
+
--------
|
210 |
+
stable_whisper.non_whisper.transcribe_any : Return :class:`stable_whisper.result.WhisperResult` containing all the
|
211 |
+
data from transcribing audio with unmodified :func:`whisper.transcribe.transcribe` with preprocessing and
|
212 |
+
postprocessing.
|
213 |
+
stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe : Return
|
214 |
+
:class:`stable_whisper.result.WhisperResult` containing all the data from transcribing audio with
|
215 |
+
:meth:`faster_whisper.WhisperModel.transcribe` with preprocessing and postprocessing.
|
216 |
+
|
217 |
+
Examples
|
218 |
+
--------
|
219 |
+
>>> import stable_whisper
|
220 |
+
>>> model = stable_whisper.load_model('base')
|
221 |
+
>>> result = model.transcribe('audio.mp3', vad=True)
|
222 |
+
>>> result.to_srt_vtt('audio.srt')
|
223 |
+
Saved: audio.srt
|
224 |
+
"""
|
225 |
+
warn_compatibility_issues(whisper, ignore_compatibility, 'Or use transcribe_minimal().')
|
226 |
+
dtype = torch.float16 if decode_options.get("fp16", True) and not getattr(model, 'dq', False) else torch.float32
|
227 |
+
if model.device == torch.device("cpu"):
|
228 |
+
if torch.cuda.is_available():
|
229 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
230 |
+
if dtype == torch.float16:
|
231 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
232 |
+
dtype = torch.float32
|
233 |
+
|
234 |
+
if dtype == torch.float32:
|
235 |
+
decode_options["fp16"] = False
|
236 |
+
|
237 |
+
if 'max_initial_timestamp' not in decode_options:
|
238 |
+
decode_options['max_initial_timestamp'] = None
|
239 |
+
|
240 |
+
device = model.device
|
241 |
+
|
242 |
+
if time_scale:
|
243 |
+
warnings.warn('``time_scale`` is deprecated. It will not affect results.',
|
244 |
+
DeprecationWarning, stacklevel=2)
|
245 |
+
if decode_options.pop('input_sr', None):
|
246 |
+
warnings.warn('``input_sr`` is deprecated. '
|
247 |
+
'``audio`` of types numpy.ndarray and torch.Tensor inputs must be already at 16kHz. '
|
248 |
+
'To higher sample rates for ``audio`` use str or bytes.',
|
249 |
+
DeprecationWarning, stacklevel=2)
|
250 |
+
if not demucs_options:
|
251 |
+
demucs_options = {}
|
252 |
+
if demucs_output:
|
253 |
+
if 'save_path' not in demucs_options:
|
254 |
+
demucs_options['save_path'] = demucs_output
|
255 |
+
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
|
256 |
+
'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
|
257 |
+
DeprecationWarning, stacklevel=2)
|
258 |
+
if 'device' not in demucs_options:
|
259 |
+
demucs_options['device'] = device
|
260 |
+
audio = prep_audio(
|
261 |
+
audio,
|
262 |
+
demucs=demucs,
|
263 |
+
demucs_options=demucs_options,
|
264 |
+
only_voice_freq=only_voice_freq,
|
265 |
+
only_ffmpeg=only_ffmpeg,
|
266 |
+
verbose=verbose
|
267 |
+
)
|
268 |
+
sample_padding = int(N_FFT // 2) + 1
|
269 |
+
whole_mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=sample_padding) if mel_first else None
|
270 |
+
tokenizer = None
|
271 |
+
language = None
|
272 |
+
initial_prompt_tokens = []
|
273 |
+
task = decode_options.get("task", "transcribe")
|
274 |
+
|
275 |
+
def detect_language():
|
276 |
+
nonlocal tokenizer
|
277 |
+
if tokenizer is None:
|
278 |
+
if decode_options.get("language", None) is None and model:
|
279 |
+
if not model.is_multilingual:
|
280 |
+
decode_options["language"] = "en"
|
281 |
+
else:
|
282 |
+
if verbose:
|
283 |
+
print("Detecting language using up to 30 seconds following first non-silent sample. "
|
284 |
+
"Use `--language` to specify the language")
|
285 |
+
timing_mask = None
|
286 |
+
if segment_silence_timing is not None:
|
287 |
+
timing_mask = np.logical_and(
|
288 |
+
segment_silence_timing[0] <= time_offset,
|
289 |
+
segment_silence_timing[1] >= time_offset
|
290 |
+
)
|
291 |
+
start_sample = (
|
292 |
+
None
|
293 |
+
if segment_silence_timing is None or not timing_mask.any() else
|
294 |
+
round(segment_silence_timing[1][timing_mask.nonzero()[0]][0] * SAMPLE_RATE)
|
295 |
+
)
|
296 |
+
if start_sample is None:
|
297 |
+
nonlocal mel_segment
|
298 |
+
curr_mel_segment = mel_segment
|
299 |
+
else:
|
300 |
+
if whole_mel is None:
|
301 |
+
curr_mel_segment = log_mel_spectrogram(
|
302 |
+
audio[..., start_sample:start_sample+N_SAMPLES],
|
303 |
+
model.dims.n_mels,
|
304 |
+
padding=sample_padding
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
start_frame = int(start_sample/HOP_LENGTH)
|
308 |
+
curr_mel_segment = whole_mel[..., start_frame:start_frame+N_FRAMES]
|
309 |
+
curr_mel_segment = pad_or_trim(curr_mel_segment, N_FRAMES).to(device=device, dtype=dtype)
|
310 |
+
_, probs = model.detect_language(curr_mel_segment)
|
311 |
+
decode_options["language"] = max(probs, key=probs.get)
|
312 |
+
if verbose is not None:
|
313 |
+
detected_msg = f"Detected language: {LANGUAGES[decode_options['language']]}"
|
314 |
+
if tqdm_pbar.disable:
|
315 |
+
print(detected_msg)
|
316 |
+
else:
|
317 |
+
tqdm_pbar.write(detected_msg)
|
318 |
+
|
319 |
+
nonlocal language
|
320 |
+
language = decode_options["language"]
|
321 |
+
tokenizer = get_tokenizer(model, language=language, task=task)
|
322 |
+
|
323 |
+
if word_timestamps and task == "translate":
|
324 |
+
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
325 |
+
|
326 |
+
if initial_prompt is not None:
|
327 |
+
nonlocal initial_prompt_tokens
|
328 |
+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
329 |
+
all_tokens.extend(initial_prompt_tokens)
|
330 |
+
|
331 |
+
audio_features = None
|
332 |
+
|
333 |
+
def decode_with_fallback(seg: torch.Tensor,
|
334 |
+
ts_token_mask: torch.Tensor = None) \
|
335 |
+
-> DecodingResult:
|
336 |
+
nonlocal audio_features
|
337 |
+
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
338 |
+
decode_result = None
|
339 |
+
|
340 |
+
for t in temperatures:
|
341 |
+
kwargs = {**decode_options}
|
342 |
+
if t > 0:
|
343 |
+
# disable beam_size and patience when t > 0
|
344 |
+
kwargs.pop("beam_size", None)
|
345 |
+
kwargs.pop("patience", None)
|
346 |
+
else:
|
347 |
+
# disable best_of when t == 0
|
348 |
+
kwargs.pop("best_of", None)
|
349 |
+
|
350 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
351 |
+
decode_result, audio_features = decode_stable(model,
|
352 |
+
seg,
|
353 |
+
options,
|
354 |
+
ts_token_mask=ts_token_mask if suppress_ts_tokens else None,
|
355 |
+
audio_features=audio_features)
|
356 |
+
|
357 |
+
needs_fallback = False
|
358 |
+
if (
|
359 |
+
compression_ratio_threshold is not None
|
360 |
+
and decode_result.compression_ratio > compression_ratio_threshold
|
361 |
+
):
|
362 |
+
needs_fallback = True # too repetitive
|
363 |
+
if (
|
364 |
+
logprob_threshold is not None
|
365 |
+
and decode_result.avg_logprob < logprob_threshold
|
366 |
+
):
|
367 |
+
needs_fallback = True # average log probability is too low
|
368 |
+
if (
|
369 |
+
no_speech_threshold is not None
|
370 |
+
and decode_result.no_speech_prob > no_speech_threshold
|
371 |
+
):
|
372 |
+
needs_fallback = False # silence
|
373 |
+
|
374 |
+
if not needs_fallback:
|
375 |
+
break
|
376 |
+
|
377 |
+
return decode_result
|
378 |
+
|
379 |
+
seek_sample = 0 # samples
|
380 |
+
input_stride = exact_div(
|
381 |
+
N_FRAMES, model.dims.n_audio_ctx
|
382 |
+
) # mel frames per output token: 2
|
383 |
+
time_precision = (
|
384 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
385 |
+
) # time per output token: 0.02 (seconds)
|
386 |
+
all_tokens = []
|
387 |
+
all_segments = []
|
388 |
+
prompt_reset_since = 0
|
389 |
+
|
390 |
+
def new_segment(
|
391 |
+
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
392 |
+
):
|
393 |
+
tokens = tokens.tolist()
|
394 |
+
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
395 |
+
return {
|
396 |
+
"seek": round(seek_sample / SAMPLE_RATE, 3), # units in seconds
|
397 |
+
"start": start,
|
398 |
+
"end": end,
|
399 |
+
"text": tokenizer.decode(text_tokens),
|
400 |
+
"tokens": tokens,
|
401 |
+
"temperature": result.temperature,
|
402 |
+
"avg_logprob": result.avg_logprob,
|
403 |
+
"compression_ratio": result.compression_ratio,
|
404 |
+
"no_speech_prob": result.no_speech_prob,
|
405 |
+
}
|
406 |
+
|
407 |
+
punctuations = prepend_punctuations + append_punctuations
|
408 |
+
|
409 |
+
total_samples = audio.shape[-1]
|
410 |
+
total_duration = round(total_samples / SAMPLE_RATE, 2)
|
411 |
+
n_samples_per_frame = exact_div(N_SAMPLES_PER_TOKEN * TOKENS_PER_SECOND, FRAMES_PER_SECOND)
|
412 |
+
|
413 |
+
silent_timings = [[], []]
|
414 |
+
silence_timing = None
|
415 |
+
if suppress_silence and vad:
|
416 |
+
silence_timing = get_vad_silence_func(onnx=vad_onnx, verbose=verbose)(audio, speech_threshold=vad_threshold)
|
417 |
+
|
418 |
+
with tqdm(total=total_duration, unit='sec', disable=verbose is not False, desc=task.title()) as tqdm_pbar:
|
419 |
+
|
420 |
+
def update_pbar():
|
421 |
+
nonlocal audio_features
|
422 |
+
audio_features = None
|
423 |
+
seek_duration = min(total_duration, round(seek_sample / SAMPLE_RATE, 2))
|
424 |
+
if not tqdm_pbar.disable:
|
425 |
+
tqdm_pbar.update(seek_duration - tqdm_pbar.n)
|
426 |
+
if progress_callback is not None:
|
427 |
+
progress_callback(seek=seek_duration, total=total_duration)
|
428 |
+
|
429 |
+
def update_seek():
|
430 |
+
nonlocal seek_sample
|
431 |
+
seek_sample += segment_samples
|
432 |
+
|
433 |
+
def fast_forward():
|
434 |
+
# fast-forward to the next segment boundary
|
435 |
+
update_seek()
|
436 |
+
update_pbar()
|
437 |
+
|
438 |
+
while seek_sample < audio.shape[-1]:
|
439 |
+
seek_sample_end = seek_sample + N_SAMPLES
|
440 |
+
audio_segment = audio[seek_sample:seek_sample_end]
|
441 |
+
time_offset = seek_sample / SAMPLE_RATE
|
442 |
+
segment_samples = audio_segment.shape[-1]
|
443 |
+
segment_duration = segment_samples / SAMPLE_RATE
|
444 |
+
|
445 |
+
mel_segment = (
|
446 |
+
log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
|
447 |
+
if whole_mel is None else
|
448 |
+
whole_mel[..., round(seek_sample / n_samples_per_frame): round(seek_sample_end / n_samples_per_frame)]
|
449 |
+
)
|
450 |
+
|
451 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(device=model.device, dtype=dtype)
|
452 |
+
|
453 |
+
segment_silence_timing = None
|
454 |
+
ts_token_mask = None
|
455 |
+
if suppress_silence:
|
456 |
+
if silence_timing is None:
|
457 |
+
ts_token_mask = wav2mask(audio_segment, q_levels=q_levels, k_size=k_size)
|
458 |
+
segment_silence_timing = mask2timing(ts_token_mask, time_offset=time_offset)
|
459 |
+
else:
|
460 |
+
timing_indices = np.logical_and(
|
461 |
+
silence_timing[1] > time_offset,
|
462 |
+
silence_timing[0] < time_offset + segment_duration
|
463 |
+
)
|
464 |
+
segment_silence_timing = (silence_timing[0][timing_indices], silence_timing[1][timing_indices])
|
465 |
+
|
466 |
+
ts_token_mask = timing2mask(*segment_silence_timing, size=1501, time_offset=time_offset)
|
467 |
+
|
468 |
+
if mn := timing_indices.argmax():
|
469 |
+
silence_timing = (silence_timing[0][mn:], silence_timing[1][mn:])
|
470 |
+
|
471 |
+
if ts_token_mask is not None:
|
472 |
+
if ts_token_mask.all(): # segment is silent
|
473 |
+
fast_forward()
|
474 |
+
continue
|
475 |
+
ts_token_mask = pad_or_trim(ts_token_mask, 1501)
|
476 |
+
|
477 |
+
detect_language()
|
478 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
479 |
+
result: DecodingResult = decode_with_fallback(mel_segment, ts_token_mask=ts_token_mask)
|
480 |
+
tokens = torch.tensor(result.tokens)
|
481 |
+
|
482 |
+
if no_speech_threshold is not None:
|
483 |
+
# no voice activity check
|
484 |
+
should_skip = result.no_speech_prob > no_speech_threshold
|
485 |
+
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
486 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
487 |
+
should_skip = False
|
488 |
+
|
489 |
+
if should_skip:
|
490 |
+
fast_forward()
|
491 |
+
continue
|
492 |
+
|
493 |
+
current_segments = []
|
494 |
+
|
495 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
496 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
497 |
+
|
498 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
499 |
+
consecutive.add_(1)
|
500 |
+
if len(consecutive) > 0:
|
501 |
+
# if the output contains two consecutive timestamp tokens
|
502 |
+
slices = consecutive.tolist()
|
503 |
+
if single_timestamp_ending:
|
504 |
+
slices.append(len(tokens))
|
505 |
+
|
506 |
+
last_slice = 0
|
507 |
+
for current_slice in slices:
|
508 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
509 |
+
start_timestamp_pos = (
|
510 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
511 |
+
)
|
512 |
+
end_timestamp_pos = (
|
513 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
514 |
+
)
|
515 |
+
current_segments.append(
|
516 |
+
new_segment(
|
517 |
+
start=round(time_offset + start_timestamp_pos * time_precision, 3),
|
518 |
+
end=round(time_offset + min(end_timestamp_pos * time_precision, segment_duration), 3),
|
519 |
+
tokens=sliced_tokens,
|
520 |
+
result=result,
|
521 |
+
)
|
522 |
+
)
|
523 |
+
last_slice = current_slice
|
524 |
+
|
525 |
+
else:
|
526 |
+
duration = segment_duration
|
527 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
528 |
+
if (
|
529 |
+
len(timestamps) > 0
|
530 |
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
531 |
+
):
|
532 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
533 |
+
end_timestamp_pos = (
|
534 |
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
535 |
+
)
|
536 |
+
duration = min(end_timestamp_pos * time_precision, segment_duration)
|
537 |
+
else:
|
538 |
+
end_timestamp_pos = 0
|
539 |
+
|
540 |
+
current_segments.append(
|
541 |
+
new_segment(
|
542 |
+
start=round(time_offset, 3),
|
543 |
+
end=round(time_offset + duration, 3),
|
544 |
+
tokens=tokens,
|
545 |
+
result=result,
|
546 |
+
)
|
547 |
+
)
|
548 |
+
|
549 |
+
# if a segment is instantaneous or does not contain text, remove it
|
550 |
+
for i in reversed(range(len(current_segments))):
|
551 |
+
seg = current_segments[i]
|
552 |
+
if seg["start"] == seg["end"] or seg["text"].strip() in punctuations:
|
553 |
+
del current_segments[i]
|
554 |
+
|
555 |
+
num_samples = (
|
556 |
+
min(round(end_timestamp_pos * N_SAMPLES_PER_TOKEN), segment_samples)
|
557 |
+
if end_timestamp_pos > 0 else
|
558 |
+
segment_samples
|
559 |
+
)
|
560 |
+
|
561 |
+
if word_timestamps:
|
562 |
+
add_word_timestamps_stable(
|
563 |
+
segments=current_segments,
|
564 |
+
model=model,
|
565 |
+
tokenizer=tokenizer,
|
566 |
+
mel=mel_segment,
|
567 |
+
num_samples=num_samples,
|
568 |
+
prepend_punctuations=prepend_punctuations,
|
569 |
+
append_punctuations=append_punctuations,
|
570 |
+
audio_features=audio_features,
|
571 |
+
ts_num=ts_num,
|
572 |
+
ts_noise=ts_noise,
|
573 |
+
split_callback=split_callback,
|
574 |
+
gap_padding=gap_padding
|
575 |
+
)
|
576 |
+
|
577 |
+
# if [max_instant_words] of the words in a segment are instantaneous, remove it
|
578 |
+
for i in reversed(range(len(current_segments))):
|
579 |
+
zero_duration_percent = (
|
580 |
+
np.array(
|
581 |
+
[w['start'] == w['end'] for w in current_segments[i]['words']]
|
582 |
+
)
|
583 |
+
.astype(np.float16)
|
584 |
+
.mean()
|
585 |
+
)
|
586 |
+
if zero_duration_percent > max_instant_words:
|
587 |
+
del current_segments[i]
|
588 |
+
|
589 |
+
if avg_prob_threshold and current_segments:
|
590 |
+
if (
|
591 |
+
single_timestamp_ending and
|
592 |
+
(np.mean([w['probability'] for s in current_segments for w in s['words']]) <
|
593 |
+
avg_prob_threshold)
|
594 |
+
):
|
595 |
+
num_samples = segment_samples
|
596 |
+
current_segments = []
|
597 |
+
else:
|
598 |
+
num_samples = round((current_segments[-1]['words'][-1]['end']-time_offset) * SAMPLE_RATE)
|
599 |
+
|
600 |
+
if len(current_segments) == 0:
|
601 |
+
fast_forward()
|
602 |
+
continue
|
603 |
+
|
604 |
+
if segment_silence_timing is not None:
|
605 |
+
silent_timings[0].extend(segment_silence_timing[0])
|
606 |
+
silent_timings[1].extend(segment_silence_timing[1])
|
607 |
+
for seg_i, segment in enumerate(current_segments):
|
608 |
+
segment = Segment(**segment).suppress_silence(
|
609 |
+
*segment_silence_timing,
|
610 |
+
min_word_dur=min_word_dur,
|
611 |
+
word_level=suppress_word_ts,
|
612 |
+
nonspeech_error=nonspeech_error,
|
613 |
+
use_word_position=use_word_position,
|
614 |
+
)
|
615 |
+
if verbose:
|
616 |
+
safe_print(segment.to_display_str())
|
617 |
+
current_segments[seg_i] = segment.to_dict()
|
618 |
+
|
619 |
+
all_segments.extend(
|
620 |
+
[
|
621 |
+
{"id": i, **segment}
|
622 |
+
for i, segment in enumerate(current_segments, start=len(all_segments))
|
623 |
+
]
|
624 |
+
)
|
625 |
+
all_tokens.extend(
|
626 |
+
[token for segment in current_segments for token in segment["tokens"]]
|
627 |
+
)
|
628 |
+
if not single_timestamp_ending or avg_prob_threshold:
|
629 |
+
segment_samples = num_samples
|
630 |
+
|
631 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
632 |
+
# do not feed the prompt tokens if a high temperature was used
|
633 |
+
prompt_reset_since = len(all_tokens)
|
634 |
+
|
635 |
+
fast_forward()
|
636 |
+
|
637 |
+
# final update
|
638 |
+
update_pbar()
|
639 |
+
|
640 |
+
if model.device != torch.device('cpu'):
|
641 |
+
torch.cuda.empty_cache()
|
642 |
+
|
643 |
+
text = '' if tokenizer is None else tokenizer.decode(all_tokens[len(initial_prompt_tokens):])
|
644 |
+
final_result = WhisperResult(dict(text=text,
|
645 |
+
segments=all_segments,
|
646 |
+
language=language,
|
647 |
+
time_scale=time_scale))
|
648 |
+
if word_timestamps and regroup:
|
649 |
+
final_result.regroup(regroup)
|
650 |
+
|
651 |
+
if time_scale is not None:
|
652 |
+
final_result.rescale_time(1 / time_scale)
|
653 |
+
|
654 |
+
if len(final_result.text) == 0:
|
655 |
+
warnings.warn(f'Failed to {task} audio. Result contains no text. ')
|
656 |
+
|
657 |
+
final_result.update_nonspeech_sections(*silent_timings)
|
658 |
+
|
659 |
+
return final_result
|
660 |
+
|
661 |
+
|
662 |
+
def transcribe_minimal(
|
663 |
+
model: "Whisper",
|
664 |
+
audio: Union[str, np.ndarray, torch.Tensor, bytes],
|
665 |
+
*,
|
666 |
+
verbose: Optional[bool] = False,
|
667 |
+
word_timestamps: bool = True,
|
668 |
+
regroup: Union[bool, str] = True,
|
669 |
+
suppress_silence: bool = True,
|
670 |
+
suppress_word_ts: bool = True,
|
671 |
+
use_word_position: bool = True,
|
672 |
+
q_levels: int = 20,
|
673 |
+
k_size: int = 5,
|
674 |
+
demucs: bool = False,
|
675 |
+
demucs_output: str = None,
|
676 |
+
demucs_options: dict = None,
|
677 |
+
vad: bool = False,
|
678 |
+
vad_threshold: float = 0.35,
|
679 |
+
vad_onnx: bool = False,
|
680 |
+
min_word_dur: float = 0.1,
|
681 |
+
nonspeech_error: float = 0.3,
|
682 |
+
only_voice_freq: bool = False,
|
683 |
+
only_ffmpeg: bool = False,
|
684 |
+
**options) \
|
685 |
+
-> WhisperResult:
|
686 |
+
"""
|
687 |
+
Transcribe audio using Whisper.
|
688 |
+
|
689 |
+
This is uses the original whisper transcribe function, :func:`whisper.transcribe.transcribe`, while still allowing
|
690 |
+
additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating voice /
|
691 |
+
removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
|
692 |
+
result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.
|
693 |
+
|
694 |
+
Parameters
|
695 |
+
----------
|
696 |
+
model : whisper.model.Whisper
|
697 |
+
An instance of Whisper ASR model.
|
698 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
699 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
700 |
+
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
|
701 |
+
verbose : bool or None, default False
|
702 |
+
Whether to display the text being decoded to the console.
|
703 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
704 |
+
word_timestamps : bool, default True
|
705 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
706 |
+
and include the timestamps for each word in each segment.
|
707 |
+
Disabling this will prevent segments from splitting/merging properly.
|
708 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
709 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
710 |
+
Ignored if ``word_timestamps = False``.
|
711 |
+
suppress_silence : bool, default True
|
712 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
713 |
+
suppress_word_ts : bool, default True
|
714 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
715 |
+
use_word_position : bool, default True
|
716 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
717 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
718 |
+
q_levels : int, default 20
|
719 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
720 |
+
Acts as a threshold to marking sound as silent.
|
721 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
722 |
+
k_size : int, default 5
|
723 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
724 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
725 |
+
demucs : bool or torch.nn.Module, default False
|
726 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
|
727 |
+
a Demucs model to avoid reloading the model for each run.
|
728 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
729 |
+
demucs_output : str, optional
|
730 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
731 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
732 |
+
demucs_options : dict, optional
|
733 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
734 |
+
vad : bool, default False
|
735 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
736 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
737 |
+
vad_threshold : float, default 0.35
|
738 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
739 |
+
vad_onnx : bool, default False
|
740 |
+
Whether to use ONNX for Silero VAD.
|
741 |
+
min_word_dur : float, default 0.1
|
742 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
743 |
+
nonspeech_error : float, default 0.3
|
744 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
745 |
+
only_voice_freq : bool, default False
|
746 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
747 |
+
only_ffmpeg : bool, default False
|
748 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
749 |
+
options
|
750 |
+
Additional options used for :func:`whisper.transcribe.transcribe` and
|
751 |
+
:func:`stable_whisper.non_whisper.transcribe_any`.
|
752 |
+
Returns
|
753 |
+
-------
|
754 |
+
stable_whisper.result.WhisperResult
|
755 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
756 |
+
|
757 |
+
Examples
|
758 |
+
--------
|
759 |
+
>>> import stable_whisper
|
760 |
+
>>> model = stable_whisper.load_model('base')
|
761 |
+
>>> result = model.transcribe_minimal('audio.mp3', vad=True)
|
762 |
+
>>> result.to_srt_vtt('audio.srt')
|
763 |
+
Saved: audio.srt
|
764 |
+
"""
|
765 |
+
inference_kwargs = dict(
|
766 |
+
model=model,
|
767 |
+
audio=audio,
|
768 |
+
word_timestamps=word_timestamps,
|
769 |
+
verbose=verbose
|
770 |
+
)
|
771 |
+
extra_options = isolate_useful_options(options, transcribe_any, True)
|
772 |
+
if demucs or only_voice_freq:
|
773 |
+
if 'audio_type' not in extra_options:
|
774 |
+
extra_options['audio_type'] = 'torch'
|
775 |
+
if 'model_sr' not in extra_options:
|
776 |
+
extra_options['model_sr'] = SAMPLE_RATE
|
777 |
+
inference_kwargs.update(options)
|
778 |
+
return transcribe_any(
|
779 |
+
inference_func=whisper.transcribe,
|
780 |
+
audio=audio,
|
781 |
+
inference_kwargs=inference_kwargs,
|
782 |
+
verbose=verbose,
|
783 |
+
regroup=regroup,
|
784 |
+
suppress_silence=suppress_silence,
|
785 |
+
suppress_word_ts=suppress_word_ts,
|
786 |
+
q_levels=q_levels,
|
787 |
+
k_size=k_size,
|
788 |
+
demucs=demucs,
|
789 |
+
demucs_output=demucs_output,
|
790 |
+
demucs_options=demucs_options,
|
791 |
+
vad=vad,
|
792 |
+
vad_threshold=vad_threshold,
|
793 |
+
vad_onnx=vad_onnx,
|
794 |
+
min_word_dur=min_word_dur,
|
795 |
+
nonspeech_error=nonspeech_error,
|
796 |
+
use_word_position=use_word_position,
|
797 |
+
only_voice_freq=only_voice_freq,
|
798 |
+
only_ffmpeg=only_ffmpeg,
|
799 |
+
force_order=True,
|
800 |
+
**extra_options
|
801 |
+
)
|
802 |
+
|
803 |
+
|
804 |
+
def load_faster_whisper(model_size_or_path: str, **model_init_options):
|
805 |
+
"""
|
806 |
+
Load an instance of :class:`faster_whisper.WhisperModel`.
|
807 |
+
|
808 |
+
Parameters
|
809 |
+
----------
|
810 |
+
model_size_or_path : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
|
811 |
+
'large-v2', 'large-v3', or 'large'}
|
812 |
+
Size of the model.
|
813 |
+
|
814 |
+
model_init_options
|
815 |
+
Additional options to use for initialization of :class:`faster_whisper.WhisperModel`.
|
816 |
+
|
817 |
+
Returns
|
818 |
+
-------
|
819 |
+
faster_whisper.WhisperModel
|
820 |
+
A modified instance with :func:`stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe`
|
821 |
+
assigned to :meth:`faster_whisper.WhisperModel.transcribe_stable`.
|
822 |
+
"""
|
823 |
+
from faster_whisper import WhisperModel
|
824 |
+
faster_model = WhisperModel(model_size_or_path, **model_init_options)
|
825 |
+
|
826 |
+
def _inner_transcribe(model, audio, verbose, **faster_transcribe_options):
|
827 |
+
if isinstance(audio, bytes):
|
828 |
+
import io
|
829 |
+
audio = io.BytesIO(audio)
|
830 |
+
progress_callback = faster_transcribe_options.pop('progress_callback', None)
|
831 |
+
segments, info = model.transcribe(audio, **faster_transcribe_options)
|
832 |
+
language = LANGUAGES.get(info.language, info.language)
|
833 |
+
if verbose is not None:
|
834 |
+
print(f'Detected Language: {language}')
|
835 |
+
print(f'Transcribing with faster-whisper ({model_size_or_path})...\r', end='')
|
836 |
+
|
837 |
+
final_segments = []
|
838 |
+
task = faster_transcribe_options.get('task', 'transcribe').title()
|
839 |
+
total_duration = round(info.duration, 2)
|
840 |
+
|
841 |
+
with tqdm(total=total_duration, unit='sec', disable=verbose is not False, desc=task) as tqdm_pbar:
|
842 |
+
|
843 |
+
def update_pbar(seek):
|
844 |
+
tqdm_pbar.update(seek - tqdm_pbar.n)
|
845 |
+
if progress_callback is not None:
|
846 |
+
progress_callback(seek, total_duration)
|
847 |
+
|
848 |
+
for segment in segments:
|
849 |
+
segment = segment._asdict()
|
850 |
+
if (words := segment.get('words')) is not None:
|
851 |
+
segment['words'] = [w._asdict() for w in words]
|
852 |
+
else:
|
853 |
+
del segment['words']
|
854 |
+
if verbose:
|
855 |
+
safe_print(Segment(**segment).to_display_str())
|
856 |
+
final_segments.append(segment)
|
857 |
+
update_pbar(segment["end"])
|
858 |
+
update_pbar(tqdm_pbar.total)
|
859 |
+
|
860 |
+
if verbose:
|
861 |
+
print(f'Completed transcription with faster-whisper ({model_size_or_path}).')
|
862 |
+
|
863 |
+
return dict(language=language, segments=final_segments)
|
864 |
+
|
865 |
+
def faster_transcribe(
|
866 |
+
model: WhisperModel,
|
867 |
+
audio: Union[str, bytes, np.ndarray],
|
868 |
+
*,
|
869 |
+
word_timestamps: bool = True,
|
870 |
+
verbose: Optional[bool] = False,
|
871 |
+
regroup: Union[bool, str] = True,
|
872 |
+
suppress_silence: bool = True,
|
873 |
+
suppress_word_ts: bool = True,
|
874 |
+
use_word_position: bool = True,
|
875 |
+
q_levels: int = 20,
|
876 |
+
k_size: int = 5,
|
877 |
+
demucs: bool = False,
|
878 |
+
demucs_output: str = None,
|
879 |
+
demucs_options: dict = None,
|
880 |
+
vad: bool = False,
|
881 |
+
vad_threshold: float = 0.35,
|
882 |
+
vad_onnx: bool = False,
|
883 |
+
min_word_dur: float = 0.1,
|
884 |
+
nonspeech_error: float = 0.3,
|
885 |
+
only_voice_freq: bool = False,
|
886 |
+
only_ffmpeg: bool = False,
|
887 |
+
check_sorted: bool = True,
|
888 |
+
progress_callback: Callable = None,
|
889 |
+
**options
|
890 |
+
) -> WhisperResult:
|
891 |
+
"""
|
892 |
+
Transcribe audio using faster-whisper (https://github.com/guillaumekln/faster-whisper).
|
893 |
+
|
894 |
+
This is uses the transcribe method from faster-whisper, :meth:`faster_whisper.WhisperModel.transcribe`, while
|
895 |
+
still allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes:
|
896 |
+
isolating voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the
|
897 |
+
transcription result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation
|
898 |
+
and speech gaps.
|
899 |
+
|
900 |
+
Parameters
|
901 |
+
----------
|
902 |
+
model : faster_whisper.WhisperModel
|
903 |
+
The faster-whisper ASR model instance.
|
904 |
+
audio : str or numpy.ndarray or torch.Tensor or bytes
|
905 |
+
Path/URL to the audio file, the audio waveform, or bytes of audio file.
|
906 |
+
If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
|
907 |
+
verbose : bool or None, default False
|
908 |
+
Whether to display the text being decoded to the console.
|
909 |
+
Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
|
910 |
+
word_timestamps : bool, default True
|
911 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
912 |
+
and include the timestamps for each word in each segment.
|
913 |
+
Disabling this will prevent segments from splitting/merging properly.
|
914 |
+
regroup : bool or str, default True, meaning the default regroup algorithm
|
915 |
+
String for customizing the regrouping algorithm. False disables regrouping.
|
916 |
+
Ignored if ``word_timestamps = False``.
|
917 |
+
suppress_silence : bool, default True
|
918 |
+
Whether to enable timestamps adjustments based on the detected silence.
|
919 |
+
suppress_word_ts : bool, default True
|
920 |
+
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
|
921 |
+
use_word_position : bool, default True
|
922 |
+
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
|
923 |
+
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
|
924 |
+
q_levels : int, default 20
|
925 |
+
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
|
926 |
+
Acts as a threshold to marking sound as silent.
|
927 |
+
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
|
928 |
+
k_size : int, default 5
|
929 |
+
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
|
930 |
+
Recommend 5 or 3; higher sizes will reduce detection of silence.
|
931 |
+
demucs : bool or torch.nn.Module, default False
|
932 |
+
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance
|
933 |
+
of a Demucs model to avoid reloading the model for each run.
|
934 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
935 |
+
demucs_output : str, optional
|
936 |
+
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
|
937 |
+
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
|
938 |
+
demucs_options : dict, optional
|
939 |
+
Options to use for :func:`stable_whisper.audio.demucs_audio`.
|
940 |
+
vad : bool, default False
|
941 |
+
Whether to use Silero VAD to generate timestamp suppression mask.
|
942 |
+
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
|
943 |
+
vad_threshold : float, default 0.35
|
944 |
+
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
|
945 |
+
vad_onnx : bool, default False
|
946 |
+
Whether to use ONNX for Silero VAD.
|
947 |
+
min_word_dur : float, default 0.1
|
948 |
+
Shortest duration each word is allowed to reach for silence suppression.
|
949 |
+
nonspeech_error : float, default 0.3
|
950 |
+
Relative error of non-speech sections that appear in between a word for silence suppression.
|
951 |
+
only_voice_freq : bool, default False
|
952 |
+
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
|
953 |
+
only_ffmpeg : bool, default False
|
954 |
+
Whether to use only FFmpeg (instead of not yt-dlp) for URls
|
955 |
+
check_sorted : bool, default True
|
956 |
+
Whether to raise an error when timestamps returned by faster-whipser are not in ascending order.
|
957 |
+
progress_callback : Callable, optional
|
958 |
+
A function that will be called when transcription progress is updated.
|
959 |
+
The callback need two parameters.
|
960 |
+
The first parameter is a float for seconds of the audio that has been transcribed.
|
961 |
+
The second parameter is a float for total duration of audio in seconds.
|
962 |
+
options
|
963 |
+
Additional options used for :meth:`faster_whisper.WhisperModel.transcribe` and
|
964 |
+
:func:`stable_whisper.non_whisper.transcribe_any`.
|
965 |
+
|
966 |
+
Returns
|
967 |
+
-------
|
968 |
+
stable_whisper.result.WhisperResult
|
969 |
+
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
|
970 |
+
|
971 |
+
Examples
|
972 |
+
--------
|
973 |
+
>>> import stable_whisper
|
974 |
+
>>> model = stable_whisper.load_faster_whisper('base')
|
975 |
+
>>> result = model.transcribe_stable('audio.mp3', vad=True)
|
976 |
+
>>> result.to_srt_vtt('audio.srt')
|
977 |
+
Saved: audio.srt
|
978 |
+
"""
|
979 |
+
extra_options = isolate_useful_options(options, transcribe_any, pop=True)
|
980 |
+
if demucs or only_voice_freq:
|
981 |
+
if 'audio_type' not in extra_options:
|
982 |
+
extra_options['audio_type'] = 'numpy'
|
983 |
+
if 'model_sr' not in extra_options:
|
984 |
+
extra_options['model_sr'] = SAMPLE_RATE
|
985 |
+
faster_whisper_options = options
|
986 |
+
faster_whisper_options['model'] = model
|
987 |
+
faster_whisper_options['audio'] = audio
|
988 |
+
faster_whisper_options['word_timestamps'] = word_timestamps
|
989 |
+
faster_whisper_options['verbose'] = verbose
|
990 |
+
faster_whisper_options['progress_callback'] = progress_callback
|
991 |
+
if not demucs_options:
|
992 |
+
demucs_options = {}
|
993 |
+
if demucs_output:
|
994 |
+
if 'save_path' not in demucs_options:
|
995 |
+
demucs_options['save_path'] = demucs_output
|
996 |
+
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
|
997 |
+
'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
|
998 |
+
DeprecationWarning, stacklevel=2)
|
999 |
+
|
1000 |
+
return transcribe_any(
|
1001 |
+
inference_func=_inner_transcribe,
|
1002 |
+
audio=audio,
|
1003 |
+
inference_kwargs=faster_whisper_options,
|
1004 |
+
verbose=verbose,
|
1005 |
+
regroup=regroup,
|
1006 |
+
suppress_silence=suppress_silence,
|
1007 |
+
suppress_word_ts=suppress_word_ts,
|
1008 |
+
q_levels=q_levels,
|
1009 |
+
k_size=k_size,
|
1010 |
+
demucs=demucs,
|
1011 |
+
demucs_options=demucs_options,
|
1012 |
+
vad=vad,
|
1013 |
+
vad_threshold=vad_threshold,
|
1014 |
+
vad_onnx=vad_onnx,
|
1015 |
+
min_word_dur=min_word_dur,
|
1016 |
+
nonspeech_error=nonspeech_error,
|
1017 |
+
use_word_position=use_word_position,
|
1018 |
+
only_voice_freq=only_voice_freq,
|
1019 |
+
only_ffmpeg=only_ffmpeg,
|
1020 |
+
force_order=True,
|
1021 |
+
check_sorted=check_sorted,
|
1022 |
+
**extra_options
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
faster_model.transcribe_stable = MethodType(faster_transcribe, faster_model)
|
1026 |
+
from .alignment import align
|
1027 |
+
faster_model.align = MethodType(align, faster_model)
|
1028 |
+
|
1029 |
+
return faster_model
|
1030 |
+
|
1031 |
+
|
1032 |
+
def modify_model(model: "Whisper"):
|
1033 |
+
"""
|
1034 |
+
Modify an instance if :class:`whisper.model.Whisper`.
|
1035 |
+
|
1036 |
+
The following are performed:
|
1037 |
+
-replace :meth:`whisper.model.Whisper.transcribe` with :func:`stable_whisper.whisper_word_level.transcribe_stable`
|
1038 |
+
-assign :meth:`whisper.model.transcribe_minimal` to :func:`stable_whisper.whisper_word_level.transcribe_minimal`
|
1039 |
+
-assign :meth:`whisper.model.Whisper.transcribe_original` to :meth:`whisper.model.Whisper.transcribe`
|
1040 |
+
-assign :meth:`whisper.model.Whisper.align` to :func:`stable_whisper.alignment.align`
|
1041 |
+
-assign :meth:`whisper.model.Whisper.locate` to :func:`stable_whisper.alignment.locate`
|
1042 |
+
"""
|
1043 |
+
model.transcribe = MethodType(transcribe_stable, model)
|
1044 |
+
model.transcribe_minimal = MethodType(transcribe_minimal, model)
|
1045 |
+
model.transcribe_original = MethodType(whisper.transcribe, model)
|
1046 |
+
from .alignment import align, refine, locate
|
1047 |
+
model.align = MethodType(align, model)
|
1048 |
+
model.refine = MethodType(refine, model)
|
1049 |
+
model.locate = MethodType(locate, model)
|
1050 |
+
|
1051 |
+
|
1052 |
+
# modified version of whisper.load_model
|
1053 |
+
def load_model(name: str, device: Optional[Union[str, torch.device]] = None,
|
1054 |
+
download_root: str = None, in_memory: bool = False,
|
1055 |
+
cpu_preload: bool = True, dq: bool = False) -> "Whisper":
|
1056 |
+
"""
|
1057 |
+
Load an instance if :class:`whisper.model.Whisper`.
|
1058 |
+
|
1059 |
+
Parameters
|
1060 |
+
----------
|
1061 |
+
name : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
|
1062 |
+
'large-v2', 'large-v3', or 'large'}
|
1063 |
+
One of the official model names listed by :func:`whisper.available_models`, or
|
1064 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
1065 |
+
device : str or torch.device, optional
|
1066 |
+
PyTorch device to put the model into.
|
1067 |
+
download_root : str, optional
|
1068 |
+
Path to download the model files; by default, it uses "~/.cache/whisper".
|
1069 |
+
in_memory : bool, default False
|
1070 |
+
Whether to preload the model weights into host memory.
|
1071 |
+
cpu_preload : bool, default True
|
1072 |
+
Load model into CPU memory first then move model to specified device
|
1073 |
+
to reduce GPU memory usage when loading model
|
1074 |
+
dq : bool, default False
|
1075 |
+
Whether to apply Dynamic Quantization to model to reduced memory usage and increase inference speed
|
1076 |
+
but at the cost of a slight decrease in accuracy. Only for CPU.
|
1077 |
+
|
1078 |
+
Returns
|
1079 |
+
-------
|
1080 |
+
model : "Whisper"
|
1081 |
+
The Whisper ASR model instance.
|
1082 |
+
|
1083 |
+
Notes
|
1084 |
+
-----
|
1085 |
+
The overhead from ``dq = True`` might make inference slower for models smaller than 'large'.
|
1086 |
+
"""
|
1087 |
+
if device is None or dq:
|
1088 |
+
device = "cuda" if torch.cuda.is_available() and not dq else "cpu"
|
1089 |
+
if cpu_preload:
|
1090 |
+
model = whisper.load_model(name, device='cpu', download_root=download_root, in_memory=in_memory)
|
1091 |
+
cuda_index = None
|
1092 |
+
if isinstance(device, str) and device.startswith('cuda'):
|
1093 |
+
try:
|
1094 |
+
cuda_index = [] if device == 'cuda' else [int(device.split(':')[-1])]
|
1095 |
+
except ValueError:
|
1096 |
+
pass
|
1097 |
+
model = model.to(device=device) if cuda_index is None else model.cuda(*cuda_index)
|
1098 |
+
else:
|
1099 |
+
model = whisper.load_model(name, device=device, download_root=download_root, in_memory=in_memory)
|
1100 |
+
modify_model(model)
|
1101 |
+
if dq:
|
1102 |
+
from .quantization import ptdq_linear
|
1103 |
+
ptdq_linear(model)
|
1104 |
+
return model
|
1105 |
+
|
1106 |
+
|
1107 |
+
# modified version of whisper.transcribe.cli
|
1108 |
+
def cli():
|
1109 |
+
import argparse
|
1110 |
+
import os
|
1111 |
+
from os.path import splitext, split, isfile, join
|
1112 |
+
from whisper import available_models
|
1113 |
+
from whisper.utils import optional_int, optional_float
|
1114 |
+
from .utils import str_to_valid_type, get_func_parameters
|
1115 |
+
|
1116 |
+
str2val = {"true": True, "false": False, "1": True, "0": False}
|
1117 |
+
|
1118 |
+
def str2bool(string: str) -> bool:
|
1119 |
+
string = string.lower()
|
1120 |
+
if string in str2val:
|
1121 |
+
return str2val[string]
|
1122 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
1123 |
+
|
1124 |
+
def valid_model_name(name):
|
1125 |
+
if name in available_models() or os.path.exists(name):
|
1126 |
+
return name
|
1127 |
+
raise ValueError(
|
1128 |
+
f"model should be one of {available_models()} or path to a model checkpoint"
|
1129 |
+
)
|
1130 |
+
|
1131 |
+
def update_options_with_args(arg_key: str, options: Optional[dict] = None, pop: bool = False):
|
1132 |
+
extra_options = args.pop(arg_key) if pop else args.get(arg_key)
|
1133 |
+
if not extra_options:
|
1134 |
+
return
|
1135 |
+
extra_options = [kv.split('=', maxsplit=1) for kv in extra_options]
|
1136 |
+
missing_val = [kv[0] for kv in extra_options if len(kv) == 1]
|
1137 |
+
if missing_val:
|
1138 |
+
raise ValueError(f'Following expected values for the following custom options: {missing_val}')
|
1139 |
+
extra_options = dict((k, str_to_valid_type(v)) for k, v in extra_options)
|
1140 |
+
if options is None:
|
1141 |
+
return extra_options
|
1142 |
+
options.update(extra_options)
|
1143 |
+
|
1144 |
+
OUTPUT_FORMATS_METHODS = {
|
1145 |
+
"srt": "to_srt_vtt",
|
1146 |
+
"ass": "to_ass",
|
1147 |
+
"json": "save_as_json",
|
1148 |
+
"vtt": "to_srt_vtt",
|
1149 |
+
"tsv": "to_tsv",
|
1150 |
+
"txt": "to_txt",
|
1151 |
+
}
|
1152 |
+
|
1153 |
+
OUTPUT_FORMATS = set(OUTPUT_FORMATS_METHODS.keys())
|
1154 |
+
|
1155 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
1156 |
+
parser.add_argument("inputs", nargs="+", type=str,
|
1157 |
+
help="audio/video filepath/URL(s) to transcribe "
|
1158 |
+
"or json file(s) to process into [output_format]")
|
1159 |
+
parser.add_argument("--output", "-o", action="extend", nargs="+", type=str,
|
1160 |
+
help="output filepaths(s);"
|
1161 |
+
"if not specified, auto-named output file(s) will be saved to "
|
1162 |
+
"[output_dir] or current dir if not specified.")
|
1163 |
+
parser.add_argument("--model", '-m', default="base", type=valid_model_name,
|
1164 |
+
help="name of the Whisper model to use")
|
1165 |
+
parser.add_argument("--model_dir", type=str, default=None,
|
1166 |
+
help="the path to save model files; uses ~/.cache/whisper by default")
|
1167 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu",
|
1168 |
+
help="device to use for PyTorch inference")
|
1169 |
+
parser.add_argument("--cpu_preload", type=str2bool, default=True,
|
1170 |
+
help="load model into CPU memory first then move model to specified device; "
|
1171 |
+
"this reduces GPU memory usage when loading model.")
|
1172 |
+
parser.add_argument("--output_dir", "-d", type=str,
|
1173 |
+
help="directory to save the outputs;"
|
1174 |
+
"if a path in [output] does not have parent, that output will be save to this directory")
|
1175 |
+
parser.add_argument("--output_format", "-f", type=str,
|
1176 |
+
help="format of the output file(s); "
|
1177 |
+
f"Supported Formats: {OUTPUT_FORMATS}; "
|
1178 |
+
"use ',' to separate multiple formats")
|
1179 |
+
parser.add_argument("--verbose", '-v', type=int, default=1, choices=(0, 1, 2),
|
1180 |
+
help="whether to display the text being decoded to the console; "
|
1181 |
+
"if 2, display all the details; "
|
1182 |
+
"if 1, display progressbar; "
|
1183 |
+
"if 0, display nothing")
|
1184 |
+
|
1185 |
+
parser.add_argument("--dynamic_quantization", "-dq", action='store_true',
|
1186 |
+
help="whether to apply Dynamic Quantization to model "
|
1187 |
+
"to reduced memory usage (~half less) and increase inference speed "
|
1188 |
+
"at cost of slight decrease in accuracy; Only for CPU; "
|
1189 |
+
"NOTE: overhead might make inference slower for models smaller than 'large'")
|
1190 |
+
|
1191 |
+
parser.add_argument("--task", type=str, default="transcribe",
|
1192 |
+
choices=["transcribe", "translate"],
|
1193 |
+
help="whether to perform X->X speech recognition ('transcribe') "
|
1194 |
+
"or X->English translation ('translate')")
|
1195 |
+
parser.add_argument("--language", '-l', type=str, default=None,
|
1196 |
+
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
1197 |
+
help="language spoken in the audio, specify None to perform language detection")
|
1198 |
+
|
1199 |
+
parser.add_argument("--prepend_punctuations", '-pp', type=str, default="\"'“¿([{-",
|
1200 |
+
help="Punctuations to prepend to next word")
|
1201 |
+
parser.add_argument("--append_punctuations", '-ap', type=str, default="\"'.。,,!!??::”)]}、",
|
1202 |
+
help="Punctuations to append to previous word")
|
1203 |
+
|
1204 |
+
parser.add_argument("--gap_padding", type=str, default=" ...",
|
1205 |
+
help="padding prepend to each segments for word timing alignment;"
|
1206 |
+
"used to reduce the probability of model predicting timestamps "
|
1207 |
+
"earlier than the first utterance")
|
1208 |
+
|
1209 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=True,
|
1210 |
+
help="extract word-level timestamps using the cross-attention pattern and dynamic time warping,"
|
1211 |
+
"and include the timestamps for each word in each segment;"
|
1212 |
+
"disabling this will prevent segments from splitting/merging properly.")
|
1213 |
+
|
1214 |
+
parser.add_argument("--regroup", type=str, default="True",
|
1215 |
+
help="whether to regroup all words into segments with more natural boundaries;"
|
1216 |
+
"specify string for customizing the regrouping algorithm"
|
1217 |
+
"ignored if [word_timestamps]=False.")
|
1218 |
+
|
1219 |
+
parser.add_argument('--ts_num', type=int, default=0,
|
1220 |
+
help="number of extra inferences to perform to find the mean timestamps")
|
1221 |
+
parser.add_argument('--ts_noise', type=float, default=0.1,
|
1222 |
+
help="percentage of noise to add to audio_features to perform inferences for [ts_num]")
|
1223 |
+
|
1224 |
+
parser.add_argument('--suppress_silence', type=str2bool, default=True,
|
1225 |
+
help="whether to suppress timestamp where audio is silent at segment-level"
|
1226 |
+
"and word-level if [suppress_word_ts]=True")
|
1227 |
+
parser.add_argument('--suppress_word_ts', type=str2bool, default=True,
|
1228 |
+
help="whether to suppress timestamps where audio is silent at word-level; "
|
1229 |
+
"ignored if [suppress_silence]=False")
|
1230 |
+
|
1231 |
+
parser.add_argument('--suppress_ts_tokens', type=str2bool, default=False,
|
1232 |
+
help="whether to use silence mask to suppress silent timestamp tokens during inference; "
|
1233 |
+
"increases word accuracy in some cases, but tends reduce 'verbatimness' of the transcript"
|
1234 |
+
"ignored if [suppress_silence]=False")
|
1235 |
+
|
1236 |
+
parser.add_argument("--q_levels", type=int, default=20,
|
1237 |
+
help="quantization levels for generating timestamp suppression mask; "
|
1238 |
+
"acts as a threshold to marking sound as silent;"
|
1239 |
+
"fewer levels will increase the threshold of volume at which to mark a sound as silent")
|
1240 |
+
|
1241 |
+
parser.add_argument("--k_size", type=int, default=5,
|
1242 |
+
help="Kernel size for average pooling waveform to generate suppression mask; "
|
1243 |
+
"recommend 5 or 3; higher sizes will reduce detection of silence")
|
1244 |
+
|
1245 |
+
parser.add_argument('--time_scale', type=float,
|
1246 |
+
help="factor for scaling audio duration for inference;"
|
1247 |
+
"greater than 1.0 'slows down' the audio; "
|
1248 |
+
"less than 1.0 'speeds up' the audio; "
|
1249 |
+
"1.0 is no scaling")
|
1250 |
+
|
1251 |
+
parser.add_argument('--vad', type=str2bool, default=False,
|
1252 |
+
help='whether to use Silero VAD to generate timestamp suppression mask; '
|
1253 |
+
'Silero VAD requires PyTorch 1.12.0+;'
|
1254 |
+
'Official repo: https://github.com/snakers4/silero-vad')
|
1255 |
+
parser.add_argument('--vad_threshold', type=float, default=0.35,
|
1256 |
+
help='threshold for detecting speech with Silero VAD. (Default: 0.35); '
|
1257 |
+
'low threshold reduces false positives for silence detection')
|
1258 |
+
parser.add_argument('--vad_onnx', type=str2bool, default=False,
|
1259 |
+
help='whether to use ONNX for Silero VAD')
|
1260 |
+
|
1261 |
+
parser.add_argument('--min_word_dur', type=float, default=0.1,
|
1262 |
+
help="shortest duration each word is allowed to reach for silence suppression")
|
1263 |
+
parser.add_argument('--nonspeech_error', type=float, default=0.3,
|
1264 |
+
help="relative error of non-speech sections that appear in between a word for "
|
1265 |
+
"silence suppression.")
|
1266 |
+
|
1267 |
+
parser.add_argument('--max_chars', type=int,
|
1268 |
+
help="maximum number of character allowed in each segment")
|
1269 |
+
parser.add_argument('--max_words', type=int,
|
1270 |
+
help="maximum number of words allowed in each segment")
|
1271 |
+
|
1272 |
+
parser.add_argument('--demucs', type=str2bool, default=False,
|
1273 |
+
help='whether to reprocess the audio track with Demucs to isolate vocals/remove noise; '
|
1274 |
+
'Demucs official repo: https://github.com/facebookresearch/demucs')
|
1275 |
+
parser.add_argument('--demucs_output', action="extend", nargs="+", type=str,
|
1276 |
+
help='path(s) to save the vocals isolated by Demucs as WAV file(s); '
|
1277 |
+
'ignored if [demucs]=False')
|
1278 |
+
parser.add_argument('--only_voice_freq', '-ovf', action='store_true',
|
1279 |
+
help='whether to only use sound between 200 - 5000 Hz, where majority of human speech are.')
|
1280 |
+
|
1281 |
+
parser.add_argument('--strip', type=str2bool, default=True,
|
1282 |
+
help="whether to remove spaces before and after text on each segment for output")
|
1283 |
+
|
1284 |
+
parser.add_argument('--tag', type=str, action="extend", nargs="+",
|
1285 |
+
help="a pair tags used to change the properties a word at its predicted time"
|
1286 |
+
"SRT Default: '<font color=\"#00ff00\">', '</font>'"
|
1287 |
+
"VTT Default: '<u>', '</u>'"
|
1288 |
+
"ASS Default: '{\\1c&HFF00&}', '{\\r}'")
|
1289 |
+
parser.add_argument('--segment_level', type=str2bool, default=True,
|
1290 |
+
help="whether to use segment-level timestamps in output")
|
1291 |
+
parser.add_argument('--word_level', type=str2bool, default=True,
|
1292 |
+
help="whether to use word-level timestamps in output")
|
1293 |
+
|
1294 |
+
parser.add_argument('--reverse_text', type=str2bool, default=False,
|
1295 |
+
help="whether to reverse the order of words for each segment of text output")
|
1296 |
+
|
1297 |
+
# ass output
|
1298 |
+
parser.add_argument('--font', type=str, default='Arial',
|
1299 |
+
help="word font for ASS output(s)")
|
1300 |
+
parser.add_argument('--font_size', type=int, default=48,
|
1301 |
+
help="word font size for ASS output(s)")
|
1302 |
+
parser.add_argument('--karaoke', type=str2bool, default=False,
|
1303 |
+
help="whether to use progressive filling highlights for karaoke effect (only for ASS outputs)")
|
1304 |
+
|
1305 |
+
parser.add_argument("--temperature", type=float, default=0,
|
1306 |
+
help="temperature to use for sampling")
|
1307 |
+
parser.add_argument("--best_of", type=optional_int,
|
1308 |
+
help="number of candidates when sampling with non-zero temperature")
|
1309 |
+
parser.add_argument("--beam_size", type=optional_int,
|
1310 |
+
help="number of beams in beam search, only applicable when temperature is zero")
|
1311 |
+
parser.add_argument("--patience", type=float, default=None,
|
1312 |
+
help="optional patience value to use in beam decoding, "
|
1313 |
+
"as in https://arxiv.org/abs/2204.05424, "
|
1314 |
+
"the default (1.0) is equivalent to conventional beam search")
|
1315 |
+
parser.add_argument("--length_penalty", type=float, default=None,
|
1316 |
+
help="optional token length penalty coefficient (alpha) "
|
1317 |
+
"as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
1318 |
+
|
1319 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1",
|
1320 |
+
help="comma-separated list of token ids to suppress during sampling; "
|
1321 |
+
"'-1' will suppress most special characters except common punctuations")
|
1322 |
+
parser.add_argument("--initial_prompt", type=str, default=None,
|
1323 |
+
help="optional text to provide as a prompt for the first window.")
|
1324 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
|
1325 |
+
help="if True, provide the previous output of the model as a prompt for the next window; "
|
1326 |
+
"disabling may make the text inconsistent across windows, "
|
1327 |
+
"but the model becomes less prone to getting stuck in a failure loop")
|
1328 |
+
parser.add_argument("--fp16", type=str2bool, default=True,
|
1329 |
+
help="whether to perform inference in fp16; True by default")
|
1330 |
+
|
1331 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2,
|
1332 |
+
help="temperature to increase when falling back when the decoding fails to meet either of "
|
1333 |
+
"the thresholds below")
|
1334 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4,
|
1335 |
+
help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
1336 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0,
|
1337 |
+
help="if the average log probability is lower than this value, treat the decoding as failed")
|
1338 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6,
|
1339 |
+
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding "
|
1340 |
+
"has failed due to `logprob_threshold`, consider the segment as silence")
|
1341 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
1342 |
+
help="number of threads used by torch for CPU inference; "
|
1343 |
+
"supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
1344 |
+
|
1345 |
+
parser.add_argument('--mel_first', action='store_true',
|
1346 |
+
help='process entire audio track into log-Mel spectrogram first instead in chunks')
|
1347 |
+
|
1348 |
+
parser.add_argument('--only_ffmpeg', action='store_true',
|
1349 |
+
help='whether to use only FFmpeg (and not yt-dlp) for URls')
|
1350 |
+
|
1351 |
+
parser.add_argument('--overwrite', '-y', action='store_true',
|
1352 |
+
help='overwrite all output files')
|
1353 |
+
|
1354 |
+
parser.add_argument('--debug', action='store_true',
|
1355 |
+
help='print all input/output pair(s) and all arguments used for transcribing/translating')
|
1356 |
+
|
1357 |
+
parser.add_argument('--transcribe_method', '-tm', type=str, default='transcribe',
|
1358 |
+
choices=('transcribe', 'transcribe_minimal'))
|
1359 |
+
|
1360 |
+
parser.add_argument('--align', '-a', action="extend", nargs='+', type=str,
|
1361 |
+
help='path(s) to TXT file(s) or JSON previous result(s)')
|
1362 |
+
|
1363 |
+
parser.add_argument('--refine', '-r', action='store_true',
|
1364 |
+
help='Refine timestamps to increase precision of timestamps')
|
1365 |
+
|
1366 |
+
parser.add_argument('--locate', '-lc', action="extend", nargs='+', type=str,
|
1367 |
+
help='words to locate in the audio(s); skips transcription and output')
|
1368 |
+
|
1369 |
+
parser.add_argument('--refine_option', '-ro', action="extend", nargs='+', type=str,
|
1370 |
+
help='Extra option(s) to use for refining timestamps; Replace True/False with 1/0; '
|
1371 |
+
'E.g. --refine_option "steps=sese" --refine_options "rel_prob_decrease=0.05"')
|
1372 |
+
parser.add_argument('--demucs_option', '-do', action="extend", nargs='+', type=str,
|
1373 |
+
help='Extra option(s) to use for demucs; Replace True/False with 1/0; '
|
1374 |
+
'E.g. --demucs_option "shifts=3" --demucs_options "overlap=0.5"')
|
1375 |
+
parser.add_argument('--model_option', '-mo', action="extend", nargs='+', type=str,
|
1376 |
+
help='Extra option(s) to use for loading model; Replace True/False with 1/0; '
|
1377 |
+
'E.g. --model_option "download_root=./downloads"')
|
1378 |
+
parser.add_argument('--transcribe_option', '-to', action="extend", nargs='+', type=str,
|
1379 |
+
help='Extra option(s) to use for transcribing/alignment/locating; Replace True/False with 1/0; '
|
1380 |
+
'E.g. --transcribe_option "ignore_compatibility=1"')
|
1381 |
+
parser.add_argument('--save_option', '-so', action="extend", nargs='+', type=str,
|
1382 |
+
help='Extra option(s) to use for text outputs; Replace True/False with 1/0; '
|
1383 |
+
'E.g. --save_option "highlight_color=ffffff"')
|
1384 |
+
|
1385 |
+
parser.add_argument('--faster_whisper', '-fw', action='store_true',
|
1386 |
+
help='whether to use faster-whisper (https://github.com/guillaumekln/faster-whisper); '
|
1387 |
+
'note: some features may not be available')
|
1388 |
+
|
1389 |
+
args = parser.parse_args().__dict__
|
1390 |
+
debug = args.pop('debug')
|
1391 |
+
if not args['language'] and (args['align'] or args['locate']):
|
1392 |
+
raise ValueError('langauge is required for --align / --locate')
|
1393 |
+
|
1394 |
+
is_faster_whisper = args.pop('faster_whisper')
|
1395 |
+
model_name: str = args.pop("model")
|
1396 |
+
model_dir: str = args.pop("model_dir")
|
1397 |
+
inputs: List[Union[str, torch.Tensor]] = args.pop("inputs")
|
1398 |
+
outputs: List[str] = args.pop("output")
|
1399 |
+
output_dir: str = args.pop("output_dir")
|
1400 |
+
output_format = args.pop("output_format")
|
1401 |
+
overwrite: bool = args.pop("overwrite")
|
1402 |
+
use_demucs = args['demucs'] or False
|
1403 |
+
demucs_outputs: List[Optional[str]] = args.pop("demucs_output")
|
1404 |
+
args['demucs_options'] = update_options_with_args('demucs_option', pop=True)
|
1405 |
+
regroup = args.pop('regroup')
|
1406 |
+
max_chars = args.pop('max_chars')
|
1407 |
+
max_words = args.pop('max_words')
|
1408 |
+
args['verbose'] = False if args['verbose'] == 1 else (True if args['verbose'] == 2 else None)
|
1409 |
+
show_curr_task = args['verbose'] is not None
|
1410 |
+
strings_to_locate = args.pop('locate')
|
1411 |
+
if dq := args.pop('dynamic_quantization', False):
|
1412 |
+
args['device'] = 'cpu'
|
1413 |
+
if args['reverse_text']:
|
1414 |
+
args['reverse_text'] = (args.get('prepend_punctuations'), args.get('append_punctuations'))
|
1415 |
+
|
1416 |
+
if regroup:
|
1417 |
+
try:
|
1418 |
+
regroup = str2bool(regroup)
|
1419 |
+
except ValueError:
|
1420 |
+
pass
|
1421 |
+
curr_output_formats: List[str] = output_format.split(',') if output_format else []
|
1422 |
+
unsupported_formats = list(set(map(str.lower, curr_output_formats)) - OUTPUT_FORMATS)
|
1423 |
+
if outputs:
|
1424 |
+
unsupported_formats.extend(list(set(splitext(o)[-1].lower().strip('.') for o in outputs) - OUTPUT_FORMATS))
|
1425 |
+
if len(unsupported_formats) != 0:
|
1426 |
+
raise NotImplementedError(f'{unsupported_formats} are not supported. Supported formats: {OUTPUT_FORMATS}.')
|
1427 |
+
|
1428 |
+
has_demucs_output = bool(demucs_outputs)
|
1429 |
+
if use_demucs and has_demucs_output and len(demucs_outputs) != len(inputs):
|
1430 |
+
raise NotImplementedError(f'[demucs_output] and [inputs] do not match in count. '
|
1431 |
+
f'Got {len(demucs_outputs)} and {len(inputs)}')
|
1432 |
+
|
1433 |
+
if tag := args.get('tag'):
|
1434 |
+
assert tag == ['-1'] or len(tag) == 2, f'[tag] must be a pair of str but got {tag}'
|
1435 |
+
|
1436 |
+
def make_parent(filepath: str):
|
1437 |
+
if parent := split(filepath)[0]:
|
1438 |
+
os.makedirs(parent, exist_ok=True)
|
1439 |
+
|
1440 |
+
def is_json(file: str):
|
1441 |
+
return file.endswith(".json")
|
1442 |
+
|
1443 |
+
def call_method_with_options(method, options: dict, include_first: bool = True):
|
1444 |
+
def val_to_str(val) -> str:
|
1445 |
+
if isinstance(val, (np.ndarray, torch.Tensor)):
|
1446 |
+
return f'{val.__class__}(shape:{list(val.shape)})'
|
1447 |
+
elif isinstance(val, str):
|
1448 |
+
return f'"{val}"'
|
1449 |
+
elif isinstance(val, bytes):
|
1450 |
+
return f'{type(val)}(len:{len(val)})'
|
1451 |
+
elif isinstance(val, torch.nn.Module):
|
1452 |
+
return str(type(val))
|
1453 |
+
return str(val)
|
1454 |
+
|
1455 |
+
params = tuple(get_func_parameters(method))
|
1456 |
+
if debug:
|
1457 |
+
temp_options = {k: options.pop(k) for k in params if k in options}
|
1458 |
+
temp_options.update(options)
|
1459 |
+
options = temp_options
|
1460 |
+
options_str = ',\n'.join(
|
1461 |
+
f' {k}={val_to_str(v)}'
|
1462 |
+
for k, v in options.items()
|
1463 |
+
if include_first or k != params[0]
|
1464 |
+
)
|
1465 |
+
if options_str:
|
1466 |
+
options_str = f'\n{options_str}\n'
|
1467 |
+
else:
|
1468 |
+
print(options, params)
|
1469 |
+
print(f'{method.__qualname__}({options_str})')
|
1470 |
+
return method(**options)
|
1471 |
+
|
1472 |
+
if alignments := args['align']:
|
1473 |
+
if unsupported_align_fmts := \
|
1474 |
+
[_ext for p in alignments if (_ext := splitext(p)[-1].lower()) not in ('.json', '.txt')]:
|
1475 |
+
raise NotImplementedError(
|
1476 |
+
f'Unsupported format(s) for alignment: {unsupported_align_fmts}'
|
1477 |
+
)
|
1478 |
+
if len(inputs) != len(alignments):
|
1479 |
+
raise NotImplementedError(
|
1480 |
+
f'Got {len(inputs)} audio file(s) but specified {len(alignments)} file(s) to align.'
|
1481 |
+
)
|
1482 |
+
else:
|
1483 |
+
alignments = ['']*len(inputs)
|
1484 |
+
|
1485 |
+
def finalize_outputs(input_file: str, _output: str = None, _alignment: str = None) -> List[str]:
|
1486 |
+
_curr_output_formats = curr_output_formats.copy()
|
1487 |
+
basename, ext = splitext(_output or input_file)
|
1488 |
+
ext = ext[1:]
|
1489 |
+
if _output:
|
1490 |
+
if ext.lower() in OUTPUT_FORMATS:
|
1491 |
+
_curr_output_formats.append(ext)
|
1492 |
+
else:
|
1493 |
+
basename = _output
|
1494 |
+
if not _curr_output_formats:
|
1495 |
+
_curr_output_formats = ["srt" if is_json(input_file) or is_json(_alignment) else "json"]
|
1496 |
+
_outputs = [f'{basename}.{ext}' for ext in set(_curr_output_formats)]
|
1497 |
+
if output_dir:
|
1498 |
+
_outputs = [join(output_dir, o) for o in _outputs]
|
1499 |
+
|
1500 |
+
return _outputs
|
1501 |
+
|
1502 |
+
if outputs:
|
1503 |
+
if len(outputs) != len(inputs):
|
1504 |
+
raise NotImplementedError(f'Got {len(inputs)} audio file(s) but specified {len(outputs)} output file(s).')
|
1505 |
+
final_outputs = [finalize_outputs(i, o, a) for i, o, a in zip(inputs, outputs, alignments)]
|
1506 |
+
else:
|
1507 |
+
if not output_dir:
|
1508 |
+
output_dir = '.'
|
1509 |
+
final_outputs = [finalize_outputs(i, _alignment=a) for i, a in zip(inputs, alignments)]
|
1510 |
+
|
1511 |
+
if not overwrite:
|
1512 |
+
|
1513 |
+
def cancel_overwrite():
|
1514 |
+
resp = input(f'{path} already exist, overwrite (y/n)? ').lower()
|
1515 |
+
if resp in ('y', 'n'):
|
1516 |
+
return resp == 'n'
|
1517 |
+
print(f'Expected "y" or "n", but got {resp}.')
|
1518 |
+
return True
|
1519 |
+
|
1520 |
+
for paths in final_outputs:
|
1521 |
+
for path in paths:
|
1522 |
+
if isfile(path) and cancel_overwrite():
|
1523 |
+
return
|
1524 |
+
|
1525 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
1526 |
+
if args["language"] is not None:
|
1527 |
+
warnings.warn(f"{model_name} is an English-only model but receipted "
|
1528 |
+
f"'{args['language']}'; using English instead.")
|
1529 |
+
args["language"] = "en"
|
1530 |
+
|
1531 |
+
temperature = args.pop("temperature")
|
1532 |
+
increment = args.pop("temperature_increment_on_fallback")
|
1533 |
+
if increment is not None:
|
1534 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
1535 |
+
else:
|
1536 |
+
temperature = [temperature]
|
1537 |
+
|
1538 |
+
args['temperature'] = temperature
|
1539 |
+
|
1540 |
+
threads = args.pop("threads")
|
1541 |
+
if threads > 0:
|
1542 |
+
torch.set_num_threads(threads)
|
1543 |
+
|
1544 |
+
if debug:
|
1545 |
+
print('Input(s) -> Outputs(s)')
|
1546 |
+
for i, (input_audio, output_paths, alignment) in enumerate(zip(inputs, final_outputs, alignments)):
|
1547 |
+
dm_output = f' {demucs_outputs[i]} ->' if demucs_outputs else ''
|
1548 |
+
alignment = f' + "{alignment}"' if alignment else ''
|
1549 |
+
print(f'"{input_audio}"{alignment} ->{dm_output} {output_paths}')
|
1550 |
+
print('')
|
1551 |
+
|
1552 |
+
if show_curr_task:
|
1553 |
+
model_from_str = '' if model_dir is None else f' from {model_dir}'
|
1554 |
+
model_loading_str = f'{"Faster-Whisper" if is_faster_whisper else "Whisper"} {model_name} model {model_from_str}'
|
1555 |
+
print(f'Loading {model_loading_str}\r', end='\n' if debug else '')
|
1556 |
+
else:
|
1557 |
+
model_loading_str = ''
|
1558 |
+
|
1559 |
+
alignments = args['align']
|
1560 |
+
model = None
|
1561 |
+
|
1562 |
+
def _load_model():
|
1563 |
+
nonlocal model
|
1564 |
+
if model is None:
|
1565 |
+
model_options = dict(
|
1566 |
+
name=model_name,
|
1567 |
+
model_size_or_path=model_name,
|
1568 |
+
device=args.get('device'),
|
1569 |
+
download_root=model_dir,
|
1570 |
+
dq=dq,
|
1571 |
+
)
|
1572 |
+
load_model_func = load_faster_whisper if is_faster_whisper else load_model
|
1573 |
+
model_options = isolate_useful_options(model_options, load_model_func)
|
1574 |
+
update_options_with_args('model_option', model_options)
|
1575 |
+
model = call_method_with_options(load_model_func, model_options)
|
1576 |
+
if model_loading_str:
|
1577 |
+
print(f'Loaded {model_loading_str} ')
|
1578 |
+
return model
|
1579 |
+
|
1580 |
+
for i, (input_audio, output_paths) in enumerate(zip(inputs, final_outputs)):
|
1581 |
+
skip_output = False
|
1582 |
+
if isinstance(input_audio, str) and is_json(input_audio):
|
1583 |
+
result = WhisperResult(input_audio)
|
1584 |
+
else:
|
1585 |
+
model = _load_model()
|
1586 |
+
args['regroup'] = False
|
1587 |
+
args['audio'] = input_audio
|
1588 |
+
if has_demucs_output:
|
1589 |
+
args['demucs_output'] = demucs_outputs[i]
|
1590 |
+
transcribe_method = args.get('transcribe_method')
|
1591 |
+
text = None
|
1592 |
+
if alignments and (text := alignments[i]):
|
1593 |
+
if text.endswith('.json'):
|
1594 |
+
text = WhisperResult(text)
|
1595 |
+
else:
|
1596 |
+
with open(text, 'r', encoding='utf-8') as f:
|
1597 |
+
text = f.read()
|
1598 |
+
args['text'] = text
|
1599 |
+
transcribe_method = 'align'
|
1600 |
+
if is_faster_whisper and transcribe_method == 'transcribe':
|
1601 |
+
transcribe_method = 'transcribe_stable'
|
1602 |
+
if strings_to_locate and (text := strings_to_locate[i]):
|
1603 |
+
args['text'] = text
|
1604 |
+
transcribe_method = 'locate'
|
1605 |
+
skip_output = args['verbose'] = True
|
1606 |
+
transcribe_method = getattr(model, transcribe_method)
|
1607 |
+
transcribe_options = isolate_useful_options(args, transcribe_method)
|
1608 |
+
if not text:
|
1609 |
+
decoding_options = (
|
1610 |
+
isolate_useful_options(args, model.transcribe if is_faster_whisper else DecodingOptions)
|
1611 |
+
)
|
1612 |
+
if is_faster_whisper:
|
1613 |
+
if decoding_options['suppress_tokens']:
|
1614 |
+
decoding_options['suppress_tokens'] = (
|
1615 |
+
list(map(int, decoding_options['suppress_tokens'].split(',')))
|
1616 |
+
)
|
1617 |
+
for k in list(decoding_options.keys()):
|
1618 |
+
if decoding_options[k] is None:
|
1619 |
+
del decoding_options[k]
|
1620 |
+
transcribe_options.update(decoding_options)
|
1621 |
+
update_options_with_args('transcribe_option', transcribe_options)
|
1622 |
+
result: WhisperResult = call_method_with_options(transcribe_method, transcribe_options)
|
1623 |
+
|
1624 |
+
if skip_output:
|
1625 |
+
continue
|
1626 |
+
|
1627 |
+
if args['refine']:
|
1628 |
+
model = _load_model()
|
1629 |
+
refine_options = isolate_useful_options(args, model.refine)
|
1630 |
+
refine_options['result'] = result
|
1631 |
+
update_options_with_args('refine_option', refine_options)
|
1632 |
+
call_method_with_options(model.refine, refine_options)
|
1633 |
+
|
1634 |
+
if args.get('word_timestamps'):
|
1635 |
+
if regroup:
|
1636 |
+
result.regroup(regroup, verbose=args['verbose'] or debug)
|
1637 |
+
if max_chars or max_words:
|
1638 |
+
result.split_by_length(max_chars=max_chars, max_words=max_words)
|
1639 |
+
|
1640 |
+
for path in output_paths:
|
1641 |
+
make_parent(path)
|
1642 |
+
save_method = getattr(result, OUTPUT_FORMATS_METHODS[splitext(path)[-1][1:]])
|
1643 |
+
args['filepath'] = path
|
1644 |
+
args['path'] = path
|
1645 |
+
save_options = isolate_useful_options(args, save_method)
|
1646 |
+
update_options_with_args('save_option', save_options)
|
1647 |
+
call_method_with_options(save_method, save_options)
|
1648 |
+
|
1649 |
+
|
1650 |
+
if __name__ == '__main__':
|
1651 |
+
cli()
|