File size: 82,911 Bytes
8718761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
import warnings
import torch
import numpy as np
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable
from types import MethodType
from tqdm import tqdm

import whisper
from whisper.audio import (
    SAMPLE_RATE, N_FRAMES, HOP_LENGTH, N_SAMPLES, N_SAMPLES_PER_TOKEN, TOKENS_PER_SECOND, FRAMES_PER_SECOND, N_FFT,
    pad_or_trim, log_mel_spectrogram
)
from whisper.utils import exact_div
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from whisper.decoding import DecodingOptions, DecodingResult

from .audio import prep_audio
from .decode import decode_stable
from .result import WhisperResult, Segment
from .timing import add_word_timestamps_stable
from .stabilization import get_vad_silence_func, wav2mask, mask2timing, timing2mask
from .non_whisper import transcribe_any
from .utils import isolate_useful_options, safe_print
from .whisper_compatibility import warn_compatibility_issues, get_tokenizer

if TYPE_CHECKING:
    from whisper.model import Whisper

__all__ = ['modify_model', 'load_model', 'load_faster_whisper']

warnings.filterwarnings('ignore', module='whisper', message='.*Triton.*', category=UserWarning)


# modified version of whisper.transcribe.transcribe
def transcribe_stable(
        model: "Whisper",
        audio: Union[str, np.ndarray, torch.Tensor, bytes],
        *,
        verbose: Optional[bool] = False,
        temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
        compression_ratio_threshold: Optional[float] = 2.4,
        logprob_threshold: Optional[float] = -1.0,
        no_speech_threshold: Optional[float] = 0.6,
        condition_on_previous_text: bool = True,
        initial_prompt: Optional[str] = None,
        word_timestamps: bool = True,
        regroup: Union[bool, str] = True,
        ts_num: int = 0,
        ts_noise: float = 0.1,
        suppress_silence: bool = True,
        suppress_word_ts: bool = True,
        use_word_position: bool = True,
        q_levels: int = 20,
        k_size: int = 5,
        time_scale: float = None,
        demucs: Union[bool, torch.nn.Module] = False,
        demucs_output: str = None,
        demucs_options: dict = None,
        vad: bool = False,
        vad_threshold: float = 0.35,
        vad_onnx: bool = False,
        min_word_dur: float = 0.1,
        nonspeech_error: float = 0.3,
        only_voice_freq: bool = False,
        prepend_punctuations: str = "\"'“¿([{-",
        append_punctuations: str = "\"'.。,,!!??::”)]}、",
        mel_first: bool = False,
        split_callback: Callable = None,
        suppress_ts_tokens: bool = False,
        gap_padding: str = ' ...',
        only_ffmpeg: bool = False,
        max_instant_words: float = 0.5,
        avg_prob_threshold: Optional[float] = None,
        progress_callback: Callable = None,
        ignore_compatibility: bool = False,
        **decode_options) \
        -> WhisperResult:
    """
    Transcribe audio using Whisper.

    This is a modified version of :func:`whisper.transcribe.transcribe` with slightly different decoding logic while
    allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating
    voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
    result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.

    Parameters
    ----------
    model : whisper.model.Whisper
        An instance of Whisper ASR model.
    audio : str or numpy.ndarray or torch.Tensor or bytes
        Path/URL to the audio file, the audio waveform, or bytes of audio file.
        If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
    verbose : bool or None, default False
        Whether to display the text being decoded to the console.
        Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
    temperature : float or iterable of float, default (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
        Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
        upon failures according to either ``compression_ratio_threshold`` or ``logprob_threshold``.
    compression_ratio_threshold : float, default 2.4
        If the gzip compression ratio is above this value, treat as failed.
    logprob_threshold : float, default -1
        If the average log probability over sampled tokens is below this value, treat as failed
    no_speech_threshold : float, default 0.6
        If the no_speech probability is higher than this value AND the average log probability
        over sampled tokens is below ``logprob_threshold``, consider the segment as silent
    condition_on_previous_text : bool, default True
        If ``True``, the previous output of the model is provided as a prompt for the next window;
        disabling may make the text inconsistent across windows, but the model becomes less prone to
        getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
    initial_prompt : str, optional
        Text to provide as a prompt for the first window. This can be used to provide, or
        "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
        to make it more likely to predict those word correctly.
    word_timestamps : bool, default True
        Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
        and include the timestamps for each word in each segment.
        Disabling this will prevent segments from splitting/merging properly.
    regroup : bool or str, default True, meaning the default regroup algorithm
        String for customizing the regrouping algorithm. False disables regrouping.
        Ignored if ``word_timestamps = False``.
    ts_num : int, default 0, meaning disable this option
        Number of extra timestamp inferences to perform then use average of these extra timestamps.
        An experimental option that might hurt performance.
    ts_noise : float, default 0.1
        Percentage of noise to add to audio_features to perform inferences for ``ts_num``.
    suppress_silence : bool, default True
        Whether to enable timestamps adjustments based on the detected silence.
    suppress_word_ts : bool, default True
        Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
    use_word_position : bool, default True
        Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
        adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
    q_levels : int, default 20
        Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
        Acts as a threshold to marking sound as silent.
        Fewer levels will increase the threshold of volume at which to mark a sound as silent.
    k_size : int, default 5
        Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
        Recommend 5 or 3; higher sizes will reduce detection of silence.
    time_scale : float, optional
        Factor for scaling audio duration for inference.
        Greater than 1.0 'slows down' the audio, and less than 1.0 'speeds up' the audio. None is same as 1.0.
        A factor of 1.5 will stretch 10s audio to 15s for inference. This increases the effective resolution
        of the model but can increase word error rate.
    demucs : bool or torch.nn.Module, default False
        Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
        a Demucs model to avoid reloading the model for each run.
        Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
    demucs_output : str, optional
        Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
        Demucs must be installed to use. Official repo. https://github.com/facebookresearch/demucs.
    demucs_options : dict, optional
        Options to use for :func:`stable_whisper.audio.demucs_audio`.
    vad : bool, default False
        Whether to use Silero VAD to generate timestamp suppression mask.
        Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
    vad_threshold : float, default 0.35
        Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
    vad_onnx : bool, default False
        Whether to use ONNX for Silero VAD.
    min_word_dur : float, default 0.1
        Shortest duration each word is allowed to reach for silence suppression.
    nonspeech_error : float, default 0.3
        Relative error of non-speech sections that appear in between a word for silence suppression.
    only_voice_freq : bool, default False
        Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
    prepend_punctuations : str, default '"\'“¿([{-)'
        Punctuations to prepend to next word.
    append_punctuations : str, default '.。,,!!??::”)]}、)'
        Punctuations to append to previous word.
    mel_first : bool, default False
        Process entire audio track into log-Mel spectrogram first instead in chunks.
        Used if odd behavior seen in stable-ts but not in whisper, but use significantly more memory for long audio.
    split_callback : Callable, optional
        Custom callback for grouping tokens up with their corresponding words.
        The callback must take two arguments, list of tokens and tokenizer.
        The callback returns a tuple with a list of words and a corresponding nested list of tokens.
    suppress_ts_tokens : bool, default False
        Whether to suppress timestamp tokens during inference for timestamps are detected at silent.
        Reduces hallucinations in some cases, but also prone to ignore disfluencies and repetitions.
        This option is ignored if ``suppress_silence = False``.
    gap_padding : str, default ' ...'
        Padding prepend to each segments for word timing alignment.
        Used to reduce the probability of model predicting timestamps earlier than the first utterance.
    only_ffmpeg : bool, default False
        Whether to use only FFmpeg (instead of not yt-dlp) for URls
    max_instant_words : float, default 0.5
        If percentage of instantaneous words in a segment exceed this amount, the segment is removed.
    avg_prob_threshold: float or None, default None
        Transcribe the gap after the previous word and if the average word proababiliy of a segment falls below this
        value, discard the segment. If ``None``, skip transcribing the gap to reduce chance of timestamps starting
        before the next utterance.
    progress_callback : Callable, optional
        A function that will be called when transcription progress is updated.
        The callback need two parameters.
        The first parameter is a float for seconds of the audio that has been transcribed.
        The second parameter is a float for total duration of audio in seconds.
    ignore_compatibility : bool, default False
        Whether to ignore warnings for compatibility issues with the detected Whisper version.
    decode_options
        Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.

    Returns
    -------
    stable_whisper.result.WhisperResult
        All timestamps, words, probabilities, and other data from the transcription of ``audio``.

    See Also
    --------
    stable_whisper.non_whisper.transcribe_any : Return :class:`stable_whisper.result.WhisperResult` containing all the
        data from transcribing audio with unmodified :func:`whisper.transcribe.transcribe` with preprocessing and
        postprocessing.
    stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe : Return
        :class:`stable_whisper.result.WhisperResult` containing all the data from transcribing audio with
        :meth:`faster_whisper.WhisperModel.transcribe` with preprocessing and postprocessing.

    Examples
    --------
    >>> import stable_whisper
    >>> model = stable_whisper.load_model('base')
    >>> result = model.transcribe('audio.mp3', vad=True)
    >>> result.to_srt_vtt('audio.srt')
    Saved: audio.srt
    """
    warn_compatibility_issues(whisper, ignore_compatibility, 'Or use transcribe_minimal().')
    dtype = torch.float16 if decode_options.get("fp16", True) and not getattr(model, 'dq', False) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    if 'max_initial_timestamp' not in decode_options:
        decode_options['max_initial_timestamp'] = None

    device = model.device

    if time_scale:
        warnings.warn('``time_scale`` is deprecated. It will not affect results.',
                      DeprecationWarning, stacklevel=2)
    if decode_options.pop('input_sr', None):
        warnings.warn('``input_sr`` is deprecated. '
                      '``audio`` of types numpy.ndarray and torch.Tensor inputs must be already at 16kHz. '
                      'To higher sample rates for ``audio`` use str or bytes.',
                      DeprecationWarning, stacklevel=2)
    if not demucs_options:
        demucs_options = {}
    if demucs_output:
        if 'save_path' not in demucs_options:
            demucs_options['save_path'] = demucs_output
        warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
                      'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
                      DeprecationWarning, stacklevel=2)
    if 'device' not in demucs_options:
        demucs_options['device'] = device
    audio = prep_audio(
        audio,
        demucs=demucs,
        demucs_options=demucs_options,
        only_voice_freq=only_voice_freq,
        only_ffmpeg=only_ffmpeg,
        verbose=verbose
    )
    sample_padding = int(N_FFT // 2) + 1
    whole_mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=sample_padding) if mel_first else None
    tokenizer = None
    language = None
    initial_prompt_tokens = []
    task = decode_options.get("task", "transcribe")

    def detect_language():
        nonlocal tokenizer
        if tokenizer is None:
            if decode_options.get("language", None) is None and model:
                if not model.is_multilingual:
                    decode_options["language"] = "en"
                else:
                    if verbose:
                        print("Detecting language using up to 30 seconds following first non-silent sample. "
                              "Use `--language` to specify the language")
                    timing_mask = None
                    if segment_silence_timing is not None:
                        timing_mask = np.logical_and(
                            segment_silence_timing[0] <= time_offset,
                            segment_silence_timing[1] >= time_offset
                        )
                    start_sample = (
                        None
                        if segment_silence_timing is None or not timing_mask.any() else
                        round(segment_silence_timing[1][timing_mask.nonzero()[0]][0] * SAMPLE_RATE)
                    )
                    if start_sample is None:
                        nonlocal mel_segment
                        curr_mel_segment = mel_segment
                    else:
                        if whole_mel is None:
                            curr_mel_segment = log_mel_spectrogram(
                                audio[..., start_sample:start_sample+N_SAMPLES],
                                model.dims.n_mels,
                                padding=sample_padding
                            )
                        else:
                            start_frame = int(start_sample/HOP_LENGTH)
                            curr_mel_segment = whole_mel[..., start_frame:start_frame+N_FRAMES]
                        curr_mel_segment = pad_or_trim(curr_mel_segment, N_FRAMES).to(device=device, dtype=dtype)
                    _, probs = model.detect_language(curr_mel_segment)
                    decode_options["language"] = max(probs, key=probs.get)
                    if verbose is not None:
                        detected_msg = f"Detected language: {LANGUAGES[decode_options['language']]}"
                        if tqdm_pbar.disable:
                            print(detected_msg)
                        else:
                            tqdm_pbar.write(detected_msg)

            nonlocal language
            language = decode_options["language"]
            tokenizer = get_tokenizer(model, language=language, task=task)

            if word_timestamps and task == "translate":
                warnings.warn("Word-level timestamps on translations may not be reliable.")

            if initial_prompt is not None:
                nonlocal initial_prompt_tokens
                initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
                all_tokens.extend(initial_prompt_tokens)

    audio_features = None

    def decode_with_fallback(seg: torch.Tensor,
                             ts_token_mask: torch.Tensor = None) \
            -> DecodingResult:
        nonlocal audio_features
        temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
        decode_result = None

        for t in temperatures:
            kwargs = {**decode_options}
            if t > 0:
                # disable beam_size and patience when t > 0
                kwargs.pop("beam_size", None)
                kwargs.pop("patience", None)
            else:
                # disable best_of when t == 0
                kwargs.pop("best_of", None)

            options = DecodingOptions(**kwargs, temperature=t)
            decode_result, audio_features = decode_stable(model,
                                                          seg,
                                                          options,
                                                          ts_token_mask=ts_token_mask if suppress_ts_tokens else None,
                                                          audio_features=audio_features)

            needs_fallback = False
            if (
                    compression_ratio_threshold is not None
                    and decode_result.compression_ratio > compression_ratio_threshold
            ):
                needs_fallback = True  # too repetitive
            if (
                    logprob_threshold is not None
                    and decode_result.avg_logprob < logprob_threshold
            ):
                needs_fallback = True  # average log probability is too low
            if (
                no_speech_threshold is not None
                and decode_result.no_speech_prob > no_speech_threshold
            ):
                needs_fallback = False  # silence

            if not needs_fallback:
                break

        return decode_result

    seek_sample = 0  # samples
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
            input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []
    prompt_reset_since = 0

    def new_segment(
            *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
            "seek": round(seek_sample / SAMPLE_RATE, 3),  # units in seconds
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }

    punctuations = prepend_punctuations + append_punctuations

    total_samples = audio.shape[-1]
    total_duration = round(total_samples / SAMPLE_RATE, 2)
    n_samples_per_frame = exact_div(N_SAMPLES_PER_TOKEN * TOKENS_PER_SECOND, FRAMES_PER_SECOND)

    silent_timings = [[], []]
    silence_timing = None
    if suppress_silence and vad:
        silence_timing = get_vad_silence_func(onnx=vad_onnx, verbose=verbose)(audio, speech_threshold=vad_threshold)

    with tqdm(total=total_duration, unit='sec', disable=verbose is not False, desc=task.title()) as tqdm_pbar:

        def update_pbar():
            nonlocal audio_features
            audio_features = None
            seek_duration = min(total_duration, round(seek_sample / SAMPLE_RATE, 2))
            if not tqdm_pbar.disable:
                tqdm_pbar.update(seek_duration - tqdm_pbar.n)
            if progress_callback is not None:
                progress_callback(seek=seek_duration, total=total_duration)

        def update_seek():
            nonlocal seek_sample
            seek_sample += segment_samples

        def fast_forward():
            # fast-forward to the next segment boundary
            update_seek()
            update_pbar()

        while seek_sample < audio.shape[-1]:
            seek_sample_end = seek_sample + N_SAMPLES
            audio_segment = audio[seek_sample:seek_sample_end]
            time_offset = seek_sample / SAMPLE_RATE
            segment_samples = audio_segment.shape[-1]
            segment_duration = segment_samples / SAMPLE_RATE

            mel_segment = (
                log_mel_spectrogram(audio_segment, model.dims.n_mels, padding=sample_padding)
                if whole_mel is None else
                whole_mel[..., round(seek_sample / n_samples_per_frame): round(seek_sample_end / n_samples_per_frame)]
            )

            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(device=model.device, dtype=dtype)

            segment_silence_timing = None
            ts_token_mask = None
            if suppress_silence:
                if silence_timing is None:
                    ts_token_mask = wav2mask(audio_segment, q_levels=q_levels, k_size=k_size)
                    segment_silence_timing = mask2timing(ts_token_mask, time_offset=time_offset)
                else:
                    timing_indices = np.logical_and(
                        silence_timing[1] > time_offset,
                        silence_timing[0] < time_offset + segment_duration
                    )
                    segment_silence_timing = (silence_timing[0][timing_indices], silence_timing[1][timing_indices])

                    ts_token_mask = timing2mask(*segment_silence_timing, size=1501, time_offset=time_offset)

                    if mn := timing_indices.argmax():
                        silence_timing = (silence_timing[0][mn:], silence_timing[1][mn:])

                if ts_token_mask is not None:
                    if ts_token_mask.all():  # segment is silent
                        fast_forward()
                        continue
                    ts_token_mask = pad_or_trim(ts_token_mask, 1501)

            detect_language()
            decode_options["prompt"] = all_tokens[prompt_reset_since:]
            result: DecodingResult = decode_with_fallback(mel_segment, ts_token_mask=ts_token_mask)
            tokens = torch.tensor(result.tokens)

            if no_speech_threshold is not None:
                # no voice activity check
                should_skip = result.no_speech_prob > no_speech_threshold
                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
                    # don't skip if the logprob is high enough, despite the no_speech_prob
                    should_skip = False

                if should_skip:
                    fast_forward()
                    continue

            current_segments = []

            timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
            single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
            consecutive.add_(1)
            if len(consecutive) > 0:
                # if the output contains two consecutive timestamp tokens
                slices = consecutive.tolist()
                if single_timestamp_ending:
                    slices.append(len(tokens))

                last_slice = 0
                for current_slice in slices:
                    sliced_tokens = tokens[last_slice:current_slice]
                    start_timestamp_pos = (
                            sliced_tokens[0].item() - tokenizer.timestamp_begin
                    )
                    end_timestamp_pos = (
                            sliced_tokens[-1].item() - tokenizer.timestamp_begin
                    )
                    current_segments.append(
                        new_segment(
                            start=round(time_offset + start_timestamp_pos * time_precision, 3),
                            end=round(time_offset + min(end_timestamp_pos * time_precision, segment_duration), 3),
                            tokens=sliced_tokens,
                            result=result,
                        )
                    )
                    last_slice = current_slice

            else:
                duration = segment_duration
                timestamps = tokens[timestamp_tokens.nonzero().flatten()]
                if (
                        len(timestamps) > 0
                        and timestamps[-1].item() != tokenizer.timestamp_begin
                ):
                    # no consecutive timestamps but it has a timestamp; use the last one.
                    end_timestamp_pos = (
                            timestamps[-1].item() - tokenizer.timestamp_begin
                    )
                    duration = min(end_timestamp_pos * time_precision, segment_duration)
                else:
                    end_timestamp_pos = 0

                current_segments.append(
                    new_segment(
                        start=round(time_offset, 3),
                        end=round(time_offset + duration, 3),
                        tokens=tokens,
                        result=result,
                    )
                )

            # if a segment is instantaneous or does not contain text, remove it
            for i in reversed(range(len(current_segments))):
                seg = current_segments[i]
                if seg["start"] == seg["end"] or seg["text"].strip() in punctuations:
                    del current_segments[i]

            num_samples = (
                min(round(end_timestamp_pos * N_SAMPLES_PER_TOKEN), segment_samples)
                if end_timestamp_pos > 0 else
                segment_samples
            )

            if word_timestamps:
                add_word_timestamps_stable(
                    segments=current_segments,
                    model=model,
                    tokenizer=tokenizer,
                    mel=mel_segment,
                    num_samples=num_samples,
                    prepend_punctuations=prepend_punctuations,
                    append_punctuations=append_punctuations,
                    audio_features=audio_features,
                    ts_num=ts_num,
                    ts_noise=ts_noise,
                    split_callback=split_callback,
                    gap_padding=gap_padding
                )

                # if [max_instant_words] of the words in a segment are instantaneous, remove it
                for i in reversed(range(len(current_segments))):
                    zero_duration_percent = (
                        np.array(
                            [w['start'] == w['end'] for w in current_segments[i]['words']]
                        )
                        .astype(np.float16)
                        .mean()
                    )
                    if zero_duration_percent > max_instant_words:
                        del current_segments[i]

                if avg_prob_threshold and current_segments:
                    if (
                            single_timestamp_ending and
                            (np.mean([w['probability'] for s in current_segments for w in s['words']]) <
                             avg_prob_threshold)
                    ):
                        num_samples = segment_samples
                        current_segments = []
                    else:
                        num_samples = round((current_segments[-1]['words'][-1]['end']-time_offset) * SAMPLE_RATE)

            if len(current_segments) == 0:
                fast_forward()
                continue

            if segment_silence_timing is not None:
                silent_timings[0].extend(segment_silence_timing[0])
                silent_timings[1].extend(segment_silence_timing[1])
                for seg_i, segment in enumerate(current_segments):
                    segment = Segment(**segment).suppress_silence(
                            *segment_silence_timing,
                            min_word_dur=min_word_dur,
                            word_level=suppress_word_ts,
                            nonspeech_error=nonspeech_error,
                            use_word_position=use_word_position,
                        )
                    if verbose:
                        safe_print(segment.to_display_str())
                    current_segments[seg_i] = segment.to_dict()

            all_segments.extend(
                [
                    {"id": i, **segment}
                    for i, segment in enumerate(current_segments, start=len(all_segments))
                ]
            )
            all_tokens.extend(
                [token for segment in current_segments for token in segment["tokens"]]
            )
            if not single_timestamp_ending or avg_prob_threshold:
                segment_samples = num_samples

            if not condition_on_previous_text or result.temperature > 0.5:
                # do not feed the prompt tokens if a high temperature was used
                prompt_reset_since = len(all_tokens)

            fast_forward()

        # final update
        update_pbar()

    if model.device != torch.device('cpu'):
        torch.cuda.empty_cache()

    text = '' if tokenizer is None else tokenizer.decode(all_tokens[len(initial_prompt_tokens):])
    final_result = WhisperResult(dict(text=text,
                                      segments=all_segments,
                                      language=language,
                                      time_scale=time_scale))
    if word_timestamps and regroup:
        final_result.regroup(regroup)

    if time_scale is not None:
        final_result.rescale_time(1 / time_scale)

    if len(final_result.text) == 0:
        warnings.warn(f'Failed to {task} audio. Result contains no text. ')

    final_result.update_nonspeech_sections(*silent_timings)

    return final_result


def transcribe_minimal(
        model: "Whisper",
        audio: Union[str, np.ndarray, torch.Tensor, bytes],
        *,
        verbose: Optional[bool] = False,
        word_timestamps: bool = True,
        regroup: Union[bool, str] = True,
        suppress_silence: bool = True,
        suppress_word_ts: bool = True,
        use_word_position: bool = True,
        q_levels: int = 20,
        k_size: int = 5,
        demucs: bool = False,
        demucs_output: str = None,
        demucs_options: dict = None,
        vad: bool = False,
        vad_threshold: float = 0.35,
        vad_onnx: bool = False,
        min_word_dur: float = 0.1,
        nonspeech_error: float = 0.3,
        only_voice_freq: bool = False,
        only_ffmpeg: bool = False,
        **options) \
        -> WhisperResult:
    """
    Transcribe audio using Whisper.

    This is uses the original whisper transcribe function, :func:`whisper.transcribe.transcribe`, while still allowing
    additional preprocessing and postprocessing. The preprocessing performed on the audio includes: isolating voice /
    removing noise with Demucs and low/high-pass filter. The postprocessing performed on the transcription
    result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation and speech gaps.

    Parameters
    ----------
    model : whisper.model.Whisper
        An instance of Whisper ASR model.
    audio : str or numpy.ndarray or torch.Tensor or bytes
        Path/URL to the audio file, the audio waveform, or bytes of audio file.
        If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
    verbose : bool or None, default False
        Whether to display the text being decoded to the console.
        Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
    word_timestamps : bool, default True
        Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
        and include the timestamps for each word in each segment.
        Disabling this will prevent segments from splitting/merging properly.
    regroup : bool or str, default True, meaning the default regroup algorithm
        String for customizing the regrouping algorithm. False disables regrouping.
        Ignored if ``word_timestamps = False``.
    suppress_silence : bool, default True
        Whether to enable timestamps adjustments based on the detected silence.
    suppress_word_ts : bool, default True
        Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
    use_word_position : bool, default True
        Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
        adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
    q_levels : int, default 20
        Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
        Acts as a threshold to marking sound as silent.
        Fewer levels will increase the threshold of volume at which to mark a sound as silent.
    k_size : int, default 5
        Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
        Recommend 5 or 3; higher sizes will reduce detection of silence.
    demucs : bool or torch.nn.Module, default False
        Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
        a Demucs model to avoid reloading the model for each run.
        Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
    demucs_output : str, optional
        Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
        Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
    demucs_options : dict, optional
        Options to use for :func:`stable_whisper.audio.demucs_audio`.
    vad : bool, default False
        Whether to use Silero VAD to generate timestamp suppression mask.
        Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
    vad_threshold : float, default 0.35
        Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
    vad_onnx : bool, default False
        Whether to use ONNX for Silero VAD.
    min_word_dur : float, default 0.1
        Shortest duration each word is allowed to reach for silence suppression.
    nonspeech_error : float, default 0.3
        Relative error of non-speech sections that appear in between a word for silence suppression.
    only_voice_freq : bool, default False
        Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
    only_ffmpeg : bool, default False
        Whether to use only FFmpeg (instead of not yt-dlp) for URls
    options
        Additional options used for :func:`whisper.transcribe.transcribe` and
        :func:`stable_whisper.non_whisper.transcribe_any`.
    Returns
    -------
    stable_whisper.result.WhisperResult
        All timestamps, words, probabilities, and other data from the transcription of ``audio``.

    Examples
    --------
    >>> import stable_whisper
    >>> model = stable_whisper.load_model('base')
    >>> result = model.transcribe_minimal('audio.mp3', vad=True)
    >>> result.to_srt_vtt('audio.srt')
    Saved: audio.srt
    """
    inference_kwargs = dict(
        model=model,
        audio=audio,
        word_timestamps=word_timestamps,
        verbose=verbose
    )
    extra_options = isolate_useful_options(options, transcribe_any, True)
    if demucs or only_voice_freq:
        if 'audio_type' not in extra_options:
            extra_options['audio_type'] = 'torch'
        if 'model_sr' not in extra_options:
            extra_options['model_sr'] = SAMPLE_RATE
    inference_kwargs.update(options)
    return transcribe_any(
        inference_func=whisper.transcribe,
        audio=audio,
        inference_kwargs=inference_kwargs,
        verbose=verbose,
        regroup=regroup,
        suppress_silence=suppress_silence, 
        suppress_word_ts=suppress_word_ts,
        q_levels=q_levels,
        k_size=k_size,
        demucs=demucs,
        demucs_output=demucs_output,
        demucs_options=demucs_options,
        vad=vad,
        vad_threshold=vad_threshold,
        vad_onnx=vad_onnx,
        min_word_dur=min_word_dur,
        nonspeech_error=nonspeech_error,
        use_word_position=use_word_position,
        only_voice_freq=only_voice_freq,
        only_ffmpeg=only_ffmpeg,
        force_order=True,
        **extra_options
    )


def load_faster_whisper(model_size_or_path: str, **model_init_options):
    """
    Load an instance of :class:`faster_whisper.WhisperModel`.

    Parameters
    ----------
    model_size_or_path : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
        'large-v2', 'large-v3', or 'large'}
        Size of the model.

    model_init_options
        Additional options to use for initialization of :class:`faster_whisper.WhisperModel`.

    Returns
    -------
    faster_whisper.WhisperModel
        A modified instance with :func:`stable_whisper.whisper_word_level.load_faster_whisper.faster_transcribe`
        assigned to :meth:`faster_whisper.WhisperModel.transcribe_stable`.
    """
    from faster_whisper import WhisperModel
    faster_model = WhisperModel(model_size_or_path, **model_init_options)

    def _inner_transcribe(model, audio, verbose, **faster_transcribe_options):
        if isinstance(audio, bytes):
            import io
            audio = io.BytesIO(audio)
        progress_callback = faster_transcribe_options.pop('progress_callback', None)
        segments, info = model.transcribe(audio, **faster_transcribe_options)
        language = LANGUAGES.get(info.language, info.language)
        if verbose is not None:
            print(f'Detected Language: {language}')
            print(f'Transcribing with faster-whisper ({model_size_or_path})...\r', end='')

        final_segments = []
        task = faster_transcribe_options.get('task', 'transcribe').title()
        total_duration = round(info.duration, 2)

        with tqdm(total=total_duration, unit='sec', disable=verbose is not False, desc=task) as tqdm_pbar:

            def update_pbar(seek):
                tqdm_pbar.update(seek - tqdm_pbar.n)
                if progress_callback is not None:
                    progress_callback(seek, total_duration)

            for segment in segments:
                segment = segment._asdict()
                if (words := segment.get('words')) is not None:
                    segment['words'] = [w._asdict() for w in words]
                else:
                    del segment['words']
                if verbose:
                    safe_print(Segment(**segment).to_display_str())
                final_segments.append(segment)
                update_pbar(segment["end"])
            update_pbar(tqdm_pbar.total)

        if verbose:
            print(f'Completed transcription with faster-whisper ({model_size_or_path}).')

        return dict(language=language, segments=final_segments)

    def faster_transcribe(
            model: WhisperModel,
            audio: Union[str, bytes, np.ndarray],
            *,
            word_timestamps: bool = True,
            verbose: Optional[bool] = False,
            regroup: Union[bool, str] = True,
            suppress_silence: bool = True,
            suppress_word_ts: bool = True,
            use_word_position: bool = True,
            q_levels: int = 20,
            k_size: int = 5,
            demucs: bool = False,
            demucs_output: str = None,
            demucs_options: dict = None,
            vad: bool = False,
            vad_threshold: float = 0.35,
            vad_onnx: bool = False,
            min_word_dur: float = 0.1,
            nonspeech_error: float = 0.3,
            only_voice_freq: bool = False,
            only_ffmpeg: bool = False,
            check_sorted: bool = True,
            progress_callback: Callable = None,
            **options
    ) -> WhisperResult:
        """
        Transcribe audio using faster-whisper (https://github.com/guillaumekln/faster-whisper).

        This is uses the transcribe method from faster-whisper, :meth:`faster_whisper.WhisperModel.transcribe`, while
        still allowing additional preprocessing and postprocessing. The preprocessing performed on the audio includes:
        isolating voice / removing noise with Demucs and low/high-pass filter. The postprocessing performed on the
        transcription result includes: adjusting timestamps with VAD and custom regrouping segments based punctuation
        and speech gaps.

        Parameters
        ----------
        model : faster_whisper.WhisperModel
            The faster-whisper ASR model instance.
        audio : str or numpy.ndarray or torch.Tensor or bytes
            Path/URL to the audio file, the audio waveform, or bytes of audio file.
            If audio is :class:`numpy.ndarray` or :class:`torch.Tensor`, the audio must be already at sampled to 16kHz.
        verbose : bool or None, default False
            Whether to display the text being decoded to the console.
            Displays all the details if ``True``. Displays progressbar if ``False``. Display nothing if ``None``.
        word_timestamps : bool, default True
            Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
            and include the timestamps for each word in each segment.
            Disabling this will prevent segments from splitting/merging properly.
        regroup : bool or str, default True, meaning the default regroup algorithm
            String for customizing the regrouping algorithm. False disables regrouping.
            Ignored if ``word_timestamps = False``.
        suppress_silence : bool, default True
            Whether to enable timestamps adjustments based on the detected silence.
        suppress_word_ts : bool, default True
            Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
        use_word_position : bool, default True
            Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
            adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
        q_levels : int, default 20
            Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
            Acts as a threshold to marking sound as silent.
            Fewer levels will increase the threshold of volume at which to mark a sound as silent.
        k_size : int, default 5
            Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
            Recommend 5 or 3; higher sizes will reduce detection of silence.
        demucs : bool or torch.nn.Module, default False
            Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance
            of a Demucs model to avoid reloading the model for each run.
            Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
        demucs_output : str, optional
            Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
            Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
        demucs_options : dict, optional
            Options to use for :func:`stable_whisper.audio.demucs_audio`.
        vad : bool, default False
            Whether to use Silero VAD to generate timestamp suppression mask.
            Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
        vad_threshold : float, default 0.35
            Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
        vad_onnx : bool, default False
            Whether to use ONNX for Silero VAD.
        min_word_dur : float, default 0.1
            Shortest duration each word is allowed to reach for silence suppression.
        nonspeech_error : float, default 0.3
            Relative error of non-speech sections that appear in between a word for silence suppression.
        only_voice_freq : bool, default False
            Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
        only_ffmpeg : bool, default False
            Whether to use only FFmpeg (instead of not yt-dlp) for URls
        check_sorted : bool, default True
            Whether to raise an error when timestamps returned by faster-whipser are not in ascending order.
        progress_callback : Callable, optional
            A function that will be called when transcription progress is updated.
            The callback need two parameters.
            The first parameter is a float for seconds of the audio that has been transcribed.
            The second parameter is a float for total duration of audio in seconds.
        options
            Additional options used for :meth:`faster_whisper.WhisperModel.transcribe` and
            :func:`stable_whisper.non_whisper.transcribe_any`.

        Returns
        -------
        stable_whisper.result.WhisperResult
            All timestamps, words, probabilities, and other data from the transcription of ``audio``.

        Examples
        --------
        >>> import stable_whisper
        >>> model = stable_whisper.load_faster_whisper('base')
        >>> result = model.transcribe_stable('audio.mp3', vad=True)
        >>> result.to_srt_vtt('audio.srt')
        Saved: audio.srt
        """
        extra_options = isolate_useful_options(options, transcribe_any, pop=True)
        if demucs or only_voice_freq:
            if 'audio_type' not in extra_options:
                extra_options['audio_type'] = 'numpy'
            if 'model_sr' not in extra_options:
                extra_options['model_sr'] = SAMPLE_RATE
        faster_whisper_options = options
        faster_whisper_options['model'] = model
        faster_whisper_options['audio'] = audio
        faster_whisper_options['word_timestamps'] = word_timestamps
        faster_whisper_options['verbose'] = verbose
        faster_whisper_options['progress_callback'] = progress_callback
        if not demucs_options:
            demucs_options = {}
        if demucs_output:
            if 'save_path' not in demucs_options:
                demucs_options['save_path'] = demucs_output
            warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
                          'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
                          DeprecationWarning, stacklevel=2)

        return transcribe_any(
            inference_func=_inner_transcribe,
            audio=audio,
            inference_kwargs=faster_whisper_options,
            verbose=verbose,
            regroup=regroup,
            suppress_silence=suppress_silence,
            suppress_word_ts=suppress_word_ts,
            q_levels=q_levels,
            k_size=k_size,
            demucs=demucs,
            demucs_options=demucs_options,
            vad=vad,
            vad_threshold=vad_threshold,
            vad_onnx=vad_onnx,
            min_word_dur=min_word_dur,
            nonspeech_error=nonspeech_error,
            use_word_position=use_word_position,
            only_voice_freq=only_voice_freq,
            only_ffmpeg=only_ffmpeg,
            force_order=True,
            check_sorted=check_sorted,
            **extra_options
        )

    faster_model.transcribe_stable = MethodType(faster_transcribe, faster_model)
    from .alignment import align
    faster_model.align = MethodType(align, faster_model)

    return faster_model


def modify_model(model: "Whisper"):
    """
    Modify an instance if :class:`whisper.model.Whisper`.

    The following are performed:
    -replace :meth:`whisper.model.Whisper.transcribe` with :func:`stable_whisper.whisper_word_level.transcribe_stable`
    -assign :meth:`whisper.model.transcribe_minimal` to :func:`stable_whisper.whisper_word_level.transcribe_minimal`
    -assign :meth:`whisper.model.Whisper.transcribe_original` to :meth:`whisper.model.Whisper.transcribe`
    -assign :meth:`whisper.model.Whisper.align` to :func:`stable_whisper.alignment.align`
    -assign :meth:`whisper.model.Whisper.locate` to :func:`stable_whisper.alignment.locate`
    """
    model.transcribe = MethodType(transcribe_stable, model)
    model.transcribe_minimal = MethodType(transcribe_minimal, model)
    model.transcribe_original = MethodType(whisper.transcribe, model)
    from .alignment import align, refine, locate
    model.align = MethodType(align, model)
    model.refine = MethodType(refine, model)
    model.locate = MethodType(locate, model)


# modified version of whisper.load_model
def load_model(name: str, device: Optional[Union[str, torch.device]] = None,
               download_root: str = None, in_memory: bool = False,
               cpu_preload: bool = True, dq: bool = False) -> "Whisper":
    """
    Load an instance if :class:`whisper.model.Whisper`.

    Parameters
    ----------
    name : {'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1',
        'large-v2', 'large-v3', or 'large'}
        One of the official model names listed by :func:`whisper.available_models`, or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : str or torch.device, optional
        PyTorch device to put the model into.
    download_root : str, optional
        Path to download the model files; by default, it uses "~/.cache/whisper".
    in_memory : bool, default False
        Whether to preload the model weights into host memory.
    cpu_preload : bool, default True
        Load model into CPU memory first then move model to specified device
        to reduce GPU memory usage when loading model
    dq : bool, default False
        Whether to apply Dynamic Quantization to model to reduced memory usage and increase inference speed
        but at the cost of a slight decrease in accuracy. Only for CPU.

    Returns
    -------
    model : "Whisper"
        The Whisper ASR model instance.

    Notes
    -----
    The overhead from ``dq = True`` might make inference slower for models smaller than 'large'.
    """
    if device is None or dq:
        device = "cuda" if torch.cuda.is_available() and not dq else "cpu"
    if cpu_preload:
        model = whisper.load_model(name, device='cpu', download_root=download_root, in_memory=in_memory)
        cuda_index = None
        if isinstance(device, str) and device.startswith('cuda'):
            try:
                cuda_index = [] if device == 'cuda' else [int(device.split(':')[-1])]
            except ValueError:
                pass
        model = model.to(device=device) if cuda_index is None else model.cuda(*cuda_index)
    else:
        model = whisper.load_model(name, device=device, download_root=download_root, in_memory=in_memory)
    modify_model(model)
    if dq:
        from .quantization import ptdq_linear
        ptdq_linear(model)
    return model


# modified version of whisper.transcribe.cli
def cli():
    import argparse
    import os
    from os.path import splitext, split, isfile, join
    from whisper import available_models
    from whisper.utils import optional_int, optional_float
    from .utils import str_to_valid_type, get_func_parameters

    str2val = {"true": True, "false": False, "1": True, "0": False}

    def str2bool(string: str) -> bool:
        string = string.lower()
        if string in str2val:
            return str2val[string]
        raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")

    def valid_model_name(name):
        if name in available_models() or os.path.exists(name):
            return name
        raise ValueError(
            f"model should be one of {available_models()} or path to a model checkpoint"
        )

    def update_options_with_args(arg_key: str, options: Optional[dict] = None, pop: bool = False):
        extra_options = args.pop(arg_key) if pop else args.get(arg_key)
        if not extra_options:
            return
        extra_options = [kv.split('=', maxsplit=1) for kv in extra_options]
        missing_val = [kv[0] for kv in extra_options if len(kv) == 1]
        if missing_val:
            raise ValueError(f'Following expected values for the following custom options: {missing_val}')
        extra_options = dict((k, str_to_valid_type(v)) for k, v in extra_options)
        if options is None:
            return extra_options
        options.update(extra_options)

    OUTPUT_FORMATS_METHODS = {
        "srt": "to_srt_vtt",
        "ass": "to_ass",
        "json": "save_as_json",
        "vtt": "to_srt_vtt",
        "tsv": "to_tsv",
        "txt": "to_txt",
    }

    OUTPUT_FORMATS = set(OUTPUT_FORMATS_METHODS.keys())

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("inputs", nargs="+", type=str,
                        help="audio/video filepath/URL(s) to transcribe "
                             "or json file(s) to process into [output_format]")
    parser.add_argument("--output", "-o", action="extend", nargs="+", type=str,
                        help="output filepaths(s);"
                             "if not specified, auto-named output file(s) will be saved to "
                             "[output_dir] or current dir if not specified.")
    parser.add_argument("--model", '-m', default="base", type=valid_model_name,
                        help="name of the Whisper model to use")
    parser.add_argument("--model_dir", type=str, default=None,
                        help="the path to save model files; uses ~/.cache/whisper by default")
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu",
                        help="device to use for PyTorch inference")
    parser.add_argument("--cpu_preload", type=str2bool, default=True,
                        help="load model into CPU memory first then move model to specified device; "
                             "this reduces GPU memory usage when loading model.")
    parser.add_argument("--output_dir", "-d", type=str,
                        help="directory to save the outputs;"
                             "if a path in [output] does not have parent, that output will be save to this directory")
    parser.add_argument("--output_format", "-f", type=str,
                        help="format of the output file(s); "
                             f"Supported Formats: {OUTPUT_FORMATS}; "
                             "use ',' to separate multiple formats")
    parser.add_argument("--verbose", '-v', type=int, default=1, choices=(0, 1, 2),
                        help="whether to display the text being decoded to the console; "
                             "if 2, display all the details; "
                             "if 1, display progressbar; "
                             "if 0, display nothing")

    parser.add_argument("--dynamic_quantization", "-dq", action='store_true',
                        help="whether to apply Dynamic Quantization to model "
                             "to reduced memory usage (~half less) and increase inference speed "
                             "at cost of slight decrease in accuracy; Only for CPU; "
                             "NOTE: overhead might make inference slower for models smaller than 'large'")

    parser.add_argument("--task", type=str, default="transcribe",
                        choices=["transcribe", "translate"],
                        help="whether to perform X->X speech recognition ('transcribe') "
                             "or X->English translation ('translate')")
    parser.add_argument("--language", '-l', type=str, default=None,
                        choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
                        help="language spoken in the audio, specify None to perform language detection")

    parser.add_argument("--prepend_punctuations", '-pp', type=str, default="\"'“¿([{-",
                        help="Punctuations to prepend to next word")
    parser.add_argument("--append_punctuations", '-ap', type=str, default="\"'.。,,!!??::”)]}、",
                        help="Punctuations to append to previous word")

    parser.add_argument("--gap_padding", type=str, default=" ...",
                        help="padding prepend to each segments for word timing alignment;"
                             "used to reduce the probability of model predicting timestamps "
                             "earlier than the first utterance")

    parser.add_argument("--word_timestamps", type=str2bool, default=True,
                        help="extract word-level timestamps using the cross-attention pattern and dynamic time warping,"
                             "and include the timestamps for each word in each segment;"
                             "disabling this will prevent segments from splitting/merging properly.")

    parser.add_argument("--regroup", type=str, default="True",
                        help="whether to regroup all words into segments with more natural boundaries;"
                             "specify string for customizing the regrouping algorithm"
                             "ignored if [word_timestamps]=False.")

    parser.add_argument('--ts_num', type=int, default=0,
                        help="number of extra inferences to perform to find the mean timestamps")
    parser.add_argument('--ts_noise', type=float, default=0.1,
                        help="percentage of noise to add to audio_features to perform inferences for [ts_num]")

    parser.add_argument('--suppress_silence', type=str2bool, default=True,
                        help="whether to suppress timestamp where audio is silent at segment-level"
                             "and word-level if [suppress_word_ts]=True")
    parser.add_argument('--suppress_word_ts', type=str2bool, default=True,
                        help="whether to suppress timestamps where audio is silent at word-level; "
                             "ignored if [suppress_silence]=False")

    parser.add_argument('--suppress_ts_tokens', type=str2bool, default=False,
                        help="whether to use silence mask to suppress silent timestamp tokens during inference; "
                             "increases word accuracy in some cases, but tends reduce 'verbatimness' of the transcript"
                             "ignored if [suppress_silence]=False")

    parser.add_argument("--q_levels", type=int, default=20,
                        help="quantization levels for generating timestamp suppression mask; "
                             "acts as a threshold to marking sound as silent;"
                             "fewer levels will increase the threshold of volume at which to mark a sound as silent")

    parser.add_argument("--k_size", type=int, default=5,
                        help="Kernel size for average pooling waveform to generate suppression mask; "
                             "recommend 5 or 3; higher sizes will reduce detection of silence")

    parser.add_argument('--time_scale', type=float,
                        help="factor for scaling audio duration for inference;"
                             "greater than 1.0 'slows down' the audio; "
                             "less than 1.0 'speeds up' the audio; "
                             "1.0 is no scaling")

    parser.add_argument('--vad', type=str2bool, default=False,
                        help='whether to use Silero VAD to generate timestamp suppression mask; '
                             'Silero VAD requires PyTorch 1.12.0+;'
                             'Official repo: https://github.com/snakers4/silero-vad')
    parser.add_argument('--vad_threshold', type=float, default=0.35,
                        help='threshold for detecting speech with Silero VAD. (Default: 0.35); '
                             'low threshold reduces false positives for silence detection')
    parser.add_argument('--vad_onnx', type=str2bool, default=False,
                        help='whether to use ONNX for Silero VAD')

    parser.add_argument('--min_word_dur', type=float, default=0.1,
                        help="shortest duration each word is allowed to reach for silence suppression")
    parser.add_argument('--nonspeech_error', type=float, default=0.3,
                        help="relative error of non-speech sections that appear in between a word for "
                             "silence suppression.")

    parser.add_argument('--max_chars', type=int,
                        help="maximum number of character allowed in each segment")
    parser.add_argument('--max_words', type=int,
                        help="maximum number of words allowed in each segment")

    parser.add_argument('--demucs', type=str2bool, default=False,
                        help='whether to reprocess the audio track with Demucs to isolate vocals/remove noise; '
                             'Demucs official repo: https://github.com/facebookresearch/demucs')
    parser.add_argument('--demucs_output', action="extend", nargs="+", type=str,
                        help='path(s) to save the vocals isolated by Demucs as WAV file(s); '
                             'ignored if [demucs]=False')
    parser.add_argument('--only_voice_freq', '-ovf', action='store_true',
                        help='whether to only use sound between 200 - 5000 Hz, where majority of human speech are.')

    parser.add_argument('--strip', type=str2bool, default=True,
                        help="whether to remove spaces before and after text on each segment for output")

    parser.add_argument('--tag', type=str, action="extend", nargs="+",
                        help="a pair tags used to change the properties a word at its predicted time"
                             "SRT Default: '<font color=\"#00ff00\">', '</font>'"
                             "VTT Default: '<u>', '</u>'"
                             "ASS Default: '{\\1c&HFF00&}', '{\\r}'")
    parser.add_argument('--segment_level', type=str2bool, default=True,
                        help="whether to use segment-level timestamps in output")
    parser.add_argument('--word_level', type=str2bool, default=True,
                        help="whether to use word-level timestamps in output")

    parser.add_argument('--reverse_text', type=str2bool, default=False,
                        help="whether to reverse the order of words for each segment of text output")

    # ass output
    parser.add_argument('--font', type=str, default='Arial',
                        help="word font for ASS output(s)")
    parser.add_argument('--font_size', type=int, default=48,
                        help="word font size for ASS output(s)")
    parser.add_argument('--karaoke', type=str2bool, default=False,
                        help="whether to use progressive filling highlights for karaoke effect (only for ASS outputs)")

    parser.add_argument("--temperature", type=float, default=0,
                        help="temperature to use for sampling")
    parser.add_argument("--best_of", type=optional_int,
                        help="number of candidates when sampling with non-zero temperature")
    parser.add_argument("--beam_size", type=optional_int,
                        help="number of beams in beam search, only applicable when temperature is zero")
    parser.add_argument("--patience", type=float, default=None,
                        help="optional patience value to use in beam decoding, "
                             "as in https://arxiv.org/abs/2204.05424, "
                             "the default (1.0) is equivalent to conventional beam search")
    parser.add_argument("--length_penalty", type=float, default=None,
                        help="optional token length penalty coefficient (alpha) "
                             "as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")

    parser.add_argument("--suppress_tokens", type=str, default="-1",
                        help="comma-separated list of token ids to suppress during sampling; "
                             "'-1' will suppress most special characters except common punctuations")
    parser.add_argument("--initial_prompt", type=str, default=None,
                        help="optional text to provide as a prompt for the first window.")
    parser.add_argument("--condition_on_previous_text", type=str2bool, default=True,
                        help="if True, provide the previous output of the model as a prompt for the next window; "
                             "disabling may make the text inconsistent across windows, "
                             "but the model becomes less prone to getting stuck in a failure loop")
    parser.add_argument("--fp16", type=str2bool, default=True,
                        help="whether to perform inference in fp16; True by default")

    parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2,
                        help="temperature to increase when falling back when the decoding fails to meet either of "
                             "the thresholds below")
    parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4,
                        help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
    parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0,
                        help="if the average log probability is lower than this value, treat the decoding as failed")
    parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6,
                        help="if the probability of the <|nospeech|> token is higher than this value AND the decoding "
                             "has failed due to `logprob_threshold`, consider the segment as silence")
    parser.add_argument("--threads", type=optional_int, default=0,
                        help="number of threads used by torch for CPU inference; "
                             "supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")

    parser.add_argument('--mel_first', action='store_true',
                        help='process entire audio track into log-Mel spectrogram first instead in chunks')

    parser.add_argument('--only_ffmpeg', action='store_true',
                        help='whether to use only FFmpeg (and not yt-dlp) for URls')

    parser.add_argument('--overwrite', '-y', action='store_true',
                        help='overwrite all output files')

    parser.add_argument('--debug', action='store_true',
                        help='print all input/output pair(s) and all arguments used for transcribing/translating')

    parser.add_argument('--transcribe_method', '-tm', type=str, default='transcribe',
                        choices=('transcribe', 'transcribe_minimal'))

    parser.add_argument('--align', '-a', action="extend", nargs='+', type=str,
                        help='path(s) to TXT file(s) or JSON previous result(s)')

    parser.add_argument('--refine', '-r', action='store_true',
                        help='Refine timestamps to increase precision of timestamps')

    parser.add_argument('--locate', '-lc', action="extend", nargs='+', type=str,
                        help='words to locate in the audio(s); skips transcription and output')

    parser.add_argument('--refine_option', '-ro', action="extend", nargs='+', type=str,
                        help='Extra option(s) to use for refining timestamps; Replace True/False with 1/0; '
                             'E.g. --refine_option "steps=sese" --refine_options "rel_prob_decrease=0.05"')
    parser.add_argument('--demucs_option', '-do', action="extend", nargs='+', type=str,
                        help='Extra option(s) to use for demucs; Replace True/False with 1/0; '
                             'E.g. --demucs_option "shifts=3" --demucs_options "overlap=0.5"')
    parser.add_argument('--model_option', '-mo', action="extend", nargs='+', type=str,
                        help='Extra option(s) to use for loading model; Replace True/False with 1/0; '
                             'E.g. --model_option "download_root=./downloads"')
    parser.add_argument('--transcribe_option', '-to', action="extend", nargs='+', type=str,
                        help='Extra option(s) to use for transcribing/alignment/locating; Replace True/False with 1/0; '
                             'E.g. --transcribe_option "ignore_compatibility=1"')
    parser.add_argument('--save_option', '-so', action="extend", nargs='+', type=str,
                        help='Extra option(s) to use for text outputs; Replace True/False with 1/0; '
                             'E.g. --save_option "highlight_color=ffffff"')

    parser.add_argument('--faster_whisper', '-fw', action='store_true',
                        help='whether to use faster-whisper (https://github.com/guillaumekln/faster-whisper); '
                             'note: some features may not be available')

    args = parser.parse_args().__dict__
    debug = args.pop('debug')
    if not args['language'] and (args['align'] or args['locate']):
        raise ValueError('langauge is required for --align / --locate')

    is_faster_whisper = args.pop('faster_whisper')
    model_name: str = args.pop("model")
    model_dir: str = args.pop("model_dir")
    inputs: List[Union[str, torch.Tensor]] = args.pop("inputs")
    outputs: List[str] = args.pop("output")
    output_dir: str = args.pop("output_dir")
    output_format = args.pop("output_format")
    overwrite: bool = args.pop("overwrite")
    use_demucs = args['demucs'] or False
    demucs_outputs: List[Optional[str]] = args.pop("demucs_output")
    args['demucs_options'] = update_options_with_args('demucs_option', pop=True)
    regroup = args.pop('regroup')
    max_chars = args.pop('max_chars')
    max_words = args.pop('max_words')
    args['verbose'] = False if args['verbose'] == 1 else (True if args['verbose'] == 2 else None)
    show_curr_task = args['verbose'] is not None
    strings_to_locate = args.pop('locate')
    if dq := args.pop('dynamic_quantization', False):
        args['device'] = 'cpu'
    if args['reverse_text']:
        args['reverse_text'] = (args.get('prepend_punctuations'), args.get('append_punctuations'))

    if regroup:
        try:
            regroup = str2bool(regroup)
        except ValueError:
            pass
    curr_output_formats: List[str] = output_format.split(',') if output_format else []
    unsupported_formats = list(set(map(str.lower, curr_output_formats)) - OUTPUT_FORMATS)
    if outputs:
        unsupported_formats.extend(list(set(splitext(o)[-1].lower().strip('.') for o in outputs) - OUTPUT_FORMATS))
    if len(unsupported_formats) != 0:
        raise NotImplementedError(f'{unsupported_formats} are not supported. Supported formats: {OUTPUT_FORMATS}.')

    has_demucs_output = bool(demucs_outputs)
    if use_demucs and has_demucs_output and len(demucs_outputs) != len(inputs):
        raise NotImplementedError(f'[demucs_output] and [inputs] do not match in count. '
                                  f'Got {len(demucs_outputs)} and {len(inputs)}')

    if tag := args.get('tag'):
        assert tag == ['-1'] or len(tag) == 2, f'[tag] must be a pair of str but got {tag}'

    def make_parent(filepath: str):
        if parent := split(filepath)[0]:
            os.makedirs(parent, exist_ok=True)

    def is_json(file: str):
        return file.endswith(".json")

    def call_method_with_options(method, options: dict, include_first: bool = True):
        def val_to_str(val) -> str:
            if isinstance(val, (np.ndarray, torch.Tensor)):
                return f'{val.__class__}(shape:{list(val.shape)})'
            elif isinstance(val, str):
                return f'"{val}"'
            elif isinstance(val, bytes):
                return f'{type(val)}(len:{len(val)})'
            elif isinstance(val, torch.nn.Module):
                return str(type(val))
            return str(val)

        params = tuple(get_func_parameters(method))
        if debug:
            temp_options = {k: options.pop(k) for k in params if k in options}
            temp_options.update(options)
            options = temp_options
            options_str = ',\n'.join(
                f'    {k}={val_to_str(v)}'
                for k, v in options.items()
                if include_first or k != params[0]
            )
            if options_str:
                options_str = f'\n{options_str}\n'
            else:
                print(options, params)
            print(f'{method.__qualname__}({options_str})')
        return method(**options)

    if alignments := args['align']:
        if unsupported_align_fmts := \
                [_ext for p in alignments if (_ext := splitext(p)[-1].lower()) not in ('.json', '.txt')]:
            raise NotImplementedError(
                f'Unsupported format(s) for alignment: {unsupported_align_fmts}'
            )
        if len(inputs) != len(alignments):
            raise NotImplementedError(
                f'Got {len(inputs)} audio file(s) but specified {len(alignments)} file(s) to align.'
            )
    else:
        alignments = ['']*len(inputs)

    def finalize_outputs(input_file: str, _output: str = None, _alignment: str = None) -> List[str]:
        _curr_output_formats = curr_output_formats.copy()
        basename, ext = splitext(_output or input_file)
        ext = ext[1:]
        if _output:
            if ext.lower() in OUTPUT_FORMATS:
                _curr_output_formats.append(ext)
            else:
                basename = _output
        if not _curr_output_formats:
            _curr_output_formats = ["srt" if is_json(input_file) or is_json(_alignment) else "json"]
        _outputs = [f'{basename}.{ext}' for ext in set(_curr_output_formats)]
        if output_dir:
            _outputs = [join(output_dir, o) for o in _outputs]

        return _outputs

    if outputs:
        if len(outputs) != len(inputs):
            raise NotImplementedError(f'Got {len(inputs)} audio file(s) but specified {len(outputs)} output file(s).')
        final_outputs = [finalize_outputs(i, o, a) for i, o, a in zip(inputs, outputs, alignments)]
    else:
        if not output_dir:
            output_dir = '.'
        final_outputs = [finalize_outputs(i, _alignment=a) for i, a in zip(inputs, alignments)]

    if not overwrite:

        def cancel_overwrite():
            resp = input(f'{path} already exist, overwrite (y/n)? ').lower()
            if resp in ('y', 'n'):
                return resp == 'n'
            print(f'Expected "y" or "n", but got {resp}.')
            return True

        for paths in final_outputs:
            for path in paths:
                if isfile(path) and cancel_overwrite():
                    return

    if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
        if args["language"] is not None:
            warnings.warn(f"{model_name} is an English-only model but receipted "
                          f"'{args['language']}'; using English instead.")
        args["language"] = "en"

    temperature = args.pop("temperature")
    increment = args.pop("temperature_increment_on_fallback")
    if increment is not None:
        temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
    else:
        temperature = [temperature]

    args['temperature'] = temperature

    threads = args.pop("threads")
    if threads > 0:
        torch.set_num_threads(threads)

    if debug:
        print('Input(s)  ->  Outputs(s)')
        for i, (input_audio, output_paths, alignment) in enumerate(zip(inputs, final_outputs, alignments)):
            dm_output = f' {demucs_outputs[i]} ->' if demucs_outputs else ''
            alignment = f' + "{alignment}"' if alignment else ''
            print(f'"{input_audio}"{alignment}  ->{dm_output}  {output_paths}')
        print('')

    if show_curr_task:
        model_from_str = '' if model_dir is None else f' from {model_dir}'
        model_loading_str = f'{"Faster-Whisper" if is_faster_whisper else "Whisper"} {model_name} model {model_from_str}'
        print(f'Loading {model_loading_str}\r', end='\n' if debug else '')
    else:
        model_loading_str = ''

    alignments = args['align']
    model = None

    def _load_model():
        nonlocal model
        if model is None:
            model_options = dict(
                name=model_name,
                model_size_or_path=model_name,
                device=args.get('device'),
                download_root=model_dir,
                dq=dq,
            )
            load_model_func = load_faster_whisper if is_faster_whisper else load_model
            model_options = isolate_useful_options(model_options, load_model_func)
            update_options_with_args('model_option', model_options)
            model = call_method_with_options(load_model_func, model_options)
            if model_loading_str:
                print(f'Loaded {model_loading_str}  ')
        return model

    for i, (input_audio, output_paths) in enumerate(zip(inputs, final_outputs)):
        skip_output = False
        if isinstance(input_audio, str) and is_json(input_audio):
            result = WhisperResult(input_audio)
        else:
            model = _load_model()
            args['regroup'] = False
            args['audio'] = input_audio
            if has_demucs_output:
                args['demucs_output'] = demucs_outputs[i]
            transcribe_method = args.get('transcribe_method')
            text = None
            if alignments and (text := alignments[i]):
                if text.endswith('.json'):
                    text = WhisperResult(text)
                else:
                    with open(text, 'r', encoding='utf-8') as f:
                        text = f.read()
                args['text'] = text
                transcribe_method = 'align'
            if is_faster_whisper and transcribe_method == 'transcribe':
                transcribe_method = 'transcribe_stable'
            if strings_to_locate and (text := strings_to_locate[i]):
                args['text'] = text
                transcribe_method = 'locate'
                skip_output = args['verbose'] = True
            transcribe_method = getattr(model, transcribe_method)
            transcribe_options = isolate_useful_options(args, transcribe_method)
            if not text:
                decoding_options = (
                    isolate_useful_options(args, model.transcribe if is_faster_whisper else DecodingOptions)
                )
                if is_faster_whisper:
                    if decoding_options['suppress_tokens']:
                        decoding_options['suppress_tokens'] = (
                            list(map(int, decoding_options['suppress_tokens'].split(',')))
                        )
                    for k in list(decoding_options.keys()):
                        if decoding_options[k] is None:
                            del decoding_options[k]
                transcribe_options.update(decoding_options)
            update_options_with_args('transcribe_option', transcribe_options)
            result: WhisperResult = call_method_with_options(transcribe_method, transcribe_options)

        if skip_output:
            continue

        if args['refine']:
            model = _load_model()
            refine_options = isolate_useful_options(args, model.refine)
            refine_options['result'] = result
            update_options_with_args('refine_option', refine_options)
            call_method_with_options(model.refine, refine_options)

        if args.get('word_timestamps'):
            if regroup:
                result.regroup(regroup, verbose=args['verbose'] or debug)
            if max_chars or max_words:
                result.split_by_length(max_chars=max_chars, max_words=max_words)

        for path in output_paths:
            make_parent(path)
            save_method = getattr(result, OUTPUT_FORMATS_METHODS[splitext(path)[-1][1:]])
            args['filepath'] = path
            args['path'] = path
            save_options = isolate_useful_options(args, save_method)
            update_options_with_args('save_option', save_options)
            call_method_with_options(save_method, save_options)


if __name__ == '__main__':
    cli()