Spaces:
Runtime error
Runtime error
Commit
·
51fd668
1
Parent(s):
2503b95
Update asr_diarizer.py
Browse files- asr_diarizer.py +82 -11
asr_diarizer.py
CHANGED
@@ -2,8 +2,13 @@ from typing import List, Optional, Union
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
|
|
|
|
|
|
5 |
from pyannote.audio import Pipeline
|
6 |
from transformers import pipeline
|
|
|
7 |
|
8 |
|
9 |
class ASRDiarizationPipeline:
|
@@ -14,14 +19,16 @@ class ASRDiarizationPipeline:
|
|
14 |
):
|
15 |
self.asr_pipeline = asr_pipeline
|
16 |
self.diarization_pipeline = diarization_pipeline
|
|
|
|
|
17 |
|
18 |
@classmethod
|
19 |
def from_pretrained(
|
20 |
cls,
|
21 |
asr_model: Optional[str] = "openai/whisper-small",
|
22 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
23 |
-
chunk_length_s: int = 30,
|
24 |
-
use_auth_token: Union[str, bool] = True,
|
25 |
**kwargs,
|
26 |
):
|
27 |
asr_pipeline = pipeline(
|
@@ -37,21 +44,42 @@ class ASRDiarizationPipeline:
|
|
37 |
def __call__(
|
38 |
self,
|
39 |
inputs: Union[np.ndarray, List[np.ndarray]],
|
40 |
-
sampling_rate: int,
|
41 |
group_by_speaker: bool = True,
|
42 |
**kwargs,
|
43 |
):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
|
50 |
diarization = self.diarization_pipeline(
|
51 |
-
{"waveform": diarizer_inputs, "sample_rate": sampling_rate},
|
52 |
**kwargs,
|
53 |
)
|
54 |
-
del diarizer_inputs
|
55 |
|
56 |
segments = diarization.for_json()["content"]
|
57 |
|
@@ -78,7 +106,7 @@ class ASRDiarizationPipeline:
|
|
78 |
)
|
79 |
|
80 |
asr_out = self.asr_pipeline(
|
81 |
-
{"array": inputs, "sampling_rate": sampling_rate},
|
82 |
return_timestamps=True,
|
83 |
**kwargs,
|
84 |
)
|
@@ -110,3 +138,46 @@ class ASRDiarizationPipeline:
|
|
110 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
111 |
|
112 |
return segmented_preds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from torchaudio import functional as F
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
from pyannote.audio import Pipeline
|
10 |
from transformers import pipeline
|
11 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
12 |
|
13 |
|
14 |
class ASRDiarizationPipeline:
|
|
|
19 |
):
|
20 |
self.asr_pipeline = asr_pipeline
|
21 |
self.diarization_pipeline = diarization_pipeline
|
22 |
+
|
23 |
+
self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
|
24 |
|
25 |
@classmethod
|
26 |
def from_pretrained(
|
27 |
cls,
|
28 |
asr_model: Optional[str] = "openai/whisper-small",
|
29 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
30 |
+
chunk_length_s: Optional[int] = 30,
|
31 |
+
use_auth_token: Optional[Union[str, bool]] = True,
|
32 |
**kwargs,
|
33 |
):
|
34 |
asr_pipeline = pipeline(
|
|
|
44 |
def __call__(
|
45 |
self,
|
46 |
inputs: Union[np.ndarray, List[np.ndarray]],
|
|
|
47 |
group_by_speaker: bool = True,
|
48 |
**kwargs,
|
49 |
):
|
50 |
+
"""
|
51 |
+
Transcribe the audio sequence(s) given as inputs to text.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
55 |
+
The inputs is either :
|
56 |
+
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
57 |
+
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
58 |
+
- `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
|
59 |
+
same way.
|
60 |
+
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
61 |
+
Raw audio at the correct sampling rate (no further check will be done)
|
62 |
+
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
63 |
+
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
|
64 |
+
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
65 |
+
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
66 |
+
inference to provide more context to the model). Only use `stride` with CTC models.
|
67 |
+
|
68 |
+
Return:
|
69 |
+
`Dict`: A dictionary with the following keys:
|
70 |
+
- **text** (`str` ) -- The recognized text.
|
71 |
+
- **chunks** (*optional(, `List[Dict]`)
|
72 |
+
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
73 |
+
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
74 |
+
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
75 |
+
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
76 |
+
"""
|
77 |
+
inputs, diarizer_inputs = self.preprocess(inputs)
|
78 |
|
|
|
79 |
diarization = self.diarization_pipeline(
|
80 |
+
{"waveform": diarizer_inputs, "sample_rate": self.sampling_rate},
|
81 |
**kwargs,
|
82 |
)
|
|
|
83 |
|
84 |
segments = diarization.for_json()["content"]
|
85 |
|
|
|
106 |
)
|
107 |
|
108 |
asr_out = self.asr_pipeline(
|
109 |
+
{"array": inputs, "sampling_rate": self.sampling_rate},
|
110 |
return_timestamps=True,
|
111 |
**kwargs,
|
112 |
)
|
|
|
138 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
139 |
|
140 |
return segmented_preds
|
141 |
+
|
142 |
+
def preprocess(self, inputs):
|
143 |
+
if isinstance(inputs, str):
|
144 |
+
if inputs.startswith("http://") or inputs.startswith("https://"):
|
145 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
146 |
+
# like http_huggingface_co.png
|
147 |
+
inputs = requests.get(inputs).content
|
148 |
+
else:
|
149 |
+
with open(inputs, "rb") as f:
|
150 |
+
inputs = f.read()
|
151 |
+
|
152 |
+
if isinstance(inputs, bytes):
|
153 |
+
inputs = ffmpeg_read(inputs, self.sampling_rate)
|
154 |
+
|
155 |
+
if isinstance(inputs, dict):
|
156 |
+
# Accepting `"array"` which is the key defined in `datasets` for better integration
|
157 |
+
if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
|
158 |
+
raise ValueError(
|
159 |
+
"When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a "
|
160 |
+
'"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
|
161 |
+
"containing the sampling_rate associated with that array"
|
162 |
+
)
|
163 |
+
|
164 |
+
_inputs = inputs.pop("raw", None)
|
165 |
+
if _inputs is None:
|
166 |
+
# Remove path which will not be used from `datasets`.
|
167 |
+
inputs.pop("path", None)
|
168 |
+
_inputs = inputs.pop("array", None)
|
169 |
+
in_sampling_rate = inputs.pop("sampling_rate")
|
170 |
+
inputs = _inputs
|
171 |
+
if in_sampling_rate != self.sampling_rate:
|
172 |
+
inputs = F.resample(
|
173 |
+
torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate
|
174 |
+
).numpy()
|
175 |
+
|
176 |
+
if not isinstance(inputs, np.ndarray):
|
177 |
+
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
178 |
+
if len(inputs.shape) != 1:
|
179 |
+
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
|
180 |
+
|
181 |
+
diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
|
182 |
+
|
183 |
+
return inputs, diarizer_inputs
|