Spaces:
Runtime error
Runtime error
init
Browse files
LICENSE
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2023- The HuggingFace Inc. team and The OpenAI Authors. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright [yyyy] [name of copyright owner]
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
README.md
CHANGED
@@ -9,4 +9,471 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Whisper JAX
|
13 |
+
|
14 |
+
This repository contains optimised JAX code for OpenAI's [Whisper Model](https://arxiv.org/abs/2212.04356), largely built
|
15 |
+
on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x**
|
16 |
+
faster, making it the fastest Whisper implementation available.
|
17 |
+
|
18 |
+
The JAX code is compatible on CPU, GPU and TPU, and can be run standalone (see [Pipeline Usage](#pipeline-usage)) or
|
19 |
+
as an inference endpoint (see [Creating an Endpoint](#creating-an-endpoint)).
|
20 |
+
|
21 |
+
For a quick-start guide to running Whisper JAX on a Cloud TPU, refer to the following Kaggle notebook, where we transcribe 30 mins of audio in approx 30 sec:
|
22 |
+
|
23 |
+
[](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
|
24 |
+
|
25 |
+
The Whisper JAX model is also running as a demo on the Hugging Face Hub:
|
26 |
+
|
27 |
+
[](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
|
28 |
+
|
29 |
+
## Installation
|
30 |
+
|
31 |
+
Whisper JAX was tested using Python 3.9 and JAX version 0.4.5. Installation assumes that you already have the latest
|
32 |
+
version of the JAX package installed on your device. You can do so using the official JAX installation guide: https://github.com/google/jax#installation
|
33 |
+
|
34 |
+
Once the appropriate version of JAX has been installed, Whisper JAX can be installed through pip:
|
35 |
+
```
|
36 |
+
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
|
37 |
+
```
|
38 |
+
|
39 |
+
To update the Whisper JAX package to the latest version, simply run:
|
40 |
+
```
|
41 |
+
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
|
42 |
+
```
|
43 |
+
|
44 |
+
## Pipeline Usage
|
45 |
+
|
46 |
+
The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
|
47 |
+
the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices.
|
48 |
+
|
49 |
+
Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_
|
50 |
+
compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time:
|
51 |
+
|
52 |
+
```python
|
53 |
+
from whisper_jax import FlaxWhisperPipline
|
54 |
+
|
55 |
+
# instantiate pipeline
|
56 |
+
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
|
57 |
+
|
58 |
+
# JIT compile the forward call - slow, but we only do once
|
59 |
+
text = pipeline("audio.mp3")
|
60 |
+
|
61 |
+
# used cached function thereafter - super fast!!
|
62 |
+
text = pipeline("audio.mp3")
|
63 |
+
```
|
64 |
+
|
65 |
+
### Half-Precision
|
66 |
+
|
67 |
+
The model computation can be run in half-precision by passing the dtype argument when instantiating the pipeline. This will
|
68 |
+
speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision
|
69 |
+
of the model weights.
|
70 |
+
|
71 |
+
For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`:
|
72 |
+
```python
|
73 |
+
from whisper_jax import FlaxWhisperPipline
|
74 |
+
import jax.numpy as jnp
|
75 |
+
|
76 |
+
# instantiate pipeline in bfloat16
|
77 |
+
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
|
78 |
+
```
|
79 |
+
|
80 |
+
### Batching
|
81 |
+
Whisper JAX also provides the option of _batching_ a single audio input across accelerator devices. The audio is first
|
82 |
+
chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. The resulting
|
83 |
+
transcriptions are stitched back together at the boundaries to give a single, uniform transcription. In practice, batching
|
84 |
+
provides a 10x speed-up compared to transcribing the audio samples sequentially, with a less than 1% penalty to the WER[^1], provided the batch size is selected large enough.
|
85 |
+
|
86 |
+
To enable batching, pass the `batch_size` parameter when you instantiate the pipeline:
|
87 |
+
|
88 |
+
```python
|
89 |
+
from whisper_jax import FlaxWhisperPipline
|
90 |
+
|
91 |
+
# instantiate pipeline with batching
|
92 |
+
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
|
93 |
+
```
|
94 |
+
|
95 |
+
### Task
|
96 |
+
|
97 |
+
By default, the pipeline transcribes the audio file in the language it was spoken in. For speech translation, set the
|
98 |
+
`task` argument to `"translate"`:
|
99 |
+
|
100 |
+
```python
|
101 |
+
# translate
|
102 |
+
text = pipeline("audio.mp3", task="translate")
|
103 |
+
```
|
104 |
+
|
105 |
+
### Timestamps
|
106 |
+
|
107 |
+
The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
|
108 |
+
forward call, this time including the timestamp outputs:
|
109 |
+
|
110 |
+
```python
|
111 |
+
# transcribe and return timestamps
|
112 |
+
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
|
113 |
+
text = outputs["text"] # transcription
|
114 |
+
chunks = outputs["chunks"] # transcription + timestamps
|
115 |
+
```
|
116 |
+
|
117 |
+
### Putting it all together
|
118 |
+
In the following code snippet, we instantiate the model in bfloat16 precision with batching enabled, and transcribe the audio file
|
119 |
+
returning timestamps tokens:
|
120 |
+
|
121 |
+
```python
|
122 |
+
from whisper_jax import FlaxWhisperPipline
|
123 |
+
import jax.numpy as jnp
|
124 |
+
|
125 |
+
# instantiate pipeline with bfloat16 and enable batching
|
126 |
+
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
|
127 |
+
|
128 |
+
# transcribe and return timestamps
|
129 |
+
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
|
130 |
+
```
|
131 |
+
|
132 |
+
## Model Usage
|
133 |
+
|
134 |
+
The Whisper JAX model can use on a more granular level in much the same way as the original Hugging Face
|
135 |
+
Transformers implementation. This requires the Whisper processor to be loaded separately to the model to handle the
|
136 |
+
pre- and post-processing, and the generate function to be wrapped using `pmap` by hand:
|
137 |
+
|
138 |
+
```python
|
139 |
+
import jax.numpy as jnp
|
140 |
+
from datasets import load_dataset
|
141 |
+
from flax.jax_utils import replicate
|
142 |
+
from flax.training.common_utils import shard
|
143 |
+
from jax import device_get, pmap
|
144 |
+
from transformers import WhisperProcessor
|
145 |
+
|
146 |
+
from whisper_jax import FlaxWhisperForConditionalGeneration
|
147 |
+
|
148 |
+
# load the processor and model
|
149 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
|
150 |
+
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
151 |
+
"openai/whisper-large-v2", dtype=jnp.bfloat16, _do_init=False,
|
152 |
+
)
|
153 |
+
|
154 |
+
def generate_fn(input_features):
|
155 |
+
pred_ids = model.generate(
|
156 |
+
input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
|
157 |
+
)
|
158 |
+
return pred_ids.sequences
|
159 |
+
|
160 |
+
# pmap the generate function for data parallelism
|
161 |
+
p_generate = pmap(generate_fn, "input_features")
|
162 |
+
# replicate the parameters across devices
|
163 |
+
params = replicate(params)
|
164 |
+
|
165 |
+
# load a dummy sample from the LibriSpeech dataset
|
166 |
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
167 |
+
sample = ds[0]["audio"]
|
168 |
+
|
169 |
+
# pre-process: convert the audio array to log-mel input features
|
170 |
+
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features
|
171 |
+
# replicate the input features across devices for DP
|
172 |
+
input_features = shard(input_features)
|
173 |
+
|
174 |
+
# run the forward pass (JIT compiled the first time it is called)
|
175 |
+
pred_ids = p_generate(input_features)
|
176 |
+
output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))
|
177 |
+
|
178 |
+
# post-process: convert tokens ids to text string
|
179 |
+
transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
180 |
+
```
|
181 |
+
|
182 |
+
## Available Models and Languages
|
183 |
+
All Whisper models on the Hugging Face Hub with Flax weights are compatible with Whisper JAX. This includes, but is not limited to,
|
184 |
+
the official OpenAI Whisper checkpoints:
|
185 |
+
|
186 |
+
| Size | Parameters | English-only | Multilingual |
|
187 |
+
|----------|------------|------------------------------------------------------|-----------------------------------------------------|
|
188 |
+
| tiny | 39 M | [✓](https://huggingface.co/openai/whisper-tiny.en) | [✓](https://huggingface.co/openai/whisper-tiny) |
|
189 |
+
| base | 74 M | [✓](https://huggingface.co/openai/whisper-base.en) | [✓](https://huggingface.co/openai/whisper-base) |
|
190 |
+
| small | 244 M | [✓](https://huggingface.co/openai/whisper-small.en) | [✓](https://huggingface.co/openai/whisper-small) |
|
191 |
+
| medium | 769 M | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
|
192 |
+
| large | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large) |
|
193 |
+
| large-v2 | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large-v2) |
|
194 |
+
|
195 |
+
Should you wish to use a fine-tuned Whisper checkpoint in Whisper JAX, you should first convert the PyTorch weights to Flax.
|
196 |
+
This is straightforward through use of the `from_pt` argument, which will convert the PyTorch state dict to a frozen Flax
|
197 |
+
parameter dictionary on the fly. You can then push the converted Flax weights to the Hub to be used directly in Flax
|
198 |
+
the next time they are required. Note that converting weights from PyTorch to Flax requires both PyTorch and Flax to be installed.
|
199 |
+
|
200 |
+
For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper):
|
201 |
+
```python
|
202 |
+
from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline
|
203 |
+
import jax.numpy as jnp
|
204 |
+
|
205 |
+
checkpoint_id = "sanchit-gandhi/whisper-small-hi"
|
206 |
+
# convert PyTorch weights to Flax
|
207 |
+
model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)
|
208 |
+
# push converted weights to the Hub
|
209 |
+
model.push_to_hub(checkpoint_id)
|
210 |
+
|
211 |
+
# now we can load the Flax weights directly as required
|
212 |
+
pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
|
213 |
+
```
|
214 |
+
|
215 |
+
## Advanced Usage
|
216 |
+
More advanced users may wish to explore different parallelisation techniques. The Whisper JAX code is
|
217 |
+
built on-top of the [T5x codebase](https://github.com/google-research/t5x), meaning it can be run using model, activation, and data parallelism using the T5x
|
218 |
+
partitioning convention. To use T5x partitioning, the logical axis rules and number of model partitions must be defined.
|
219 |
+
For more details, the user is referred to the official T5x partitioning guide: https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md
|
220 |
+
|
221 |
+
### Pipeline
|
222 |
+
The following code snippet demonstrates how data parallelism can be achieved using the pipeline `shard_params` method in
|
223 |
+
an entirely equivalent way to `pmap`:
|
224 |
+
|
225 |
+
```python
|
226 |
+
from whisper_jax import FlaxWhisperPipline
|
227 |
+
import jax.numpy as jnp
|
228 |
+
|
229 |
+
# 2D parameter and activation partitioning for DP
|
230 |
+
logical_axis_rules_dp = (
|
231 |
+
("batch", "data"),
|
232 |
+
("mlp", None),
|
233 |
+
("heads", None),
|
234 |
+
("vocab", None),
|
235 |
+
("embed", None),
|
236 |
+
("embed", None),
|
237 |
+
("joined_kv", None),
|
238 |
+
("kv", None),
|
239 |
+
("length", None),
|
240 |
+
("num_mel", None),
|
241 |
+
("channels", None),
|
242 |
+
)
|
243 |
+
|
244 |
+
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
|
245 |
+
pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp)
|
246 |
+
```
|
247 |
+
|
248 |
+
### Model
|
249 |
+
It is also possible to use the Whisper JAX model with T5x partitioning by defining a T5x inference state and T5x partitioner:
|
250 |
+
|
251 |
+
```python
|
252 |
+
import jax
|
253 |
+
import jax.numpy as jnp
|
254 |
+
from flax.core.frozen_dict import freeze
|
255 |
+
from jax.sharding import PartitionSpec as P
|
256 |
+
|
257 |
+
from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner
|
258 |
+
|
259 |
+
|
260 |
+
# 2D parameter and activation partitioning for DP
|
261 |
+
logical_axis_rules_dp = [
|
262 |
+
("batch", "data"),
|
263 |
+
("mlp", None),
|
264 |
+
("heads", None),
|
265 |
+
("vocab", None),
|
266 |
+
("embed", None),
|
267 |
+
("embed", None),
|
268 |
+
("joined_kv", None),
|
269 |
+
("kv", None),
|
270 |
+
("length", None),
|
271 |
+
("num_mel", None),
|
272 |
+
("channels", None),
|
273 |
+
]
|
274 |
+
|
275 |
+
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
276 |
+
"openai/whisper-large-v2",
|
277 |
+
_do_init=False,
|
278 |
+
dtype=jnp.bfloat16,
|
279 |
+
)
|
280 |
+
|
281 |
+
|
282 |
+
def init_fn():
|
283 |
+
input_shape = (1, 80, 3000)
|
284 |
+
|
285 |
+
input_features = jnp.zeros(input_shape, dtype="f4")
|
286 |
+
input_features = input_features.at[(..., -1)].set(model.config.eos_token_id)
|
287 |
+
|
288 |
+
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
|
289 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
290 |
+
|
291 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
292 |
+
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
293 |
+
|
294 |
+
rng = jax.random.PRNGKey(0)
|
295 |
+
init_params = model.module.init(
|
296 |
+
rng,
|
297 |
+
input_features=input_features,
|
298 |
+
decoder_input_ids=decoder_input_ids,
|
299 |
+
decoder_attention_mask=decoder_attention_mask,
|
300 |
+
decoder_position_ids=decoder_position_ids,
|
301 |
+
return_dict=False,
|
302 |
+
)
|
303 |
+
return init_params
|
304 |
+
|
305 |
+
|
306 |
+
# Axis names metadata
|
307 |
+
param_axes = jax.eval_shape(init_fn)["params_axes"]
|
308 |
+
|
309 |
+
# Create InferenceState, since the partitioner expects it
|
310 |
+
state = InferenceState(
|
311 |
+
step=jnp.array(0),
|
312 |
+
params=freeze(model.params_shape_tree),
|
313 |
+
params_axes=freeze(param_axes),
|
314 |
+
flax_mutables=None,
|
315 |
+
flax_mutables_axes=param_axes,
|
316 |
+
)
|
317 |
+
|
318 |
+
# Define the pjit partitioner with 1 model partition
|
319 |
+
partitioner = PjitPartitioner(
|
320 |
+
num_partitions=1,
|
321 |
+
logical_axis_rules=logical_axis_rules_dp,
|
322 |
+
)
|
323 |
+
|
324 |
+
mesh_axes = partitioner.get_mesh_axes(state)
|
325 |
+
params_spec = mesh_axes.params
|
326 |
+
|
327 |
+
p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)
|
328 |
+
|
329 |
+
|
330 |
+
def generate(params, input_features):
|
331 |
+
output_ids = model.generate(input_features, params=params, max_length=model.config.max_length).sequences
|
332 |
+
return output_ids
|
333 |
+
|
334 |
+
|
335 |
+
p_generate = partitioner.partition(
|
336 |
+
generate,
|
337 |
+
in_axis_resources=(params_spec, P("data")),
|
338 |
+
out_axis_resources=P("data"),
|
339 |
+
)
|
340 |
+
|
341 |
+
# This will auto-magically run in mesh context
|
342 |
+
params = p_shard_params(freeze(params))
|
343 |
+
|
344 |
+
# you can now run the forward pass with:
|
345 |
+
# pred_ids = p_generate(input_features)
|
346 |
+
```
|
347 |
+
|
348 |
+
## Benchmarks
|
349 |
+
|
350 |
+
We compare Whisper JAX to the official [OpenAI implementation](https://github.com/openai/whisper) and the [🤗 Transformers
|
351 |
+
implementation](https://huggingface.co/docs/transformers/model_doc/whisper). We benchmark the models on audio samples of
|
352 |
+
increasing length and report the average inference time in seconds over 10 repeat runs. For all three systems, we pass a
|
353 |
+
pre-loaded audio file to the model and measure the time for the forward pass. Leaving the task of loading the audio file
|
354 |
+
to the systems adds an equal offset to all the benchmark times, so the actual time for loading **and** transcribing an
|
355 |
+
audio file will be higher than the reported numbers.
|
356 |
+
|
357 |
+
OpenAI and Transformers both run in PyTorch on GPU. Whisper JAX runs in JAX on GPU and TPU. OpenAI transcribes the audio
|
358 |
+
sequentially in the order it is spoken. Both Transformers and Whisper JAX use a batching algorithm, where chunks of audio
|
359 |
+
are batched together and transcribed in parallel (see section [Batching](#batching)).
|
360 |
+
|
361 |
+
**Table 1:** Average inference time in seconds for audio files of increasing length. GPU device is a single A100 40GB GPU.
|
362 |
+
TPU device is a single TPU v4-8.
|
363 |
+
|
364 |
+
<div align="center">
|
365 |
+
|
366 |
+
| | OpenAI | Transformers | Whisper JAX | Whisper JAX |
|
367 |
+
|-----------|---------|--------------|-------------|-------------|
|
368 |
+
| | | | | |
|
369 |
+
| Framework | PyTorch | PyTorch | JAX | JAX |
|
370 |
+
| Backend | GPU | GPU | GPU | TPU |
|
371 |
+
| | | | | |
|
372 |
+
| 1 min | 13.8 | 4.54 | 1.72 | 0.45 |
|
373 |
+
| 10 min | 108.3 | 20.2 | 9.38 | 2.01 |
|
374 |
+
| 1 hour | 1001.0 | 126.1 | 75.3 | 13.8 |
|
375 |
+
| | | | | |
|
376 |
+
|
377 |
+
</div>
|
378 |
+
|
379 |
+
## Creating an Endpoint
|
380 |
+
|
381 |
+
The Whisper JAX model is running as a demo on the Hugging Face Hub:
|
382 |
+
|
383 |
+
[](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
|
384 |
+
|
385 |
+
However, at peak times there may be a queue of users that limit how quickly your audio input is transcribed. In this case,
|
386 |
+
you may benefit from running the model yourself, such that you have unrestricted access to the Whisper JAX model.
|
387 |
+
|
388 |
+
If you are just interested in running the model in a standalone Python script, refer to the Kaggle notebook Whisper JAX TPU:
|
389 |
+
|
390 |
+
[](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
|
391 |
+
|
392 |
+
Otherwise, we provide all the necessary code for creating an inference endpoint. To obtain this code, first clone the
|
393 |
+
repository on the GPU/TPU on which you want to host the endpoint:
|
394 |
+
```
|
395 |
+
git clone https://github.com/sanchit-gandhi/whisper-jax
|
396 |
+
```
|
397 |
+
|
398 |
+
And then install Whisper JAX from source, with the required additional endpoint dependencies:
|
399 |
+
```
|
400 |
+
cd whisper-jax
|
401 |
+
pip install -e .["endpoint"]
|
402 |
+
```
|
403 |
+
|
404 |
+
We recommend that you set-up an endpoint in the same zone/region as the one you are based in. This reduces the communication
|
405 |
+
time between your local machine and the remote one, which can significantly reduce the overall request time.
|
406 |
+
|
407 |
+
The Python script [`fastapi_app.py`](app/fastapi_app.py) contains the code to launch a FastAPI app with the Whisper large-v2 model.
|
408 |
+
By default, it uses a batch size of 16 and bfloat16 half-precision. You should update these parameters depending on your
|
409 |
+
GPU/TPU device (as explained in the sections on [Half-precision](#half-precision) and [Batching](#batching)).
|
410 |
+
|
411 |
+
You can launch the FastAPI app through Uvicorn using the bash script [`launch_app.sh`](app/launch_app.sh):
|
412 |
+
```
|
413 |
+
bash launch_app.sh
|
414 |
+
```
|
415 |
+
|
416 |
+
This will open the port 8000 for the FastAPI app. To direct network requests to the FastAPI app, we use ngrok to launch a
|
417 |
+
server on the corresponding port:
|
418 |
+
```
|
419 |
+
ngrok http --subdomain=whisper-jax 8000
|
420 |
+
```
|
421 |
+
|
422 |
+
We can now send json requests to our endpoint using ngrok. The function `transcribe_audio` loads an audio file, encodes it
|
423 |
+
in bytes, sends it to our endpoint, and returns the transcription:
|
424 |
+
|
425 |
+
```python
|
426 |
+
import base64
|
427 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
428 |
+
import requests
|
429 |
+
|
430 |
+
API_URL = "https://whisper-jax.ngrok.io/generate/" # make sure this URL matches your ngrok subdomain
|
431 |
+
|
432 |
+
|
433 |
+
def query(payload):
|
434 |
+
"""Send json payload to ngrok API URL and return response."""
|
435 |
+
response = requests.post(API_URL, json=payload)
|
436 |
+
return response.json(), response.status_code
|
437 |
+
|
438 |
+
|
439 |
+
def transcribe_audio(audio_file, task="transcribe", return_timestamps=False):
|
440 |
+
with open(audio_file, "rb") as f:
|
441 |
+
inputs = f.read()
|
442 |
+
inputs = ffmpeg_read(inputs, sampling_rate=16000)
|
443 |
+
# encode to bytes to make json compatible
|
444 |
+
inputs = {"array": base64.b64encode(inputs.tobytes()).decode(), "sampling_rate": 16000}
|
445 |
+
# format as a json payload and send query
|
446 |
+
payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
|
447 |
+
data, status_code = query(payload)
|
448 |
+
|
449 |
+
if status_code == 200:
|
450 |
+
output = {"text": data["text"], "chunks": data.get("chunks", None)}
|
451 |
+
else:
|
452 |
+
output = data["detail"]
|
453 |
+
|
454 |
+
return output
|
455 |
+
|
456 |
+
# transcribe an audio file using our endpoint
|
457 |
+
output = transcribe_audio("audio.mp3")
|
458 |
+
```
|
459 |
+
|
460 |
+
Note that this code snippet sends a base64 byte encoding of the audio file to the remote machine over [`requests`](https://requests.readthedocs.io).
|
461 |
+
In some cases, transferring the audio request from the local machine to the remote can take longer than actually
|
462 |
+
transcribing it. Therefore, you may wish to explore more efficient methods of sending requests, such as parallel
|
463 |
+
requests/transcription (see function `transcribe_chunked_audio` in [app.py](app/app.py).)
|
464 |
+
|
465 |
+
Finally, we can create a Gradio demo for the frontend, the code for which resides in [`app.py`](app/app.py). You can launch this
|
466 |
+
application by providing the ngrok subdomain:
|
467 |
+
```
|
468 |
+
API_URL=https://whisper-jax.ngrok.io/generate/ API_URL_FROM_FEATURES=https://whisper-jax.ngrok.io/generate_from_features/ python app.py
|
469 |
+
```
|
470 |
+
|
471 |
+
This will launch a Gradio demo with the same interface as the official Whisper JAX demo.
|
472 |
+
|
473 |
+
## Acknowledgements
|
474 |
+
|
475 |
+
* 🤗 Hugging Face Transformers for the base Whisper implementation, particularly to [andyehrenberg](https://github.com/andyehrenberg) for the [Flax Whisper PR](https://github.com/huggingface/transformers/pull/20479) and [ArthurZucker](https://github.com/ArthurZucker) for the batching algorithm
|
476 |
+
* Gradio for their easy-to-use package for building ML demos, and [pcuenca](https://github.com/pcuenca) for the help in hooking the demo up to the TPU
|
477 |
+
* Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for Cloud TPUs
|
478 |
+
|
479 |
+
[^1]: See WER results from Colab: https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing
|
app.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from multiprocessing import Pool
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import pytube
|
10 |
+
import requests
|
11 |
+
from processing_whisper import WhisperPrePostProcessor
|
12 |
+
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
13 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
14 |
+
|
15 |
+
|
16 |
+
title = "Whisper JAX: The Fastest Whisper API ⚡️"
|
17 |
+
|
18 |
+
description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available.
|
19 |
+
|
20 |
+
Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar.
|
21 |
+
|
22 |
+
To skip the queue, you may wish to create your own inference endpoint, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint).
|
23 |
+
"""
|
24 |
+
|
25 |
+
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
|
26 |
+
|
27 |
+
API_URL = os.getenv("API_URL")
|
28 |
+
API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES")
|
29 |
+
language_names = sorted(TO_LANGUAGE_CODE.keys())
|
30 |
+
CHUNK_LENGTH_S = 30
|
31 |
+
BATCH_SIZE = 16
|
32 |
+
NUM_PROC = 16
|
33 |
+
FILE_LIMIT_MB = 1000
|
34 |
+
|
35 |
+
|
36 |
+
def query(payload):
|
37 |
+
response = requests.post(API_URL, json=payload)
|
38 |
+
return response.json(), response.status_code
|
39 |
+
|
40 |
+
|
41 |
+
def inference(inputs, task=None, return_timestamps=False):
|
42 |
+
payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
|
43 |
+
|
44 |
+
data, status_code = query(payload)
|
45 |
+
|
46 |
+
if status_code != 200:
|
47 |
+
# error with our request - return the details to the user
|
48 |
+
raise gr.Error(data["detail"])
|
49 |
+
|
50 |
+
text = data["detail"]
|
51 |
+
timestamps = data.get("chunks")
|
52 |
+
if timestamps is not None:
|
53 |
+
timestamps = [
|
54 |
+
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
|
55 |
+
for chunk in timestamps
|
56 |
+
]
|
57 |
+
text = "\n".join(str(feature) for feature in timestamps)
|
58 |
+
return text
|
59 |
+
|
60 |
+
|
61 |
+
def chunked_query(payload):
|
62 |
+
response = requests.post(API_URL_FROM_FEATURES, json=payload)
|
63 |
+
return response.json(), response.status_code
|
64 |
+
|
65 |
+
|
66 |
+
def forward(batch, task=None, return_timestamps=False):
|
67 |
+
feature_shape = batch["input_features"].shape
|
68 |
+
batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
|
69 |
+
outputs, status_code = chunked_query(
|
70 |
+
{"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
|
71 |
+
)
|
72 |
+
if status_code != 200:
|
73 |
+
# error with our request - return the details to the user
|
74 |
+
raise gr.Error(outputs["detail"])
|
75 |
+
outputs["tokens"] = np.asarray(outputs["tokens"])
|
76 |
+
return outputs
|
77 |
+
|
78 |
+
|
79 |
+
def identity(batch):
|
80 |
+
return batch
|
81 |
+
|
82 |
+
|
83 |
+
# Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
|
84 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
|
85 |
+
if seconds is not None:
|
86 |
+
milliseconds = round(seconds * 1000.0)
|
87 |
+
|
88 |
+
hours = milliseconds // 3_600_000
|
89 |
+
milliseconds -= hours * 3_600_000
|
90 |
+
|
91 |
+
minutes = milliseconds // 60_000
|
92 |
+
milliseconds -= minutes * 60_000
|
93 |
+
|
94 |
+
seconds = milliseconds // 1_000
|
95 |
+
milliseconds -= seconds * 1_000
|
96 |
+
|
97 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
98 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
99 |
+
else:
|
100 |
+
# we have a malformed timestamp so just return it as is
|
101 |
+
return seconds
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2")
|
106 |
+
stride_length_s = CHUNK_LENGTH_S / 6
|
107 |
+
chunk_len = round(CHUNK_LENGTH_S * processor.feature_extractor.sampling_rate)
|
108 |
+
stride_left = stride_right = round(stride_length_s * processor.feature_extractor.sampling_rate)
|
109 |
+
step = chunk_len - stride_left - stride_right
|
110 |
+
pool = Pool(NUM_PROC)
|
111 |
+
|
112 |
+
def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
|
113 |
+
inputs_len = inputs["array"].shape[0]
|
114 |
+
all_chunk_start_idx = np.arange(0, inputs_len, step)
|
115 |
+
num_samples = len(all_chunk_start_idx)
|
116 |
+
num_batches = math.ceil(num_samples / BATCH_SIZE)
|
117 |
+
dummy_batches = list(
|
118 |
+
range(num_batches)
|
119 |
+
) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
|
120 |
+
|
121 |
+
dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
|
122 |
+
progress(0, desc="Pre-processing audio file...")
|
123 |
+
dataloader = pool.map(identity, dataloader)
|
124 |
+
|
125 |
+
model_outputs = []
|
126 |
+
start_time = time.time()
|
127 |
+
# iterate over our chunked audio samples
|
128 |
+
for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
|
129 |
+
model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
|
130 |
+
runtime = time.time() - start_time
|
131 |
+
|
132 |
+
post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
|
133 |
+
text = post_processed["text"]
|
134 |
+
timestamps = post_processed.get("chunks")
|
135 |
+
if timestamps is not None:
|
136 |
+
timestamps = [
|
137 |
+
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
|
138 |
+
for chunk in timestamps
|
139 |
+
]
|
140 |
+
text = "\n".join(str(feature) for feature in timestamps)
|
141 |
+
return text, runtime
|
142 |
+
|
143 |
+
def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
|
144 |
+
progress(0, desc="Loading audio file...")
|
145 |
+
if inputs is None:
|
146 |
+
raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
|
147 |
+
file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
|
148 |
+
if file_size_mb > FILE_LIMIT_MB:
|
149 |
+
raise gr.Error(
|
150 |
+
f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
|
151 |
+
)
|
152 |
+
|
153 |
+
with open(inputs, "rb") as f:
|
154 |
+
inputs = f.read()
|
155 |
+
|
156 |
+
inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
|
157 |
+
inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
|
158 |
+
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
159 |
+
return text, runtime
|
160 |
+
|
161 |
+
def _return_yt_html_embed(yt_url):
|
162 |
+
video_id = yt_url.split("?v=")[-1]
|
163 |
+
HTML_str = (
|
164 |
+
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
|
165 |
+
" </center>"
|
166 |
+
)
|
167 |
+
return HTML_str
|
168 |
+
|
169 |
+
def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
|
170 |
+
progress(0, desc="Loading audio file...")
|
171 |
+
html_embed_str = _return_yt_html_embed(yt_url)
|
172 |
+
try:
|
173 |
+
yt = pytube.YouTube(yt_url)
|
174 |
+
stream = yt.streams.filter(only_audio=True)[0]
|
175 |
+
except KeyError:
|
176 |
+
raise gr.Error("An error occurred while loading the YouTube video. Please try again.")
|
177 |
+
|
178 |
+
if stream.filesize_mb > max_filesize:
|
179 |
+
raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
|
180 |
+
|
181 |
+
stream.download(filename="audio.mp3")
|
182 |
+
|
183 |
+
with open("audio.mp3", "rb") as f:
|
184 |
+
inputs = f.read()
|
185 |
+
|
186 |
+
inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
|
187 |
+
inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
|
188 |
+
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
189 |
+
return html_embed_str, text, runtime
|
190 |
+
|
191 |
+
microphone_chunked = gr.Interface(
|
192 |
+
fn=transcribe_chunked_audio,
|
193 |
+
inputs=[
|
194 |
+
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
|
195 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
196 |
+
gr.inputs.Checkbox(default=False, label="Return timestamps"),
|
197 |
+
],
|
198 |
+
outputs=[
|
199 |
+
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
|
200 |
+
gr.outputs.Textbox(label="Transcription Time (s)"),
|
201 |
+
],
|
202 |
+
allow_flagging="never",
|
203 |
+
title=title,
|
204 |
+
description=description,
|
205 |
+
article=article,
|
206 |
+
)
|
207 |
+
|
208 |
+
audio_chunked = gr.Interface(
|
209 |
+
fn=transcribe_chunked_audio,
|
210 |
+
inputs=[
|
211 |
+
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
|
212 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
213 |
+
gr.inputs.Checkbox(default=False, label="Return timestamps"),
|
214 |
+
],
|
215 |
+
outputs=[
|
216 |
+
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
|
217 |
+
gr.outputs.Textbox(label="Transcription Time (s)"),
|
218 |
+
],
|
219 |
+
allow_flagging="never",
|
220 |
+
title=title,
|
221 |
+
description=description,
|
222 |
+
article=article,
|
223 |
+
)
|
224 |
+
|
225 |
+
youtube = gr.Interface(
|
226 |
+
fn=transcribe_youtube,
|
227 |
+
inputs=[
|
228 |
+
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
|
229 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
230 |
+
gr.inputs.Checkbox(default=False, label="Return timestamps"),
|
231 |
+
],
|
232 |
+
outputs=[
|
233 |
+
gr.outputs.HTML(label="Video"),
|
234 |
+
gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
|
235 |
+
gr.outputs.Textbox(label="Transcription Time (s)"),
|
236 |
+
],
|
237 |
+
allow_flagging="never",
|
238 |
+
title=title,
|
239 |
+
examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]],
|
240 |
+
cache_examples=False,
|
241 |
+
description=description,
|
242 |
+
article=article,
|
243 |
+
)
|
244 |
+
|
245 |
+
demo = gr.Blocks()
|
246 |
+
|
247 |
+
with demo:
|
248 |
+
gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
|
249 |
+
|
250 |
+
demo.queue(concurrency_count=3, max_size=5)
|
251 |
+
demo.launch(show_api=False, max_threads=10)
|