matsuap commited on
Commit
faed9d7
·
verified ·
1 Parent(s): 16d1477

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/web/demo.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ベースイメージを指定
2
+ FROM python:3.11-slim
3
+
4
+ # 作業ディレクトリを設定
5
+ WORKDIR /app
6
+
7
+ # 必要なパッケージをインストール
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # アプリケーションのソースコードをコピー
12
+ COPY . .
13
+
14
+ # サーバーを起動するコマンドを指定
15
+ CMD ["python", "./whisper_fastapi_online_server.py", "--host", "localhost", "--port", "8000", "--lan", "en", "--model", "tiny"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ librosa
2
+ soundfile
3
+ fastapi
4
+ ffmpeg-python
5
+ faster-whisper
src/__init__.py ADDED
File without changes
src/diarization/diarization_online.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diart import SpeakerDiarization
2
+ from diart.inference import StreamingInference
3
+ from diart.sources import AudioSource
4
+ from rx.subject import Subject
5
+ import threading
6
+ import numpy as np
7
+ import asyncio
8
+
9
+ class WebSocketAudioSource(AudioSource):
10
+ """
11
+ Simple custom AudioSource that blocks in read()
12
+ until close() is called.
13
+ push_audio() is used to inject new PCM chunks.
14
+ """
15
+ def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
16
+ super().__init__(uri, sample_rate)
17
+ self._close_event = threading.Event()
18
+ self._closed = False
19
+
20
+ def read(self):
21
+ self._close_event.wait()
22
+
23
+ def close(self):
24
+ if not self._closed:
25
+ self._closed = True
26
+ self.stream.on_completed()
27
+ self._close_event.set()
28
+
29
+ def push_audio(self, chunk: np.ndarray):
30
+ chunk = np.expand_dims(chunk, axis=0)
31
+ if not self._closed:
32
+ self.stream.on_next(chunk)
33
+
34
+
35
+ def create_pipeline(SAMPLE_RATE):
36
+ diar_pipeline = SpeakerDiarization()
37
+ ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
38
+ inference = StreamingInference(
39
+ pipeline=diar_pipeline,
40
+ source=ws_source,
41
+ do_plot=False,
42
+ show_progress=False,
43
+ )
44
+ return inference, ws_source
45
+
46
+
47
+ def init_diart(SAMPLE_RATE):
48
+ inference, ws_source = create_pipeline(SAMPLE_RATE)
49
+
50
+ def diar_hook(result):
51
+ """
52
+ Hook called each time Diart processes a chunk.
53
+ result is (annotation, audio).
54
+ We store the label of the last segment in 'current_speaker'.
55
+ """
56
+ global l_speakers
57
+ l_speakers = []
58
+ annotation, audio = result
59
+ for speaker in annotation._labels:
60
+ segments_beg = annotation._labels[speaker].segments_boundaries_[0]
61
+ segments_end = annotation._labels[speaker].segments_boundaries_[-1]
62
+ asyncio.create_task(
63
+ l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
64
+ )
65
+
66
+ l_speakers_queue = asyncio.Queue()
67
+ inference.attach_hooks(diar_hook)
68
+
69
+ # Launch Diart in a background thread
70
+ loop = asyncio.get_event_loop()
71
+ diar_future = loop.run_in_executor(None, inference)
72
+ return inference, l_speakers_queue, ws_source
73
+
74
+
75
+ class DiartDiarization():
76
+ def __init__(self, SAMPLE_RATE):
77
+ self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
78
+ self.segment_speakers = []
79
+
80
+ async def diarize(self, pcm_array):
81
+ self.ws_source.push_audio(pcm_array)
82
+ self.segment_speakers = []
83
+ while not self.l_speakers_queue.empty():
84
+ self.segment_speakers.append(await self.l_speakers_queue.get())
85
+
86
+ def close(self):
87
+ self.ws_source.close()
88
+
89
+
90
+ def assign_speakers_to_chunks(self, chunks):
91
+ """
92
+ Go through each chunk and see which speaker(s) overlap
93
+ that chunk's time range in the Diart annotation.
94
+ Then store the speaker label(s) (or choose the most overlapping).
95
+ This modifies `chunks` in-place or returns a new list with assigned speakers.
96
+ """
97
+ if not self.segment_speakers:
98
+ return chunks
99
+
100
+ for segment in self.segment_speakers:
101
+ seg_beg = segment["beg"]
102
+ seg_end = segment["end"]
103
+ speaker = segment["speaker"]
104
+ for ch in chunks:
105
+ if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
106
+ continue
107
+ # We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
108
+ ch["speaker"] = speaker
109
+
110
+ return chunks
src/web/demo.png ADDED

Git LFS Details

  • SHA256: 1a2f01e0f79b60d0ed0c398726d04f6f99ad81b314e6f1885292f2d6c6c49b3c
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
src/web/live_transcription.html ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8"/>
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
6
+ <title>Audio Transcription</title>
7
+ <style>
8
+ body {
9
+ font-family: 'Inter', sans-serif;
10
+ margin: 20px;
11
+ text-align: center;
12
+ }
13
+ #recordButton {
14
+ width: 80px;
15
+ height: 80px;
16
+ font-size: 36px;
17
+ border: none;
18
+ border-radius: 50%;
19
+ background-color: white;
20
+ cursor: pointer;
21
+ box-shadow: 0 0px 10px rgba(0, 0, 0, 0.2);
22
+ transition: background-color 0.3s ease, transform 0.2s ease;
23
+ }
24
+ #recordButton.recording {
25
+ background-color: #ff4d4d;
26
+ color: white;
27
+ }
28
+ #recordButton:active {
29
+ transform: scale(0.95);
30
+ }
31
+ #status {
32
+ margin-top: 20px;
33
+ font-size: 16px;
34
+ color: #333;
35
+ }
36
+ .settings-container {
37
+ display: flex;
38
+ justify-content: center;
39
+ align-items: center;
40
+ gap: 15px;
41
+ margin-top: 20px;
42
+ }
43
+ .settings {
44
+ display: flex;
45
+ flex-direction: column;
46
+ align-items: flex-start;
47
+ gap: 5px;
48
+ }
49
+ #chunkSelector,
50
+ #websocketInput {
51
+ font-size: 16px;
52
+ padding: 5px;
53
+ border-radius: 5px;
54
+ border: 1px solid #ddd;
55
+ background-color: #f9f9f9;
56
+ }
57
+ #websocketInput {
58
+ width: 200px;
59
+ }
60
+ #chunkSelector:focus,
61
+ #websocketInput:focus {
62
+ outline: none;
63
+ border-color: #007bff;
64
+ }
65
+ label {
66
+ font-size: 14px;
67
+ }
68
+ /* Speaker-labeled transcript area */
69
+ #linesTranscript {
70
+ margin: 20px auto;
71
+ max-width: 600px;
72
+ text-align: left;
73
+ font-size: 16px;
74
+ }
75
+ #linesTranscript p {
76
+ margin: 5px 0;
77
+ }
78
+ #linesTranscript strong {
79
+ color: #333;
80
+ }
81
+ /* Grey buffer styling */
82
+ .buffer {
83
+ color: rgb(180, 180, 180);
84
+ font-style: italic;
85
+ margin-left: 4px;
86
+ }
87
+ </style>
88
+ </head>
89
+ <body>
90
+
91
+ <div class="settings-container">
92
+ <button id="recordButton">🎙️</button>
93
+ <div class="settings">
94
+ <div>
95
+ <label for="chunkSelector">Chunk size (ms):</label>
96
+ <select id="chunkSelector">
97
+ <option value="500" selected>500 ms</option>
98
+ <option value="1000">1000 ms</option>
99
+ <option value="2000">2000 ms</option>
100
+ <option value="3000">3000 ms</option>
101
+ <option value="4000">4000 ms</option>
102
+ <option value="5000">5000 ms</option>
103
+ </select>
104
+ </div>
105
+ <div>
106
+ <label for="websocketInput">WebSocket URL:</label>
107
+ <input id="websocketInput" type="text" value="ws://localhost:8000/asr" />
108
+ </div>
109
+ </div>
110
+ </div>
111
+
112
+ <p id="status"></p>
113
+
114
+ <!-- Speaker-labeled transcript -->
115
+ <div id="linesTranscript"></div>
116
+
117
+ <script>
118
+ let isRecording = false;
119
+ let websocket = null;
120
+ let recorder = null;
121
+ let chunkDuration = 500;
122
+ let websocketUrl = "ws://localhost:8000/asr";
123
+ let userClosing = false;
124
+
125
+ const statusText = document.getElementById("status");
126
+ const recordButton = document.getElementById("recordButton");
127
+ const chunkSelector = document.getElementById("chunkSelector");
128
+ const websocketInput = document.getElementById("websocketInput");
129
+ const linesTranscriptDiv = document.getElementById("linesTranscript");
130
+
131
+ chunkSelector.addEventListener("change", () => {
132
+ chunkDuration = parseInt(chunkSelector.value);
133
+ });
134
+
135
+ websocketInput.addEventListener("change", () => {
136
+ const urlValue = websocketInput.value.trim();
137
+ if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
138
+ statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
139
+ return;
140
+ }
141
+ websocketUrl = urlValue;
142
+ statusText.textContent = "WebSocket URL updated. Ready to connect.";
143
+ });
144
+
145
+ function setupWebSocket() {
146
+ return new Promise((resolve, reject) => {
147
+ try {
148
+ websocket = new WebSocket(websocketUrl);
149
+ } catch (error) {
150
+ statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
151
+ reject(error);
152
+ return;
153
+ }
154
+
155
+ websocket.onopen = () => {
156
+ statusText.textContent = "Connected to server.";
157
+ resolve();
158
+ };
159
+
160
+ websocket.onclose = () => {
161
+ if (userClosing) {
162
+ statusText.textContent = "WebSocket closed by user.";
163
+ } else {
164
+ statusText.textContent =
165
+ "Disconnected from the WebSocket server. (Check logs if model is loading.)";
166
+ }
167
+ userClosing = false;
168
+ };
169
+
170
+ websocket.onerror = () => {
171
+ statusText.textContent = "Error connecting to WebSocket.";
172
+ reject(new Error("Error connecting to WebSocket"));
173
+ };
174
+
175
+ // Handle messages from server
176
+ websocket.onmessage = (event) => {
177
+ const data = JSON.parse(event.data);
178
+ /*
179
+ The server might send:
180
+ {
181
+ "lines": [
182
+ {"speaker": 0, "text": "Hello."},
183
+ {"speaker": 1, "text": "Bonjour."},
184
+ ...
185
+ ],
186
+ "buffer": "..."
187
+ }
188
+ */
189
+ const { lines = [], buffer = "" } = data;
190
+ renderLinesWithBuffer(lines, buffer);
191
+ };
192
+ });
193
+ }
194
+
195
+ function renderLinesWithBuffer(lines, buffer) {
196
+ // Clears if no lines
197
+ if (!Array.isArray(lines) || lines.length === 0) {
198
+ linesTranscriptDiv.innerHTML = "";
199
+ return;
200
+ }
201
+ // Build the HTML
202
+ // The buffer is appended to the last line if it's non-empty
203
+ const linesHtml = lines.map((item, idx) => {
204
+ let textContent = item.text;
205
+ if (idx === lines.length - 1 && buffer) {
206
+ textContent += `<span class="buffer">${buffer}</span>`;
207
+ }
208
+ return `<p><strong>Speaker ${item.speaker}:</strong> ${textContent}</p>`;
209
+ }).join("");
210
+
211
+ linesTranscriptDiv.innerHTML = linesHtml;
212
+ }
213
+
214
+ async function startRecording() {
215
+ try {
216
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
217
+ recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
218
+ recorder.ondataavailable = (e) => {
219
+ if (websocket && websocket.readyState === WebSocket.OPEN) {
220
+ websocket.send(e.data);
221
+ }
222
+ };
223
+ recorder.start(chunkDuration);
224
+ isRecording = true;
225
+ updateUI();
226
+ } catch (err) {
227
+ statusText.textContent = "Error accessing microphone. Please allow microphone access.";
228
+ }
229
+ }
230
+
231
+ function stopRecording() {
232
+ userClosing = true;
233
+ if (recorder) {
234
+ recorder.stop();
235
+ recorder = null;
236
+ }
237
+ isRecording = false;
238
+
239
+ if (websocket) {
240
+ websocket.close();
241
+ websocket = null;
242
+ }
243
+
244
+ updateUI();
245
+ }
246
+
247
+ async function toggleRecording() {
248
+ if (!isRecording) {
249
+ linesTranscriptDiv.innerHTML = "";
250
+ try {
251
+ await setupWebSocket();
252
+ await startRecording();
253
+ } catch (err) {
254
+ statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
255
+ }
256
+ } else {
257
+ stopRecording();
258
+ }
259
+ }
260
+
261
+ function updateUI() {
262
+ recordButton.classList.toggle("recording", isRecording);
263
+ statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
264
+ }
265
+
266
+ recordButton.addEventListener("click", toggleRecording);
267
+ </script>
268
+ </body>
269
+ </html>
src/whisper/timestaped_words.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ class TimeStampedSegment:
4
+ """
5
+ Represents a segment of text with start and end timestamps.
6
+
7
+ Attributes:
8
+ start (float): The start time of the segment.
9
+ end (float): The end time of the segment.
10
+ text (str): The text of the segment.
11
+ """
12
+ def __init__(self, start: float, end: float, text: str):
13
+ self.start = start
14
+ self.end = end
15
+ self.text = text
16
+
17
+ def __str__(self):
18
+ return f'{self.start} - {self.end}: {self.text}'
19
+
20
+ def __repr__(self):
21
+ return self.__str__()
22
+
23
+ def shift(self, shift: float):
24
+ """
25
+ Shifts the segment by a given amount of time.
26
+
27
+ Args:
28
+ shift (float): The amount of time to shift the segment.
29
+
30
+ Returns:
31
+ TimeStampedSegment: A new segment shifted by the given amount of time.
32
+
33
+ Example:
34
+ >>> segment = TimeStampedSegment(0.0, 1.0, "Hello")
35
+ >>> segment.shift(1.0)
36
+ 1.0 - 2.0: Hello
37
+ """
38
+ return TimeStampedSegment(self.start + shift, self.end + shift, self.text)
39
+
40
+ def append_text(self, text: str):
41
+ """
42
+ Appends text to the segment.
43
+
44
+ Args:
45
+ text (str): The text to append.
46
+
47
+ Example:
48
+ >>> segment = TimeStampedSegment(0.0, 1.0, "Hello")
49
+ >>> segment.append_text("!")
50
+ >>> segment
51
+ 0.0 - 1.0: Hello!
52
+ """
53
+ self.text += text
54
+
55
+ def __eq__(self, other):
56
+ return self.start == other.start and self.end == other.end and self.text == other.text
57
+
58
+ def __add__(self, other):
59
+ if isinstance(other, (int, float)):
60
+ return self.shift(other)
61
+ elif isinstance(other, str):
62
+ return TimeStampedSegment(self.start, self.end, self.text + other)
63
+ else:
64
+ raise TypeError(f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'")
65
+
66
+ class TimeStampedText:
67
+ """
68
+ Represents a collection of TimeStampedSegment instances.
69
+
70
+ Attributes:
71
+ segments (List[TimeStampedSegment]): The list of segments.
72
+ """
73
+ def __init__(self):
74
+ self.segments: List[TimeStampedSegment] = []
75
+
76
+ def add_segment(self, segment: TimeStampedSegment):
77
+ """
78
+ Adds a segment to the collection.
79
+
80
+ Args:
81
+ segment (TimeStampedSegment): The segment to add.
82
+
83
+ Example:
84
+ >>> tst = TimeStampedText()
85
+ >>> tst.add_segment(TimeStampedSegment(0.0, 1.0, "Hello"))
86
+ >>> tst.add_segment(TimeStampedSegment(1.0, 2.0, "world"))
87
+ >>> len(tst)
88
+ 2
89
+ """
90
+ self.segments.append(segment)
91
+
92
+ def __repr__(self):
93
+ return f"TimeStampedText(segments={self.segments})"
94
+
95
+ def __iter__(self):
96
+ return iter(self.segments)
97
+
98
+ def __getitem__(self, index):
99
+ return self.segments[index]
100
+
101
+ def __len__(self):
102
+ return len(self.segments)
103
+
104
+ # TODO: a function from_whisper_res()
105
+
106
+ if __name__ == "__main__":
107
+ import doctest
108
+ doctest.testmod(verbose=True)
src/whisper_streaming/backends.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+ import io
5
+ import soundfile as sf
6
+ import math
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ASRBase:
12
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
13
+ # "" for faster-whisper because it emits the spaces when neeeded)
14
+
15
+ def __init__(
16
+ self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
17
+ ):
18
+ self.logfile = logfile
19
+
20
+ self.transcribe_kargs = {}
21
+ if lan == "auto":
22
+ self.original_language = None
23
+ else:
24
+ self.original_language = lan
25
+
26
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
27
+
28
+ def load_model(self, modelsize, cache_dir):
29
+ raise NotImplemented("must be implemented in the child class")
30
+
31
+ def transcribe(self, audio, init_prompt=""):
32
+ raise NotImplemented("must be implemented in the child class")
33
+
34
+ def use_vad(self):
35
+ raise NotImplemented("must be implemented in the child class")
36
+
37
+
38
+ class WhisperTimestampedASR(ASRBase):
39
+ """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
40
+ On the other hand, the installation for GPU could be easier.
41
+ """
42
+
43
+ sep = " "
44
+
45
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
46
+ import whisper
47
+ import whisper_timestamped
48
+ from whisper_timestamped import transcribe_timestamped
49
+
50
+ self.transcribe_timestamped = transcribe_timestamped
51
+ if model_dir is not None:
52
+ logger.debug("ignoring model_dir, not implemented")
53
+ return whisper.load_model(modelsize, download_root=cache_dir)
54
+
55
+ def transcribe(self, audio, init_prompt=""):
56
+ result = self.transcribe_timestamped(
57
+ self.model,
58
+ audio,
59
+ language=self.original_language,
60
+ initial_prompt=init_prompt,
61
+ verbose=None,
62
+ condition_on_previous_text=True,
63
+ **self.transcribe_kargs,
64
+ )
65
+ return result
66
+
67
+ def ts_words(self, r):
68
+ # return: transcribe result object to [(beg,end,"word1"), ...]
69
+ o = []
70
+ for s in r["segments"]:
71
+ for w in s["words"]:
72
+ t = (w["start"], w["end"], w["text"])
73
+ o.append(t)
74
+ return o
75
+
76
+ def segments_end_ts(self, res):
77
+ return [s["end"] for s in res["segments"]]
78
+
79
+ def use_vad(self):
80
+ self.transcribe_kargs["vad"] = True
81
+
82
+ def set_translate_task(self):
83
+ self.transcribe_kargs["task"] = "translate"
84
+
85
+
86
+ class FasterWhisperASR(ASRBase):
87
+ """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
88
+
89
+ sep = ""
90
+
91
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
92
+ from faster_whisper import WhisperModel
93
+
94
+ # logging.getLogger("faster_whisper").setLevel(logger.level)
95
+ if model_dir is not None:
96
+ logger.debug(
97
+ f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
98
+ )
99
+ model_size_or_path = model_dir
100
+ elif modelsize is not None:
101
+ model_size_or_path = modelsize
102
+ else:
103
+ raise ValueError("modelsize or model_dir parameter must be set")
104
+
105
+ # this worked fast and reliably on NVIDIA L40
106
+ model = WhisperModel(
107
+ model_size_or_path,
108
+ device="cuda",
109
+ compute_type="float16",
110
+ download_root=cache_dir,
111
+ )
112
+
113
+ # or run on GPU with INT8
114
+ # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
115
+ # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
116
+
117
+ # or run on CPU with INT8
118
+ # tested: works, but slow, appx 10-times than cuda FP16
119
+ # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
120
+ return model
121
+
122
+ def transcribe(self, audio, init_prompt=""):
123
+
124
+ # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
125
+ segments, info = self.model.transcribe(
126
+ audio,
127
+ language=self.original_language,
128
+ initial_prompt=init_prompt,
129
+ beam_size=5,
130
+ word_timestamps=True,
131
+ condition_on_previous_text=True,
132
+ **self.transcribe_kargs,
133
+ )
134
+ # print(info) # info contains language detection result
135
+
136
+ return list(segments)
137
+
138
+ def ts_words(self, segments):
139
+ o = []
140
+ for segment in segments:
141
+ for word in segment.words:
142
+ if segment.no_speech_prob > 0.9:
143
+ continue
144
+ # not stripping the spaces -- should not be merged with them!
145
+ w = word.word
146
+ t = (word.start, word.end, w)
147
+ o.append(t)
148
+ return o
149
+
150
+ def segments_end_ts(self, res):
151
+ return [s.end for s in res]
152
+
153
+ def use_vad(self):
154
+ self.transcribe_kargs["vad_filter"] = True
155
+
156
+ def set_translate_task(self):
157
+ self.transcribe_kargs["task"] = "translate"
158
+
159
+
160
+ class MLXWhisper(ASRBase):
161
+ """
162
+ Uses MPX Whisper library as the backend, optimized for Apple Silicon.
163
+ Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
164
+ Significantly faster than faster-whisper (without CUDA) on Apple M1.
165
+ """
166
+
167
+ sep = "" # In my experience in french it should also be no space.
168
+
169
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
170
+ """
171
+ Loads the MLX-compatible Whisper model.
172
+
173
+ Args:
174
+ modelsize (str, optional): The size or name of the Whisper model to load.
175
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
176
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
177
+ cache_dir (str, optional): Path to the directory for caching models.
178
+ **Note**: This is not supported by MLX Whisper and will be ignored.
179
+ model_dir (str, optional): Direct path to a custom model directory.
180
+ If specified, it overrides the `modelsize` parameter.
181
+ """
182
+ from mlx_whisper.transcribe import ModelHolder, transcribe
183
+ import mlx.core as mx
184
+
185
+ if model_dir is not None:
186
+ logger.debug(
187
+ f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
188
+ )
189
+ model_size_or_path = model_dir
190
+ elif modelsize is not None:
191
+ model_size_or_path = self.translate_model_name(modelsize)
192
+ logger.debug(
193
+ f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
194
+ )
195
+
196
+ self.model_size_or_path = model_size_or_path
197
+
198
+ # In mlx_whisper.transcribe, dtype is defined as:
199
+ # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
200
+ # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
201
+ dtype = mx.float16
202
+ ModelHolder.get_model(model_size_or_path, dtype)
203
+ return transcribe
204
+
205
+ def translate_model_name(self, model_name):
206
+ """
207
+ Translates a given model name to its corresponding MLX-compatible model path.
208
+
209
+ Args:
210
+ model_name (str): The name of the model to translate.
211
+
212
+ Returns:
213
+ str: The MLX-compatible model path.
214
+ """
215
+ # Dictionary mapping model names to MLX-compatible paths
216
+ model_mapping = {
217
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx",
218
+ "tiny": "mlx-community/whisper-tiny-mlx",
219
+ "base.en": "mlx-community/whisper-base.en-mlx",
220
+ "base": "mlx-community/whisper-base-mlx",
221
+ "small.en": "mlx-community/whisper-small.en-mlx",
222
+ "small": "mlx-community/whisper-small-mlx",
223
+ "medium.en": "mlx-community/whisper-medium.en-mlx",
224
+ "medium": "mlx-community/whisper-medium-mlx",
225
+ "large-v1": "mlx-community/whisper-large-v1-mlx",
226
+ "large-v2": "mlx-community/whisper-large-v2-mlx",
227
+ "large-v3": "mlx-community/whisper-large-v3-mlx",
228
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
229
+ "large": "mlx-community/whisper-large-mlx",
230
+ }
231
+
232
+ # Retrieve the corresponding MLX model path
233
+ mlx_model_path = model_mapping.get(model_name)
234
+
235
+ if mlx_model_path:
236
+ return mlx_model_path
237
+ else:
238
+ raise ValueError(
239
+ f"Model name '{model_name}' is not recognized or not supported."
240
+ )
241
+
242
+ def transcribe(self, audio, init_prompt=""):
243
+ if self.transcribe_kargs:
244
+ logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
245
+ segments = self.model(
246
+ audio,
247
+ language=self.original_language,
248
+ initial_prompt=init_prompt,
249
+ word_timestamps=True,
250
+ condition_on_previous_text=True,
251
+ path_or_hf_repo=self.model_size_or_path,
252
+ )
253
+ return segments.get("segments", [])
254
+
255
+ def ts_words(self, segments):
256
+ """
257
+ Extract timestamped words from transcription segments and skips words with high no-speech probability.
258
+ """
259
+ return [
260
+ (word["start"], word["end"], word["word"])
261
+ for segment in segments
262
+ for word in segment.get("words", [])
263
+ if segment.get("no_speech_prob", 0) <= 0.9
264
+ ]
265
+
266
+ def segments_end_ts(self, res):
267
+ return [s["end"] for s in res]
268
+
269
+ def use_vad(self):
270
+ self.transcribe_kargs["vad_filter"] = True
271
+
272
+ def set_translate_task(self):
273
+ self.transcribe_kargs["task"] = "translate"
274
+
275
+
276
+ class OpenaiApiASR(ASRBase):
277
+ """Uses OpenAI's Whisper API for audio transcription."""
278
+
279
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
280
+ self.logfile = logfile
281
+
282
+ self.modelname = "whisper-1"
283
+ self.original_language = (
284
+ None if lan == "auto" else lan
285
+ ) # ISO-639-1 language code
286
+ self.response_format = "verbose_json"
287
+ self.temperature = temperature
288
+
289
+ self.load_model()
290
+
291
+ self.use_vad_opt = False
292
+
293
+ # reset the task in set_translate_task
294
+ self.task = "transcribe"
295
+
296
+ def load_model(self, *args, **kwargs):
297
+ from openai import OpenAI
298
+
299
+ self.client = OpenAI()
300
+
301
+ self.transcribed_seconds = (
302
+ 0 # for logging how many seconds were processed by API, to know the cost
303
+ )
304
+
305
+ def ts_words(self, segments):
306
+ no_speech_segments = []
307
+ if self.use_vad_opt:
308
+ for segment in segments.segments:
309
+ # TODO: threshold can be set from outside
310
+ if segment["no_speech_prob"] > 0.8:
311
+ no_speech_segments.append(
312
+ (segment.get("start"), segment.get("end"))
313
+ )
314
+
315
+ o = []
316
+ for word in segments.words:
317
+ start = word.start
318
+ end = word.end
319
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
320
+ # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
321
+ continue
322
+ o.append((start, end, word.word))
323
+ return o
324
+
325
+ def segments_end_ts(self, res):
326
+ return [s.end for s in res.words]
327
+
328
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs):
329
+ # Write the audio data to a buffer
330
+ buffer = io.BytesIO()
331
+ buffer.name = "temp.wav"
332
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
333
+ buffer.seek(0) # Reset buffer's position to the beginning
334
+
335
+ self.transcribed_seconds += math.ceil(
336
+ len(audio_data) / 16000
337
+ ) # it rounds up to the whole seconds
338
+
339
+ params = {
340
+ "model": self.modelname,
341
+ "file": buffer,
342
+ "response_format": self.response_format,
343
+ "temperature": self.temperature,
344
+ "timestamp_granularities": ["word", "segment"],
345
+ }
346
+ if self.task != "translate" and self.original_language:
347
+ params["language"] = self.original_language
348
+ if prompt:
349
+ params["prompt"] = prompt
350
+
351
+ if self.task == "translate":
352
+ proc = self.client.audio.translations
353
+ else:
354
+ proc = self.client.audio.transcriptions
355
+
356
+ # Process transcription/translation
357
+ transcript = proc.create(**params)
358
+ logger.debug(
359
+ f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
360
+ )
361
+
362
+ return transcript
363
+
364
+ def use_vad(self):
365
+ self.use_vad_opt = True
366
+
367
+ def set_translate_task(self):
368
+ self.task = "translate"
src/whisper_streaming/online_asr.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class HypothesisBuffer:
8
+
9
+ def __init__(self, logfile=sys.stderr):
10
+ self.commited_in_buffer = []
11
+ self.buffer = []
12
+ self.new = []
13
+
14
+ self.last_commited_time = 0
15
+ self.last_commited_word = None
16
+
17
+ self.logfile = logfile
18
+
19
+ def insert(self, new, offset):
20
+ """
21
+ compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
22
+ The new tail is added to self.new
23
+ """
24
+
25
+ new = [(a + offset, b + offset, t) for a, b, t in new]
26
+ self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
27
+
28
+ if len(self.new) >= 1:
29
+ a, b, t = self.new[0]
30
+ if abs(a - self.last_commited_time) < 1:
31
+ if self.commited_in_buffer:
32
+ # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
33
+ cn = len(self.commited_in_buffer)
34
+ nn = len(self.new)
35
+ for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
36
+ c = " ".join(
37
+ [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
38
+ ::-1
39
+ ]
40
+ )
41
+ tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
42
+ if c == tail:
43
+ words = []
44
+ for j in range(i):
45
+ words.append(repr(self.new.pop(0)))
46
+ words_msg = " ".join(words)
47
+ logger.debug(f"removing last {i} words: {words_msg}")
48
+ break
49
+
50
+ def flush(self):
51
+ # returns commited chunk = the longest common prefix of 2 last inserts.
52
+
53
+ commit = []
54
+ while self.new:
55
+ na, nb, nt = self.new[0]
56
+
57
+ if len(self.buffer) == 0:
58
+ break
59
+
60
+ if nt == self.buffer[0][2]:
61
+ commit.append((na, nb, nt))
62
+ self.last_commited_word = nt
63
+ self.last_commited_time = nb
64
+ self.buffer.pop(0)
65
+ self.new.pop(0)
66
+ else:
67
+ break
68
+ self.buffer = self.new
69
+ self.new = []
70
+ self.commited_in_buffer.extend(commit)
71
+ return commit
72
+
73
+ def pop_commited(self, time):
74
+ "Remove (from the beginning) of commited_in_buffer all the words that are finished before `time`"
75
+ while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
76
+ self.commited_in_buffer.pop(0)
77
+
78
+ def complete(self):
79
+ return self.buffer
80
+
81
+
82
+
83
+
84
+
85
+ class OnlineASRProcessor:
86
+
87
+ SAMPLING_RATE = 16000
88
+
89
+ def __init__(
90
+ self,
91
+ asr,
92
+ tokenize_method=None,
93
+ buffer_trimming=("segment", 15),
94
+ logfile=sys.stderr,
95
+ ):
96
+ """
97
+ Initialize OnlineASRProcessor.
98
+
99
+ Args:
100
+ asr: WhisperASR object
101
+ tokenize_method: Sentence tokenizer function for the target language.
102
+ Must be a function that takes a list of text as input like MosesSentenceSplitter.
103
+ Can be None if using "segment" buffer trimming option.
104
+ buffer_trimming: Tuple of (option, seconds) where:
105
+ - option: Either "sentence" or "segment"
106
+ - seconds: Number of seconds threshold for buffer trimming
107
+ Default is ("segment", 15)
108
+ logfile: File to store logs
109
+
110
+ """
111
+ self.asr = asr
112
+ self.tokenize = tokenize_method
113
+ self.logfile = logfile
114
+
115
+ self.init()
116
+
117
+ self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
118
+
119
+ if self.buffer_trimming_way not in ["sentence", "segment"]:
120
+ raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
121
+ if self.buffer_trimming_sec <= 0:
122
+ raise ValueError("buffer_trimming_sec must be positive")
123
+ elif self.buffer_trimming_sec > 30:
124
+ logger.warning(
125
+ f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
126
+ )
127
+
128
+ def init(self, offset=None):
129
+ """run this when starting or restarting processing"""
130
+ self.audio_buffer = np.array([], dtype=np.float32)
131
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
132
+ self.buffer_time_offset = 0
133
+ if offset is not None:
134
+ self.buffer_time_offset = offset
135
+ self.transcript_buffer.last_commited_time = self.buffer_time_offset
136
+ self.final_transcript = []
137
+ self.commited_not_final = []
138
+
139
+
140
+ def insert_audio_chunk(self, audio):
141
+ self.audio_buffer = np.append(self.audio_buffer, audio)
142
+
143
+ def prompt(self):
144
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
145
+ "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
146
+
147
+
148
+ """
149
+
150
+ if len(self.final_transcript) == 0:
151
+ prompt=""
152
+
153
+ if len(self.final_transcript) == 1:
154
+ prompt = self.final_transcript[0][2][-200:]
155
+
156
+ else:
157
+ prompt = self.concatenate_tsw(self.final_transcript)[2][-200:]
158
+ # TODO: this is not ideal as we concatenate each time the whole transcript
159
+
160
+ # k = max(0, len(self.final_transcript) - 1)
161
+ # while k > 1 and self.final_transcript[k - 1][1] > self.buffer_time_offset:
162
+ # k -= 1
163
+
164
+ # p = self.final_transcript[:k]
165
+
166
+
167
+ # p = [t for _, _, t in p]
168
+ # prompt = []
169
+ # l = 0
170
+ # while p and l < 200: # 200 characters prompt size
171
+ # x = p.pop(-1)
172
+ # l += len(x) + 1
173
+ # prompt.append(x)
174
+
175
+ non_prompt = self.concatenate_tsw(self.commited_not_final)[2]
176
+
177
+ logger.debug(f"PROMPT(previous): {prompt[:20]}…{prompt[-20:]} (length={len(prompt)}chars)")
178
+ logger.debug(f"CONTEXT: {non_prompt}")
179
+
180
+ return prompt, non_prompt
181
+
182
+
183
+ def process_iter(self):
184
+ """Runs on the current audio buffer.
185
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
186
+ The non-emty text is confirmed (committed) partial transcript.
187
+ """
188
+
189
+ prompt, non_prompt = self.prompt()
190
+
191
+ logger.debug(
192
+ f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}"
193
+ )
194
+
195
+ ## Transcribe and format the result to [(beg,end,"word1"), ...]
196
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
197
+ tsw = self.asr.ts_words(res)
198
+
199
+
200
+ # insert into HypothesisBuffer, and get back the commited words
201
+ self.transcript_buffer.insert(tsw, self.buffer_time_offset)
202
+ commited_tsw = self.transcript_buffer.flush()
203
+
204
+ if len(commited_tsw) == 0:
205
+ return (None, None, "")
206
+
207
+
208
+ self.commited_not_final.extend(commited_tsw)
209
+
210
+
211
+ # Define `completed` and `the_rest` based on the buffer_trimming_way
212
+ # completed will be returned at the end of the function.
213
+ # completed is a transcribed text with (beg,end,"sentence ...") format.
214
+
215
+
216
+ completed = []
217
+ if self.buffer_trimming_way == "sentence":
218
+
219
+ sentences = self.words_to_sentences(self.commited_not_final)
220
+
221
+
222
+
223
+ if len(sentences) < 2:
224
+ logger.debug(f"[Sentence-segmentation] no full sentence segmented, do not commit anything.")
225
+
226
+
227
+
228
+
229
+ else:
230
+ identified_sentence= "\n - ".join([f"{s[0]*1000:.0f}-{s[1]*1000:.0f} {s[2]}" for s in sentences])
231
+ logger.debug(f"[Sentence-segmentation] identified sentences:\n - {identified_sentence}")
232
+
233
+ # assume last sentence is incomplete, which is not always true
234
+
235
+ # we will continue with audio processing at this timestamp
236
+ chunk_at = sentences[-2][1]
237
+
238
+ self.chunk_at(chunk_at)
239
+ # TODO: here paragraph breaks can be added
240
+ self.commited_not_final = sentences[-1:]
241
+
242
+ completed= sentences[:-1]
243
+
244
+
245
+
246
+
247
+
248
+ # break audio buffer anyway if it is too long
249
+
250
+ if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec :
251
+
252
+ if self.buffer_trimming_way == "sentence":
253
+ logger.warning(f"Chunck segment after {self.buffer_trimming_sec} seconds!"
254
+ " Even if no sentence was found!"
255
+ )
256
+
257
+
258
+
259
+
260
+ completed = self.chunk_completed_segment()
261
+
262
+
263
+
264
+
265
+
266
+
267
+ if len(completed) == 0:
268
+ return (None, None, "")
269
+ else:
270
+ self.final_transcript.extend(completed) # add whole time stamped sentences / or words to commited list
271
+
272
+
273
+ completed_text_segment= self.concatenate_tsw(completed)
274
+
275
+ the_rest = self.concatenate_tsw(self.transcript_buffer.complete())
276
+ commited_but_not_final = self.concatenate_tsw(self.commited_not_final)
277
+ logger.debug(f"\n COMPLETE NOW: {completed_text_segment[2]}\n"
278
+ f" COMMITTED (but not Final): {commited_but_not_final[2]}\n"
279
+ f" INCOMPLETE: {the_rest[2]}"
280
+ )
281
+
282
+
283
+ return completed_text_segment
284
+
285
+
286
+ def chunk_completed_segment(self) -> list:
287
+
288
+
289
+ ts_words = self.commited_not_final
290
+
291
+ if len(ts_words) <= 1:
292
+ logger.debug(f"--- not enough segments to chunk (<=1 words)")
293
+ return []
294
+ else:
295
+
296
+ ends = [w[1] for w in ts_words]
297
+
298
+ t = ts_words[-1][1] # start of the last word
299
+ e = ends[-2]
300
+ while len(ends) > 2 and e > t:
301
+ ends.pop(-1)
302
+ e = ends[-2]
303
+
304
+ if e <= t:
305
+
306
+ self.chunk_at(e)
307
+
308
+ n_commited_words = len(ends)-1
309
+
310
+ words_to_commit = ts_words[:n_commited_words]
311
+ self.final_transcript.extend(words_to_commit)
312
+ self.commited_not_final = ts_words[n_commited_words:]
313
+
314
+ return words_to_commit
315
+
316
+
317
+
318
+ else:
319
+ logger.debug(f"--- last segment not within commited area")
320
+ return []
321
+
322
+
323
+ def chunk_at(self, time):
324
+ """trims the hypothesis and audio buffer at "time" """
325
+ logger.debug(f"chunking at {time:2.2f}s")
326
+
327
+ logger.debug(
328
+ f"len of audio buffer before chunking is: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
329
+ )
330
+
331
+
332
+ self.transcript_buffer.pop_commited(time)
333
+ cut_seconds = time - self.buffer_time_offset
334
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE) :]
335
+ self.buffer_time_offset = time
336
+
337
+ logger.debug(
338
+ f"len of audio buffer is now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}s"
339
+ )
340
+
341
+ def words_to_sentences(self, words):
342
+ """Uses self.tokenize for sentence segmentation of words.
343
+ Returns: [(beg,end,"sentence 1"),...]
344
+ """
345
+
346
+
347
+ cwords = [w for w in words]
348
+ t = self.asr.sep.join(o[2] for o in cwords)
349
+ logger.debug(f"[Sentence-segmentation] Raw Text: {t}")
350
+
351
+ s = self.tokenize([t])
352
+ out = []
353
+ while s:
354
+ beg = None
355
+ end = None
356
+ sent = s.pop(0).strip()
357
+ fsent = sent
358
+ while cwords:
359
+ b, e, w = cwords.pop(0)
360
+ w = w.strip()
361
+ if beg is None and sent.startswith(w):
362
+ beg = b
363
+ if end is None and sent == w:
364
+ end = e
365
+ if beg is not None and end is not None:
366
+ out.append((beg, end, fsent))
367
+ break
368
+ sent = sent[len(w) :].strip()
369
+
370
+ return out
371
+
372
+ def finish(self):
373
+ """Flush the incomplete text when the whole processing ends.
374
+ Returns: the same format as self.process_iter()
375
+ """
376
+ o = self.transcript_buffer.complete()
377
+ f = self.concatenate_tsw(o)
378
+ if f[1] is not None:
379
+ logger.debug(f"last, noncommited: {f[0]*1000:.0f}-{f[1]*1000:.0f}: {f[2]}")
380
+ self.buffer_time_offset += len(self.audio_buffer) / 16000
381
+ return f
382
+
383
+ def concatenate_tsw(
384
+ self,
385
+ tsw,
386
+ sep=None,
387
+ offset=0,
388
+ ):
389
+ # concatenates the timestamped words or sentences into one sequence that is flushed in one line
390
+ # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
391
+ # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
392
+ if sep is None:
393
+ sep = self.asr.sep
394
+
395
+
396
+
397
+ t = sep.join(s[2] for s in tsw)
398
+ if len(tsw) == 0:
399
+ b = None
400
+ e = None
401
+ else:
402
+ b = offset + tsw[0][0]
403
+ e = offset + tsw[-1][1]
404
+ return (b, e, t)
405
+
406
+
407
+ class VACOnlineASRProcessor(OnlineASRProcessor):
408
+ """Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
409
+
410
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
411
+ it runs VAD and continuously detects whether there is speech or not.
412
+ When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
413
+ """
414
+
415
+ # TODO: VACOnlineASRProcessor does not break after chunch length is reached, this can lead to overflow!
416
+
417
+ def __init__(self, online_chunk_size, *a, **kw):
418
+ self.online_chunk_size = online_chunk_size
419
+
420
+ self.online = OnlineASRProcessor(*a, **kw)
421
+
422
+ # VAC:
423
+ import torch
424
+
425
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
426
+ from src.whisper_streaming.silero_vad_iterator import FixedVADIterator
427
+
428
+ self.vac = FixedVADIterator(
429
+ model
430
+ ) # we use the default options there: 500ms silence, 100ms padding, etc.
431
+
432
+ self.logfile = self.online.logfile
433
+ self.init()
434
+
435
+ def init(self):
436
+ self.online.init()
437
+ self.vac.reset_states()
438
+ self.current_online_chunk_buffer_size = 0
439
+
440
+ self.is_currently_final = False
441
+
442
+ self.status = None # or "voice" or "nonvoice"
443
+ self.audio_buffer = np.array([], dtype=np.float32)
444
+ self.buffer_offset = 0 # in frames
445
+
446
+ def clear_buffer(self):
447
+ self.buffer_offset += len(self.audio_buffer)
448
+ self.audio_buffer = np.array([], dtype=np.float32)
449
+
450
+ def insert_audio_chunk(self, audio):
451
+ res = self.vac(audio)
452
+ self.audio_buffer = np.append(self.audio_buffer, audio)
453
+
454
+ if res is not None:
455
+ frame = list(res.values())[0] - self.buffer_offset
456
+ if "start" in res and "end" not in res:
457
+ self.status = "voice"
458
+ send_audio = self.audio_buffer[frame:]
459
+ self.online.init(
460
+ offset=(frame + self.buffer_offset) / self.SAMPLING_RATE
461
+ )
462
+ self.online.insert_audio_chunk(send_audio)
463
+ self.current_online_chunk_buffer_size += len(send_audio)
464
+ self.clear_buffer()
465
+ elif "end" in res and "start" not in res:
466
+ self.status = "nonvoice"
467
+ send_audio = self.audio_buffer[:frame]
468
+ self.online.insert_audio_chunk(send_audio)
469
+ self.current_online_chunk_buffer_size += len(send_audio)
470
+ self.is_currently_final = True
471
+ self.clear_buffer()
472
+ else:
473
+ beg = res["start"] - self.buffer_offset
474
+ end = res["end"] - self.buffer_offset
475
+ self.status = "nonvoice"
476
+ send_audio = self.audio_buffer[beg:end]
477
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
478
+ self.online.insert_audio_chunk(send_audio)
479
+ self.current_online_chunk_buffer_size += len(send_audio)
480
+ self.is_currently_final = True
481
+ self.clear_buffer()
482
+ else:
483
+ if self.status == "voice":
484
+ self.online.insert_audio_chunk(self.audio_buffer)
485
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
486
+ self.clear_buffer()
487
+ else:
488
+ # We keep 1 second because VAD may later find start of voice in it.
489
+ # But we trim it to prevent OOM.
490
+ self.buffer_offset += max(
491
+ 0, len(self.audio_buffer) - self.SAMPLING_RATE
492
+ )
493
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE :]
494
+
495
+ def process_iter(self):
496
+ if self.is_currently_final:
497
+ return self.finish()
498
+ elif (
499
+ self.current_online_chunk_buffer_size
500
+ > self.SAMPLING_RATE * self.online_chunk_size
501
+ ):
502
+ self.current_online_chunk_buffer_size = 0
503
+ ret = self.online.process_iter()
504
+ return ret
505
+ else:
506
+ logger.debug("no online update, only VAD")
507
+ return (None, None, "")
508
+
509
+ def finish(self):
510
+ ret = self.online.finish()
511
+ self.current_online_chunk_buffer_size = 0
512
+ self.is_currently_final = False
513
+ return ret
src/whisper_streaming/silero_vad_iterator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # This is copied from silero-vad's vad_utils.py:
4
+ # https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
5
+ # (except changed defaults)
6
+
7
+ # Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8
+
9
+
10
+ class VADIterator:
11
+ def __init__(
12
+ self,
13
+ model,
14
+ threshold: float = 0.5,
15
+ sampling_rate: int = 16000,
16
+ min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
17
+ speech_pad_ms: int = 100, # same
18
+ ):
19
+ """
20
+ Class for stream imitation
21
+
22
+ Parameters
23
+ ----------
24
+ model: preloaded .jit silero VAD model
25
+
26
+ threshold: float (default - 0.5)
27
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
28
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
29
+
30
+ sampling_rate: int (default - 16000)
31
+ Currently silero VAD models support 8000 and 16000 sample rates
32
+
33
+ min_silence_duration_ms: int (default - 100 milliseconds)
34
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
35
+
36
+ speech_pad_ms: int (default - 30 milliseconds)
37
+ Final speech chunks are padded by speech_pad_ms each side
38
+ """
39
+
40
+ self.model = model
41
+ self.threshold = threshold
42
+ self.sampling_rate = sampling_rate
43
+
44
+ if sampling_rate not in [8000, 16000]:
45
+ raise ValueError(
46
+ "VADIterator does not support sampling rates other than [8000, 16000]"
47
+ )
48
+
49
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
50
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
51
+ self.reset_states()
52
+
53
+ def reset_states(self):
54
+
55
+ self.model.reset_states()
56
+ self.triggered = False
57
+ self.temp_end = 0
58
+ self.current_sample = 0
59
+
60
+ def __call__(self, x, return_seconds=False):
61
+ """
62
+ x: torch.Tensor
63
+ audio chunk (see examples in repo)
64
+
65
+ return_seconds: bool (default - False)
66
+ whether return timestamps in seconds (default - samples)
67
+ """
68
+
69
+ if not torch.is_tensor(x):
70
+ try:
71
+ x = torch.Tensor(x)
72
+ except:
73
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
74
+
75
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
76
+ self.current_sample += window_size_samples
77
+
78
+ speech_prob = self.model(x, self.sampling_rate).item()
79
+
80
+ if (speech_prob >= self.threshold) and self.temp_end:
81
+ self.temp_end = 0
82
+
83
+ if (speech_prob >= self.threshold) and not self.triggered:
84
+ self.triggered = True
85
+ speech_start = self.current_sample - self.speech_pad_samples
86
+ return {
87
+ "start": (
88
+ int(speech_start)
89
+ if not return_seconds
90
+ else round(speech_start / self.sampling_rate, 1)
91
+ )
92
+ }
93
+
94
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
95
+ if not self.temp_end:
96
+ self.temp_end = self.current_sample
97
+ if self.current_sample - self.temp_end < self.min_silence_samples:
98
+ return None
99
+ else:
100
+ speech_end = self.temp_end + self.speech_pad_samples
101
+ self.temp_end = 0
102
+ self.triggered = False
103
+ return {
104
+ "end": (
105
+ int(speech_end)
106
+ if not return_seconds
107
+ else round(speech_end / self.sampling_rate, 1)
108
+ )
109
+ }
110
+
111
+ return None
112
+
113
+
114
+ #######################
115
+ # because Silero now requires exactly 512-sized audio chunks
116
+
117
+ import numpy as np
118
+
119
+
120
+ class FixedVADIterator(VADIterator):
121
+ """It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
122
+ If audio to be processed at once is long and multiple voiced segments detected,
123
+ then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
124
+ """
125
+
126
+ def reset_states(self):
127
+ super().reset_states()
128
+ self.buffer = np.array([], dtype=np.float32)
129
+
130
+ def __call__(self, x, return_seconds=False):
131
+ self.buffer = np.append(self.buffer, x)
132
+ ret = None
133
+ while len(self.buffer) >= 512:
134
+ r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
135
+ self.buffer = self.buffer[512:]
136
+ if ret is None:
137
+ ret = r
138
+ elif r is not None:
139
+ if "end" in r:
140
+ ret["end"] = r["end"] # the latter end
141
+ if "start" in r and "end" in ret: # there is an earlier start.
142
+ # Remove end, merging this segment with the previous one.
143
+ del ret["end"]
144
+ return ret if ret != {} else None
145
+
146
+
147
+ if __name__ == "__main__":
148
+ # test/demonstrate the need for FixedVADIterator:
149
+
150
+ import torch
151
+
152
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
153
+ vac = FixedVADIterator(model)
154
+ # vac = VADIterator(model) # the second case crashes with this
155
+
156
+ # this works: for both
157
+ audio_buffer = np.array([0] * (512), dtype=np.float32)
158
+ vac(audio_buffer)
159
+
160
+ # this crashes on the non FixedVADIterator with
161
+ # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
162
+ audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
163
+ vac(audio_buffer)
src/whisper_streaming/whisper_online.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import numpy as np
4
+ import librosa
5
+ from functools import lru_cache
6
+ import time
7
+ import logging
8
+ from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
9
+ from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+
15
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
16
+ ","
17
+ )
18
+
19
+
20
+ def create_tokenizer(lan):
21
+ """returns an object that has split function that works like the one of MosesTokenizer"""
22
+
23
+ assert (
24
+ lan in WHISPER_LANG_CODES
25
+ ), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
26
+
27
+ if lan == "uk":
28
+ import tokenize_uk
29
+
30
+ class UkrainianTokenizer:
31
+ def split(self, text):
32
+ return tokenize_uk.tokenize_sents(text)
33
+
34
+ return UkrainianTokenizer()
35
+
36
+ # supported by fast-mosestokenizer
37
+ if (
38
+ lan
39
+ in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
40
+ ):
41
+ from mosestokenizer import MosesSentenceSplitter
42
+
43
+ return MosesSentenceSplitter(lan)
44
+
45
+ # the following languages are in Whisper, but not in wtpsplit:
46
+ if (
47
+ lan
48
+ in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
49
+ ):
50
+ logger.debug(
51
+ f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
52
+ )
53
+ lan = None
54
+
55
+ from wtpsplit import WtP
56
+
57
+ # downloads the model from huggingface on the first use
58
+ wtp = WtP("wtp-canine-s-12l-no-adapters")
59
+
60
+ class WtPtok:
61
+ def split(self, sent):
62
+ return wtp.split(sent, lang_code=lan)
63
+
64
+ return WtPtok()
65
+
66
+
67
+ def add_shared_args(parser):
68
+ """shared args for simulation (this entry point) and server
69
+ parser: argparse.ArgumentParser object
70
+ """
71
+ parser.add_argument(
72
+ "--min-chunk-size",
73
+ type=float,
74
+ default=1.0,
75
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
76
+ )
77
+ parser.add_argument(
78
+ "--model",
79
+ type=str,
80
+ # default="large-v3-turbo",
81
+ # choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
82
+ # ","
83
+ # ),
84
+ help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.",
85
+ )
86
+ parser.add_argument(
87
+ "--model_cache_dir",
88
+ type=str,
89
+ default=None,
90
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
91
+ )
92
+ parser.add_argument(
93
+ "--model_dir",
94
+ type=str,
95
+ default=None,
96
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
97
+ )
98
+ parser.add_argument(
99
+ "--lan",
100
+ "--language",
101
+ type=str,
102
+ default="auto",
103
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
104
+ )
105
+ parser.add_argument(
106
+ "--task",
107
+ type=str,
108
+ default="transcribe",
109
+ choices=["transcribe", "translate"],
110
+ help="Transcribe or translate.",
111
+ )
112
+ parser.add_argument(
113
+ "--backend",
114
+ type=str,
115
+ default="faster-whisper",
116
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
117
+ help="Load only this backend for Whisper processing.",
118
+ )
119
+ parser.add_argument(
120
+ "--vac",
121
+ action="store_true",
122
+ default=False,
123
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
124
+ )
125
+ parser.add_argument(
126
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
127
+ )
128
+ parser.add_argument(
129
+ "--vad",
130
+ action="store_true",
131
+ default=False,
132
+ help="Use VAD = voice activity detection, with the default parameters.",
133
+ )
134
+ parser.add_argument(
135
+ "--buffer_trimming",
136
+ type=str,
137
+ default="segment",
138
+ choices=["sentence", "segment"],
139
+ help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
140
+ )
141
+ parser.add_argument(
142
+ "--buffer_trimming_sec",
143
+ type=float,
144
+ default=15,
145
+ help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
146
+ )
147
+ parser.add_argument(
148
+ "-l",
149
+ "--log-level",
150
+ dest="log_level",
151
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
152
+ help="Set the log level",
153
+ default="DEBUG",
154
+ )
155
+
156
+ def backend_factory(args):
157
+ backend = args.backend
158
+ if backend == "openai-api":
159
+ logger.debug("Using OpenAI API.")
160
+ asr = OpenaiApiASR(lan=args.lan)
161
+ else:
162
+ if backend == "faster-whisper":
163
+ asr_cls = FasterWhisperASR
164
+ elif backend == "mlx-whisper":
165
+ asr_cls = MLXWhisper
166
+ else:
167
+ asr_cls = WhisperTimestampedASR
168
+
169
+ # Only for FasterWhisperASR and WhisperTimestampedASR
170
+ size = args.model
171
+ t = time.time()
172
+ logger.info(f"Loading Whisper {size} model for {args.lan}...")
173
+ asr = asr_cls(
174
+ modelsize=size,
175
+ lan=args.lan,
176
+ cache_dir=args.model_cache_dir,
177
+ model_dir=args.model_dir,
178
+ )
179
+ e = time.time()
180
+ logger.info(f"done. It took {round(e-t,2)} seconds.")
181
+
182
+ # Apply common configurations
183
+ if getattr(args, "vad", False): # Checks if VAD argument is present and True
184
+ logger.info("Setting VAD filter")
185
+ asr.use_vad()
186
+
187
+ language = args.lan
188
+ if args.task == "translate":
189
+ asr.set_translate_task()
190
+ tgt_language = "en" # Whisper translates into English
191
+ else:
192
+ tgt_language = language # Whisper transcribes in this language
193
+
194
+ # Create the tokenizer
195
+ if args.buffer_trimming == "sentence":
196
+
197
+ tokenizer = create_tokenizer(tgt_language)
198
+ else:
199
+ tokenizer = None
200
+ return asr, tokenizer
201
+
202
+ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
203
+ if args.vac:
204
+ online = VACOnlineASRProcessor(
205
+ args.min_chunk_size,
206
+ asr,
207
+ tokenizer,
208
+ logfile=logfile,
209
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
210
+ )
211
+ else:
212
+ online = OnlineASRProcessor(
213
+ asr,
214
+ tokenizer,
215
+ logfile=logfile,
216
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
217
+ )
218
+ return online
219
+
220
+ def asr_factory(args, logfile=sys.stderr):
221
+ """
222
+ Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
223
+ """
224
+ asr, tokenizer = backend_factory(args)
225
+ online = online_factory(args, asr, tokenizer, logfile=logfile)
226
+ return asr, online
227
+
228
+ def set_logging(args, logger, others=[]):
229
+ logging.basicConfig(format="%(levelname)s\t%(message)s") # format='%(name)s
230
+ logger.setLevel(args.log_level)
231
+
232
+ for other in others:
233
+ logging.getLogger(other).setLevel(args.log_level)
234
+
235
+
whisper_fastapi_online_server.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import argparse
4
+ import asyncio
5
+ import numpy as np
6
+ import ffmpeg
7
+ from time import time
8
+
9
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
+ from fastapi.responses import HTMLResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
+ from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
14
+
15
+
16
+ import logging
17
+ import logging.config
18
+ from transformers import pipeline
19
+
20
+ MODEL_NAME = 'Helsinki-NLP/opus-tatoeba-en-ja'
21
+ TRANSLATOR = pipeline('translation', model=MODEL_NAME, device='cuda')
22
+ TRANSLATOR('Warming up!')
23
+
24
+ API_KEY = '3c2b8b0f-4fa9-4eb7-b67d-7cae25546051:fx' # 自身の API キーを指定
25
+
26
+ SOURCE_LANG = 'EN'
27
+ TARGET_LANG = 'JA'
28
+
29
+ def translator_wrapper(source_text, mode='deepl'):
30
+ if mode == 'deepl':
31
+ params = {
32
+ 'auth_key' : API_KEY,
33
+ 'text' : source_text,
34
+ 'source_lang' : SOURCE_LANG, # 翻訳対象の言語
35
+ "target_lang": TARGET_LANG # 翻訳後の言語
36
+ }
37
+
38
+ # リクエストを投げる
39
+ try:
40
+ request = requests.post("https://api-free.deepl.com/v2/translate", data=params, timeout=5) # URIは有償版, 無償版で異なるため要注意
41
+ result = request.json()['translations'][0]['text']
42
+ except requests.exceptions.Timeout:
43
+ result = "(timed out)"
44
+ return result
45
+
46
+ elif mode == 'marianmt':
47
+ return TRANSLATOR(source_text)[0]['translation_text']
48
+
49
+ elif mode == 'google':
50
+ import requests
51
+
52
+ # https://www.eyoucms.com/news/ziliao/other/29445.html
53
+ language_type = ""
54
+ target = 'ja-jp'
55
+ url = "https://translation.googleapis.com/language/translate/v2"
56
+ data = {
57
+ 'key':"AIzaSyCX0-Wdxl_rgvcZzklNjnqJ1W9YiKjcHUs", # 認証の設定:APIキー
58
+ 'source': language_type,
59
+ 'target': target,
60
+ 'q': source_text,
61
+ 'format': "text"
62
+ }
63
+ #headers = {'X-HTTP-Method-Override': 'GET'}
64
+ #response = requests.post(url, data=data, headers=headers)
65
+ response = requests.post(url, data)
66
+ # print(response.json())
67
+ print(response)
68
+ res = response.json()
69
+ print(res["data"]["translations"][0]["translatedText"])
70
+ result = res["data"]["translations"][0]["translatedText"]
71
+ print(result)
72
+ return result
73
+
74
+
75
+ def setup_logging():
76
+ logging_config = {
77
+ 'version': 1,
78
+ 'disable_existing_loggers': False,
79
+ 'formatters': {
80
+ 'standard': {
81
+ 'format': '%(asctime)s %(levelname)s [%(name)s]: %(message)s',
82
+ },
83
+ },
84
+ 'handlers': {
85
+ 'console': {
86
+ 'level': 'INFO',
87
+ 'class': 'logging.StreamHandler',
88
+ 'formatter': 'standard',
89
+ },
90
+ },
91
+ 'root': {
92
+ 'handlers': ['console'],
93
+ 'level': 'DEBUG',
94
+ },
95
+ 'loggers': {
96
+ 'uvicorn': {
97
+ 'handlers': ['console'],
98
+ 'level': 'INFO',
99
+ 'propagate': False,
100
+ },
101
+ 'uvicorn.error': {
102
+ 'level': 'INFO',
103
+ },
104
+ 'uvicorn.access': {
105
+ 'level': 'INFO',
106
+ },
107
+ 'src.whisper_streaming.online_asr': { # Add your specific module here
108
+ 'handlers': ['console'],
109
+ 'level': 'DEBUG',
110
+ 'propagate': False,
111
+ },
112
+ 'src.whisper_streaming.whisper_streaming': { # Add your specific module here
113
+ 'handlers': ['console'],
114
+ 'level': 'DEBUG',
115
+ 'propagate': False,
116
+ },
117
+ },
118
+ }
119
+
120
+ logging.config.dictConfig(logging_config)
121
+
122
+ setup_logging()
123
+ logger = logging.getLogger(__name__)
124
+
125
+
126
+
127
+
128
+
129
+
130
+ app = FastAPI()
131
+ app.add_middleware(
132
+ CORSMiddleware,
133
+ allow_origins=["*"],
134
+ allow_credentials=True,
135
+ allow_methods=["*"],
136
+ allow_headers=["*"],
137
+ )
138
+
139
+
140
+ parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
141
+ parser.add_argument(
142
+ "--host",
143
+ type=str,
144
+ default="localhost",
145
+ help="The host address to bind the server to.",
146
+ )
147
+ parser.add_argument(
148
+ "--port", type=int, default=8000, help="The port number to bind the server to."
149
+ )
150
+ parser.add_argument(
151
+ "--warmup-file",
152
+ type=str,
153
+ dest="warmup_file",
154
+ help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .",
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--diarization",
159
+ type=bool,
160
+ default=False,
161
+ help="Whether to enable speaker diarization.",
162
+ )
163
+
164
+
165
+ add_shared_args(parser)
166
+ args = parser.parse_args()
167
+ # args.model = 'medium'
168
+
169
+ asr, tokenizer = backend_factory(args)
170
+
171
+ if args.diarization:
172
+ from src.diarization.diarization_online import DiartDiarization
173
+
174
+
175
+ # Load demo HTML for the root endpoint
176
+ with open("src/web/live_transcription.html", "r", encoding="utf-8") as f:
177
+ html = f.read()
178
+
179
+
180
+ @app.get("/")
181
+ async def get():
182
+ return HTMLResponse(html)
183
+
184
+
185
+ SAMPLE_RATE = 16000
186
+ CHANNELS = 1
187
+ SAMPLES_PER_SEC = SAMPLE_RATE * int(args.min_chunk_size)
188
+ BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
189
+ BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
190
+
191
+
192
+ async def start_ffmpeg_decoder():
193
+ """
194
+ Start an FFmpeg process in async streaming mode that reads WebM from stdin
195
+ and outputs raw s16le PCM on stdout. Returns the process object.
196
+ """
197
+ process = (
198
+ ffmpeg.input("pipe:0", format="webm")
199
+ .output(
200
+ "pipe:1",
201
+ format="s16le",
202
+ acodec="pcm_s16le",
203
+ ac=CHANNELS,
204
+ ar=str(SAMPLE_RATE),
205
+ )
206
+ .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)
207
+ )
208
+ return process
209
+
210
+
211
+
212
+ @app.websocket("/asr")
213
+ async def websocket_endpoint(websocket: WebSocket):
214
+ await websocket.accept()
215
+ print("WebSocket connection opened.")
216
+
217
+ ffmpeg_process = await start_ffmpeg_decoder()
218
+ pcm_buffer = bytearray()
219
+ print("Loading online.")
220
+ online = online_factory(args, asr, tokenizer)
221
+ print("Online loaded.")
222
+
223
+ if args.diarization:
224
+ diarization = DiartDiarization(SAMPLE_RATE)
225
+
226
+ # Continuously read decoded PCM from ffmpeg stdout in a background task
227
+ async def ffmpeg_stdout_reader():
228
+ nonlocal pcm_buffer
229
+ loop = asyncio.get_event_loop()
230
+ full_transcription = ""
231
+ beg = time()
232
+
233
+ chunk_history = [] # Will store dicts: {beg, end, text, speaker}
234
+
235
+ buffers = [{'speaker': '0', 'text': '', 'translation': None}]
236
+ buffer_line = ''
237
+
238
+ while True:
239
+ print('in while')
240
+ try:
241
+ print('try in while')
242
+ elapsed_time = int(time() - beg)
243
+ beg = time()
244
+ print('before await loop.run_in_executor()')
245
+ chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 32000 * elapsed_time)
246
+
247
+ print('before if not chunk')
248
+ if not chunk: # The first chunk will be almost empty, FFmpeg is still starting up
249
+ chunk = await loop.run_in_executor(None, ffmpeg_process.stdout.read, 4096)
250
+ if not chunk: # FFmpeg might have closed
251
+ print("FFmpeg stdout closed.")
252
+ break
253
+
254
+ pcm_buffer.extend(chunk)
255
+
256
+ print('before if len(pcm_buffer)')
257
+ if len(pcm_buffer) >= BYTES_PER_SEC:
258
+ print('in if len(pcm_buffer)')
259
+ # Convert int16 -> float32
260
+ pcm_array = (np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0)
261
+ pcm_buffer = bytearray() # Initialize the PCM buffer
262
+ online.insert_audio_chunk(pcm_array)
263
+ beg_trans, end_trans, trans = online.process_iter()
264
+
265
+ if trans:
266
+ chunk_history.append({
267
+ "beg": beg_trans,
268
+ "end": end_trans,
269
+ "text": trans,
270
+ "speaker": "0"
271
+ })
272
+ full_transcription += trans
273
+
274
+ # ----------------
275
+ # Process buffer
276
+ # ----------------
277
+ if args.vac:
278
+ # We need to access the underlying online object to get the buffer
279
+ buffer = online.online.concatenate_tsw(online.online.transcript_buffer.buffer)[2]
280
+ else:
281
+ buffer = online.concatenate_tsw(online.transcript_buffer.buffer)[2]
282
+
283
+ if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
284
+ buffer = ""
285
+
286
+ buffer_line += buffer
287
+
288
+ punctuations = (',', '.', '?', '!', 'and', 'or', 'but', 'however')
289
+ if any(punctuation in buffer_line for punctuation in punctuations):
290
+ last_punctuation_index = max((buffer_line.rfind(p) + len(p) + 1) for p in punctuations if p in buffer_line)
291
+ extracted_text = buffer_line[:last_punctuation_index]
292
+ buffer_line = buffer_line[last_punctuation_index:]
293
+ buffers.append({'speaker': '0', 'text': extracted_text, 'translation': None})
294
+
295
+ # Translation loop
296
+ print('buffers for loop')
297
+ for i, buffer in enumerate(buffers):
298
+ print(i, buffer)
299
+ if buffer['translation'] is not None:
300
+ continue
301
+ if buffer['text'] == '':
302
+ continue
303
+
304
+ transcription = buffer['text']
305
+ buffers[i]['translation'] = translator_wrapper(transcription, mode='google')
306
+ buffers[i]['text'] += ('|' + buffers[i]['translation'])
307
+
308
+ # ----------------
309
+ # Process lines
310
+ # ----------------
311
+ print('Process lines')
312
+ lines = [{"speaker": "0", "text": ""}]
313
+
314
+ if args.diarization:
315
+ await diarization.diarize(pcm_array)
316
+ # diarization.assign_speakers_to_chunks(chunk_history)
317
+ chunk_history = diarization.assign_speakers_to_chunks(chunk_history)
318
+
319
+ for ch in chunk_history:
320
+ if args.diarization and ch["speaker"] and ch["speaker"][-1] != lines[-1]["speaker"]:
321
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
322
+
323
+ else:
324
+ lines.append({"speaker": ch["speaker"], "text": ch['text']})
325
+
326
+ for i, line in enumerate(lines):
327
+ if line['text'].strip() == '':
328
+ continue
329
+ # translation = translator(line['text'])[0]['translation_text']
330
+ # translation = translation.replace(' ', '')
331
+ # lines[i]['text'] = line['text'] + translation
332
+ lines[i]['text'] = line['text']
333
+
334
+ # translation = translator(buffer)[0]['translation_text']
335
+ # translation = translation.replace(' ', '')
336
+ # buffer += translation
337
+
338
+ print('Before making response')
339
+ response = {"lines": buffers, "buffer": ''}
340
+ await websocket.send_json(response)
341
+
342
+ except Exception as e:
343
+ print(f"Exception in ffmpeg_stdout_reader: {e}")
344
+ break
345
+
346
+ print("Exiting ffmpeg_stdout_reader...")
347
+
348
+ stdout_reader_task = asyncio.create_task(ffmpeg_stdout_reader())
349
+
350
+ try:
351
+ while True:
352
+ # Receive incoming WebM audio chunks from the client
353
+ message = await websocket.receive_bytes()
354
+ # Pass them to ffmpeg via stdin
355
+ ffmpeg_process.stdin.write(message)
356
+ ffmpeg_process.stdin.flush()
357
+
358
+ except WebSocketDisconnect:
359
+ print("WebSocket connection closed.")
360
+ except Exception as e:
361
+ print(f"Error in websocket loop: {e}")
362
+ finally:
363
+ # Clean up ffmpeg and the reader task
364
+ try:
365
+ ffmpeg_process.stdin.close()
366
+ except:
367
+ pass
368
+ stdout_reader_task.cancel()
369
+
370
+ try:
371
+ ffmpeg_process.stdout.close()
372
+ except:
373
+ pass
374
+
375
+ ffmpeg_process.wait()
376
+ del online
377
+
378
+ if args.diarization:
379
+ # Stop Diart
380
+ diarization.close()
381
+
382
+
383
+
384
+
385
+ if __name__ == "__main__":
386
+ import uvicorn
387
+
388
+ uvicorn.run(
389
+ "whisper_fastapi_online_server:app", host=args.host, port=args.port, reload=True,
390
+ log_level="info"
391
+ )