Dionyssos commited on
Commit
737edba
·
1 Parent(s): a77742c

soundscaps

Browse files
Files changed (3) hide show
  1. app.py +496 -82
  2. audiocraft.py +724 -0
  3. requirements.txt +6 -4
app.py CHANGED
@@ -21,6 +21,11 @@ import nltk
21
  from num2words import num2words
22
  from num2word_greek.numbers2words import convert_numbers
23
  from audionar import VitsModel, VitsTokenizer
 
 
 
 
 
24
 
25
  nltk.download('punkt', download_dir='./')
26
  nltk.download('punkt_tab', download_dir='./')
@@ -443,97 +448,118 @@ language_names = ['Ancient greek',
443
 
444
 
445
  def audionar_tts(text=None,
446
- lang='romanian'):
 
 
447
 
448
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
449
 
450
- lang = lang.lower()
451
-
452
- # https://huggingface.co/spaces/mms-meta/MMS
453
-
454
- if 'hun' in lang:
455
-
456
- lang_code = 'hun'
457
 
458
- elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]):
 
 
 
 
 
 
 
 
 
459
 
460
- # romani carpathian (has also Vlax) - cooler voice
461
- lang_code = 'rmc-script_latin'
462
 
463
- elif 'rom' in lang:
 
 
 
464
 
465
- lang_code = 'ron'
 
466
 
467
- elif 'ger' in lang or 'deu' in lang or 'allem' in lang:
 
 
 
468
 
469
- lang_code = 'deu'
470
 
471
- elif 'french' in lang:
 
472
 
473
- lang_code = 'fra'
 
 
474
 
475
- elif 'eng' in lang:
476
 
477
- lang_code = 'eng'
478
 
479
- elif 'ancient greek' in lang:
 
 
 
 
 
 
 
 
 
480
 
481
- lang_code = 'grc'
482
 
483
- else:
 
484
 
485
- lang_code = lang.split()[0].strip() # latin & future option
486
 
487
- # LATIN / GRC / CYRILLIC
 
488
 
489
- text = only_greek_or_only_latin(text, lang=lang_code) # assure gr-chars if lang=='grc' / latin if lang!='grc'
490
 
491
- # NUMERALS (^ in math expression found & substituted here before arriving to fix_vocals)
 
 
 
 
492
 
493
- text = transliterate_number(text, lang=lang_code)
494
 
495
- # PRONOUNC.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- text = fix_vocals(text, lang=lang_code)
 
 
498
 
499
- # VITS
500
-
501
- global cached_lang_code, cached_net_g, cached_tokenizer
502
-
503
- if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
504
- cached_lang_code = lang_code
505
- cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device)
506
- cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
507
-
508
- net_g = cached_net_g
509
- tokenizer = cached_tokenizer
510
-
511
- total_audio = []
512
 
513
- if not isinstance(text, list):
514
- text = textwrap.wrap(text, width=439)
515
-
516
- for _t in text:
517
- inputs = tokenizer(_t, return_tensors="pt")
518
- with torch.no_grad():
519
- x = net_g(input_ids=inputs.input_ids.to(device),
520
- attention_mask=inputs.attention_mask.to(device),
521
- lang_code=lang_code,
522
- )[0, :]
523
- total_audio.append(x)
524
 
525
- print(f'\n\n_______________________________ {_t} {x.shape=}')
526
 
527
- x = torch.cat(total_audio).cpu().numpy()
528
-
529
- tmp_file = f'_speech.wav'
530
-
531
- audiofile.write(tmp_file, x, 16000)
532
-
533
- return tmp_file
534
-
535
-
536
- # --
537
 
538
 
539
  device = 0 if torch.cuda.is_available() else "cpu"
@@ -838,7 +864,334 @@ def plot_expression(arousal, dominance, valence):
838
  # plt.show()
839
 
840
  # TTS
841
- VOICES = [f'wav/{vox}' for vox in os.listdir('wav')]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  _tts = StyleTTS2().to('cpu')
843
 
844
  def only_greek_or_only_latin(text, lang='grc'):
@@ -968,23 +1321,78 @@ def only_greek_or_only_latin(text, lang='grc'):
968
 
969
 
970
  def other_tts(text='Hallov worlds Far over the',
971
- ref_s='wav/af_ZA_google-nwu_0184.wav'):
 
 
 
 
 
 
 
972
 
973
- text = only_greek_or_only_latin(text, lang='eng')
974
 
975
- x = _tts.inference(text, ref_s=ref_s)[0:1, 0, :]
 
 
976
 
977
- x = torch.cat([.99 * x,
978
- .94 * x], 0).cpu().numpy() # Stereo
 
979
 
980
- # x /= np.abs(x).max() + 1e-7 ~ Volume normalisation @api.py:tts_multi_sentence() OR demo.py
 
 
981
 
982
- tmp_file = f'_speech.wav' # N x clients (cleanup vs tmp file / client)
 
983
 
984
- audiofile.write(tmp_file, x, 24000)
 
 
985
 
986
- return tmp_file
987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
 
989
  def update_selected_voice(voice_filename):
990
  return 'wav/' + voice_filename + '.wav'
@@ -1035,11 +1443,19 @@ with gr.Blocks(theme='huggingface', css=css_buttons) as demo:
1035
  # Main input and output components
1036
  with gr.Row():
1037
  text_input = gr.Textbox(
1038
- label="Enter text for TTS:",
1039
  placeholder="Type your message here...",
1040
  lines=4,
1041
  value="Farover the misty mountains cold too dungeons deep and caverns old.",
1042
  )
 
 
 
 
 
 
 
 
1043
  generate_button = gr.Button("Generate Audio", variant="primary")
1044
 
1045
  output_audio = gr.Audio(label="TTS Output")
@@ -1066,7 +1482,7 @@ with gr.Blocks(theme='huggingface', css=css_buttons) as demo:
1066
 
1067
  generate_button.click(
1068
  fn=other_tts,
1069
- inputs=[text_input, selected_voice],
1070
  outputs=output_audio
1071
  )
1072
 
@@ -1108,11 +1524,9 @@ with gr.Blocks(theme='huggingface', css=css_buttons) as demo:
1108
  value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.',
1109
  label="Type text for TTS"
1110
  )
1111
- lang_dropdown = gr.Dropdown(
1112
- choices=language_names,
1113
- label="TTS language",
1114
- value="Ancient greek",
1115
- )
1116
 
1117
  # Create a button to trigger the TTS function
1118
  tts_button = gr.Button("Generate Audio")
@@ -1123,7 +1537,7 @@ with gr.Blocks(theme='huggingface', css=css_buttons) as demo:
1123
  # Link the button click event to the mms_tts function
1124
  tts_button.click(
1125
  fn=audionar_tts,
1126
- inputs=[text_input, lang_dropdown],
1127
  outputs=audio_output
1128
  )
1129
 
 
21
  from num2words import num2words
22
  from num2word_greek.numbers2words import convert_numbers
23
  from audionar import VitsModel, VitsTokenizer
24
+ from audiocraft import AudioGen
25
+
26
+
27
+
28
+ audiogen = AudioGen().eval().to('cpu')
29
 
30
  nltk.download('punkt', download_dir='./')
31
  nltk.download('punkt_tab', download_dir='./')
 
448
 
449
 
450
  def audionar_tts(text=None,
451
+ lang='romanian',
452
+ soundscape='',
453
+ cache_lim=24):
454
 
455
  # https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py
456
 
 
 
 
 
 
 
 
457
 
458
+ lang_map = {
459
+ 'ancient greek': 'grc',
460
+ 'english': 'eng',
461
+ 'deutsch': 'deu',
462
+ 'french': 'fra',
463
+ 'hungarian': 'hun',
464
+ 'romanian': 'ron',
465
+ 'serbian (approx.)': 'rmc-script_latin',
466
+ }
467
+ lang_code = lang_map.get(lang.lower(), lang.lower().split()[0].strip())
468
 
469
+ global cached_lang_code, cached_net_g, cached_tokenizer
 
470
 
471
+ if 'cached_lang_code' not in globals() or cached_lang_code != lang_code:
472
+ cached_lang_code = lang_code
473
+ cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval()
474
+ cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}')
475
 
476
+ net_g = cached_net_g
477
+ tokenizer = cached_tokenizer
478
 
479
+ total_audio = []
480
+
481
+ final_audio = None
482
+ speech_audio = None
483
 
 
484
 
485
+ if text and text.strip():
486
+
487
 
488
+ text = only_greek_or_only_latin(text, lang=lang_code)
489
+ text = transliterate_number(text, lang=lang_code)
490
+ text = fix_vocals(text, lang=lang_code)
491
 
 
492
 
493
+ sentences = textwrap.wrap(text, width=439)
494
 
495
+ total_audio_parts = []
496
+ for sentence in sentences:
497
+ inputs = cached_tokenizer(sentence, return_tensors="pt")
498
+ with torch.no_grad():
499
+ audio_part = cached_net_g(
500
+ input_ids=inputs.input_ids.to(device),
501
+ attention_mask=inputs.attention_mask.to(device),
502
+ lang_code=lang_code,
503
+ )[0, :]
504
+ total_audio_parts.append(audio_part)
505
 
506
+ speech_audio = torch.cat(total_audio_parts).cpu().numpy()
507
 
508
+ # AudioGen
509
+ if soundscape and soundscape.strip():
510
 
 
511
 
512
+ speech_duration_secs = len(speech_audio) / 16000 if speech_audio is not None else 0
513
+ target_duration = max(speech_duration_secs + 0.74, 2.0)
514
 
 
515
 
516
+ background_audio = audiogen.generate(
517
+ soundscape,
518
+ duration=target_duration,
519
+ cache_lim=max(4, int(cache_lim)) # at least allow 10 A/R stEps
520
+ ).numpy()
521
 
522
+ if speech_audio is not None:
523
 
524
+ len_speech = len(speech_audio)
525
+ len_background = len(background_audio)
526
+
527
+ if len_background > len_speech:
528
+ padding = np.zeros(len_background - len_speech,
529
+ dtype=np.float32)
530
+ speech_audio = np.concatenate([speech_audio, padding])
531
+ elif len_speech > len_background:
532
+ padding = np.zeros(len_speech - len_background,
533
+ dtype=np.float32)
534
+ background_audio = np.concatenate([background_audio, padding])
535
+
536
+
537
+ speech_audio_stereo = speech_audio[None, :]
538
+ background_audio_stereo = background_audio[None, :]
539
+
540
+
541
+ final_audio = np.concatenate([
542
+ 0.49 * speech_audio_stereo + 0.51 * background_audio_stereo,
543
+ 0.51 * background_audio_stereo + 0.49 * speech_audio_stereo
544
+ ], 0)
545
+ else:
546
+ final_audio = background_audio
547
+
548
+ # If no soundscape, use the speech audio as is.
549
+ elif speech_audio is not None:
550
+ final_audio = speech_audio
551
 
552
+ # If both inputs are empty, create a 2s silent audio file.
553
+ if final_audio is None:
554
+ final_audio = np.zeros(16000 * 2, dtype=np.float32)
555
 
556
+ wavfile = '_vits_.wav'
557
+ audiofile.write(wavfile, final_audio, 16000)
 
 
 
 
 
 
 
 
 
 
 
558
 
559
+ return wavfile
 
 
 
 
 
 
 
 
 
 
560
 
 
561
 
562
+ # -- EXPRESSIO
 
 
 
 
 
 
 
 
 
563
 
564
 
565
  device = 0 if torch.cuda.is_available() else "cpu"
 
864
  # plt.show()
865
 
866
  # TTS
867
+ # VOICES = [f'wav/{vox}' for vox in os.listdir('wav')]
868
+ # add unidecode (to parse non-roman characters for the StyleTTS2
869
+ # # for the VITS it should better skip the unknown letters - dont use unidecode())
870
+ # at generation fill the state of "last tts"
871
+ # at record fill the state of "last record" and place in list of voice/langs for TTS
872
+ VOICES = ['jv_ID_google-gmu_04982.wav',
873
+ 'it_IT_mls_1595.wav',
874
+ 'en_US_vctk_p303.wav',
875
+ 'en_US_vctk_p306.wav',
876
+ 'it_IT_mls_8842.wav',
877
+ 'en_US_cmu_arctic_ksp.wav',
878
+ 'jv_ID_google-gmu_05970.wav',
879
+ 'en_US_vctk_p318.wav',
880
+ 'ha_NE_openbible.wav',
881
+ 'ne_NP_ne-google_0883.wav',
882
+ 'en_US_vctk_p280.wav',
883
+ 'bn_multi_1010.wav',
884
+ 'en_US_vctk_p259.wav',
885
+ 'it_IT_mls_844.wav',
886
+ 'en_US_vctk_p269.wav',
887
+ 'en_US_vctk_p285.wav',
888
+ 'de_DE_m-ailabs_angela_merkel.wav',
889
+ 'en_US_vctk_p316.wav',
890
+ 'en_US_vctk_p362.wav',
891
+ 'jv_ID_google-gmu_06207.wav',
892
+ 'tn_ZA_google-nwu_9061.wav',
893
+ 'fr_FR_tom.wav',
894
+ 'en_US_vctk_p233.wav',
895
+ 'it_IT_mls_4975.wav',
896
+ 'en_US_vctk_p236.wav',
897
+ 'bn_multi_01232.wav',
898
+ 'bn_multi_5958.wav',
899
+ 'it_IT_mls_9185.wav',
900
+ 'en_US_vctk_p248.wav',
901
+ 'en_US_vctk_p287.wav',
902
+ 'it_IT_mls_9772.wav',
903
+ 'te_IN_cmu-indic_sk.wav',
904
+ 'tn_ZA_google-nwu_8333.wav',
905
+ 'en_US_vctk_p260.wav',
906
+ 'en_US_vctk_p247.wav',
907
+ 'en_US_vctk_p329.wav',
908
+ 'en_US_cmu_arctic_fem.wav',
909
+ 'en_US_cmu_arctic_rms.wav',
910
+ 'en_US_vctk_p308.wav',
911
+ 'jv_ID_google-gmu_08736.wav',
912
+ 'en_US_vctk_p245.wav',
913
+ 'fr_FR_m-ailabs_nadine_eckert_boulet.wav',
914
+ 'jv_ID_google-gmu_03314.wav',
915
+ 'en_US_vctk_p239.wav',
916
+ 'jv_ID_google-gmu_05540.wav',
917
+ 'it_IT_mls_7440.wav',
918
+ 'en_US_vctk_p310.wav',
919
+ 'en_US_vctk_p237.wav',
920
+ 'en_US_hifi-tts_92.wav',
921
+ 'en_US_cmu_arctic_aew.wav',
922
+ 'ne_NP_ne-google_2099.wav',
923
+ 'en_US_vctk_p226.wav',
924
+ 'af_ZA_google-nwu_1919.wav',
925
+ 'jv_ID_google-gmu_03727.wav',
926
+ 'en_US_vctk_p317.wav',
927
+ 'tn_ZA_google-nwu_0378.wav',
928
+ 'nl_pmk.wav',
929
+ 'en_US_vctk_p286.wav',
930
+ 'tn_ZA_google-nwu_3342.wav',
931
+ # 'en_US_vctk_p343.wav',
932
+ 'de_DE_m-ailabs_ramona_deininger.wav',
933
+ 'jv_ID_google-gmu_03424.wav',
934
+ 'en_US_vctk_p341.wav',
935
+ 'jv_ID_google-gmu_03187.wav',
936
+ 'ne_NP_ne-google_3960.wav',
937
+ 'jv_ID_google-gmu_06080.wav',
938
+ 'ne_NP_ne-google_3997.wav',
939
+ # 'en_US_vctk_p267.wav',
940
+ 'en_US_vctk_p240.wav',
941
+ 'ne_NP_ne-google_5687.wav',
942
+ 'ne_NP_ne-google_9407.wav',
943
+ 'jv_ID_google-gmu_05667.wav',
944
+ 'jv_ID_google-gmu_01519.wav',
945
+ 'ne_NP_ne-google_7957.wav',
946
+ 'it_IT_mls_4705.wav',
947
+ 'ne_NP_ne-google_6329.wav',
948
+ 'it_IT_mls_1725.wav',
949
+ 'tn_ZA_google-nwu_8914.wav',
950
+ 'en_US_ljspeech.wav',
951
+ 'tn_ZA_google-nwu_4850.wav',
952
+ 'en_US_vctk_p238.wav',
953
+ 'en_US_vctk_p302.wav',
954
+ 'jv_ID_google-gmu_08178.wav',
955
+ 'en_US_vctk_p313.wav',
956
+ 'af_ZA_google-nwu_2418.wav',
957
+ 'bn_multi_00737.wav',
958
+ 'en_US_vctk_p275.wav', # y
959
+ 'af_ZA_google-nwu_0184.wav',
960
+ 'jv_ID_google-gmu_07638.wav',
961
+ 'ne_NP_ne-google_6587.wav',
962
+ 'ne_NP_ne-google_0258.wav',
963
+ 'en_US_vctk_p232.wav',
964
+ 'en_US_vctk_p336.wav',
965
+ 'jv_ID_google-gmu_09039.wav',
966
+ 'en_US_vctk_p312.wav',
967
+ 'af_ZA_google-nwu_8148.wav',
968
+ 'en_US_vctk_p326.wav',
969
+ 'en_US_vctk_p264.wav',
970
+ 'en_US_vctk_p295.wav',
971
+ # 'en_US_vctk_p298.wav',
972
+ 'es_ES_m-ailabs_victor_villarraza.wav',
973
+ 'pl_PL_m-ailabs_nina_brown.wav',
974
+ 'tn_ZA_google-nwu_9365.wav',
975
+ 'en_US_vctk_p294.wav',
976
+ 'jv_ID_google-gmu_00658.wav',
977
+ 'jv_ID_google-gmu_08305.wav',
978
+ 'en_US_vctk_p330.wav',
979
+ 'gu_IN_cmu-indic_cmu_indic_guj_dp.wav',
980
+ 'jv_ID_google-gmu_05219.wav',
981
+ 'en_US_vctk_p284.wav',
982
+ 'de_DE_m-ailabs_eva_k.wav',
983
+ # 'bn_multi_00779.wav',
984
+ 'en_UK_apope.wav',
985
+ 'en_US_vctk_p345.wav',
986
+ 'it_IT_mls_6744.wav',
987
+ 'en_US_vctk_p347.wav',
988
+ 'en_US_m-ailabs_mary_ann.wav',
989
+ 'en_US_m-ailabs_elliot_miller.wav',
990
+ 'en_US_vctk_p279.wav',
991
+ 'ru_RU_multi_nikolaev.wav',
992
+ 'bn_multi_4811.wav',
993
+ 'tn_ZA_google-nwu_7693.wav',
994
+ 'bn_multi_01701.wav',
995
+ 'en_US_vctk_p262.wav',
996
+ # 'en_US_vctk_p266.wav',
997
+ 'en_US_vctk_p243.wav',
998
+ 'en_US_vctk_p297.wav',
999
+ 'en_US_vctk_p278.wav',
1000
+ 'jv_ID_google-gmu_02059.wav',
1001
+ 'en_US_vctk_p231.wav',
1002
+ 'te_IN_cmu-indic_kpn.wav',
1003
+ 'en_US_vctk_p250.wav',
1004
+ 'it_IT_mls_4974.wav',
1005
+ 'en_US_cmu_arctic_awbrms.wav',
1006
+ # 'en_US_vctk_p263.wav',
1007
+ 'nl_femal.wav',
1008
+ 'tn_ZA_google-nwu_6116.wav',
1009
+ 'jv_ID_google-gmu_06383.wav',
1010
+ 'en_US_vctk_p225.wav',
1011
+ 'en_US_vctk_p228.wav',
1012
+ 'it_IT_mls_277.wav',
1013
+ 'tn_ZA_google-nwu_7866.wav',
1014
+ 'en_US_vctk_p300.wav',
1015
+ 'ne_NP_ne-google_0649.wav',
1016
+ 'es_ES_carlfm.wav',
1017
+ 'jv_ID_google-gmu_06510.wav',
1018
+ 'de_DE_m-ailabs_rebecca_braunert_plunkett.wav',
1019
+ 'en_US_vctk_p340.wav',
1020
+ 'en_US_cmu_arctic_gka.wav',
1021
+ 'ne_NP_ne-google_2027.wav',
1022
+ 'jv_ID_google-gmu_09724.wav',
1023
+ 'en_US_vctk_p361.wav',
1024
+ 'ne_NP_ne-google_6834.wav',
1025
+ 'jv_ID_google-gmu_02326.wav',
1026
+ 'fr_FR_m-ailabs_zeckou.wav',
1027
+ 'tn_ZA_google-nwu_1932.wav',
1028
+ # 'female-20-happy.wav',
1029
+ 'tn_ZA_google-nwu_1483.wav',
1030
+ 'de_DE_thorsten-emotion_amused.wav',
1031
+ 'ru_RU_multi_minaev.wav',
1032
+ 'sw_lanfrica.wav',
1033
+ 'en_US_vctk_p271.wav',
1034
+ 'tn_ZA_google-nwu_0441.wav',
1035
+ 'it_IT_mls_6001.wav',
1036
+ 'en_US_vctk_p305.wav',
1037
+ 'it_IT_mls_8828.wav',
1038
+ 'jv_ID_google-gmu_08002.wav',
1039
+ 'it_IT_mls_2033.wav',
1040
+ 'tn_ZA_google-nwu_3629.wav',
1041
+ 'it_IT_mls_6348.wav',
1042
+ 'en_US_cmu_arctic_axb.wav',
1043
+ 'it_IT_mls_8181.wav',
1044
+ 'en_US_vctk_p230.wav',
1045
+ 'af_ZA_google-nwu_7214.wav',
1046
+ 'nl_nathalie.wav',
1047
+ 'it_IT_mls_8207.wav',
1048
+ 'ko_KO_kss.wav',
1049
+ 'af_ZA_google-nwu_6590.wav',
1050
+ 'jv_ID_google-gmu_00264.wav',
1051
+ 'tn_ZA_google-nwu_6234.wav',
1052
+ 'jv_ID_google-gmu_05522.wav',
1053
+ 'en_US_cmu_arctic_lnh.wav',
1054
+ 'en_US_vctk_p272.wav',
1055
+ 'en_US_cmu_arctic_slp.wav',
1056
+ 'en_US_vctk_p299.wav',
1057
+ 'en_US_hifi-tts_9017.wav',
1058
+ 'it_IT_mls_4998.wav',
1059
+ 'it_IT_mls_6299.wav',
1060
+ 'en_US_cmu_arctic_rxr.wav',
1061
+ 'female-46-neutral.wav',
1062
+ 'jv_ID_google-gmu_01392.wav',
1063
+ 'tn_ZA_google-nwu_8512.wav',
1064
+ 'en_US_vctk_p244.wav',
1065
+ # 'bn_multi_3108.wav',
1066
+ # 'it_IT_mls_7405.wav',
1067
+ # 'bn_multi_3713.wav',
1068
+ # 'yo_openbible.wav',
1069
+ # 'jv_ID_google-gmu_01932.wav',
1070
+ 'en_US_vctk_p270.wav',
1071
+ 'tn_ZA_google-nwu_6459.wav',
1072
+ 'bn_multi_4046.wav',
1073
+ 'en_US_vctk_p288.wav',
1074
+ 'en_US_vctk_p251.wav',
1075
+ 'es_ES_m-ailabs_tux.wav',
1076
+ 'tn_ZA_google-nwu_6206.wav',
1077
+ 'bn_multi_9169.wav',
1078
+ # 'en_US_vctk_p293.wav',
1079
+ # 'en_US_vctk_p255.wav',
1080
+ 'af_ZA_google-nwu_8963.wav',
1081
+ # 'en_US_vctk_p265.wav',
1082
+ 'gu_IN_cmu-indic_cmu_indic_guj_ad.wav',
1083
+ 'jv_ID_google-gmu_07335.wav',
1084
+ 'en_US_vctk_p323.wav',
1085
+ 'en_US_vctk_p281.wav',
1086
+ 'en_US_cmu_arctic_bdl.wav',
1087
+ 'en_US_m-ailabs_judy_bieber.wav',
1088
+ 'it_IT_mls_10446.wav',
1089
+ 'en_US_vctk_p261.wav',
1090
+ 'en_US_vctk_p292.wav',
1091
+ 'te_IN_cmu-indic_ss.wav',
1092
+ 'en_US_vctk_p311.wav',
1093
+ 'it_IT_mls_12428.wav',
1094
+ 'en_US_cmu_arctic_aup.wav',
1095
+ 'jv_ID_google-gmu_04679.wav',
1096
+ 'it_IT_mls_4971.wav',
1097
+ 'en_US_cmu_arctic_ljm.wav',
1098
+ 'fa_haaniye.wav',
1099
+ 'en_US_vctk_p339.wav',
1100
+ 'tn_ZA_google-nwu_7896.wav',
1101
+ 'en_US_vctk_p253.wav',
1102
+ 'it_IT_mls_5421.wav',
1103
+ # 'ne_NP_ne-google_0546.wav',
1104
+ 'vi_VN_vais1000.wav',
1105
+ 'en_US_vctk_p229.wav',
1106
+ 'en_US_vctk_p254.wav',
1107
+ 'en_US_vctk_p258.wav',
1108
+ 'it_IT_mls_7936.wav',
1109
+ 'en_US_vctk_p301.wav',
1110
+ 'tn_ZA_google-nwu_0045.wav',
1111
+ 'it_IT_mls_659.wav',
1112
+ 'tn_ZA_google-nwu_7674.wav',
1113
+ 'it_IT_mls_12804.wav',
1114
+ 'el_GR_rapunzelina.wav',
1115
+ 'en_US_hifi-tts_6097.wav',
1116
+ 'en_US_vctk_p257.wav',
1117
+ 'jv_ID_google-gmu_07875.wav',
1118
+ 'it_IT_mls_1157.wav',
1119
+ 'it_IT_mls_643.wav',
1120
+ 'en_US_vctk_p304.wav',
1121
+ 'ru_RU_multi_hajdurova.wav',
1122
+ 'it_IT_mls_8461.wav',
1123
+ 'bn_multi_3958.wav',
1124
+ 'it_IT_mls_1989.wav',
1125
+ 'en_US_vctk_p249.wav',
1126
+ # 'bn_multi_0834.wav',
1127
+ 'en_US_vctk_p307.wav',
1128
+ 'es_ES_m-ailabs_karen_savage.wav',
1129
+ 'fr_FR_m-ailabs_bernard.wav',
1130
+ 'en_US_vctk_p252.wav',
1131
+ 'en_US_cmu_arctic_jmk.wav',
1132
+ 'en_US_vctk_p333.wav',
1133
+ 'tn_ZA_google-nwu_4506.wav',
1134
+ 'ne_NP_ne-google_0283.wav',
1135
+ 'de_DE_m-ailabs_karlsson.wav',
1136
+ 'en_US_cmu_arctic_awb.wav',
1137
+ 'en_US_vctk_p246.wav',
1138
+ 'en_US_cmu_arctic_clb.wav',
1139
+ 'en_US_vctk_p364.wav',
1140
+ 'nl_flemishguy.wav',
1141
+ 'en_US_vctk_p276.wav', # y
1142
+ # 'en_US_vctk_p274.wav',
1143
+ 'fr_FR_m-ailabs_gilles_g_le_blanc.wav',
1144
+ 'it_IT_mls_7444.wav',
1145
+ 'style_o22050.wav',
1146
+ 'en_US_vctk_s5.wav',
1147
+ 'en_US_vctk_p268.wav',
1148
+ 'it_IT_mls_6807.wav',
1149
+ 'it_IT_mls_2019.wav',
1150
+ 'male-60-angry.wav',
1151
+ 'af_ZA_google-nwu_8924.wav',
1152
+ 'en_US_vctk_p374.wav',
1153
+ 'en_US_vctk_p363.wav',
1154
+ 'it_IT_mls_644.wav',
1155
+ 'ne_NP_ne-google_3614.wav',
1156
+ 'en_US_vctk_p241.wav',
1157
+ 'ne_NP_ne-google_3154.wav',
1158
+ 'en_US_vctk_p234.wav',
1159
+ 'it_IT_mls_8384.wav',
1160
+ 'fr_FR_m-ailabs_ezwa.wav',
1161
+ 'it_IT_mls_5010.wav',
1162
+ 'en_US_vctk_p351.wav',
1163
+ 'en_US_cmu_arctic_eey.wav',
1164
+ 'jv_ID_google-gmu_04285.wav',
1165
+ 'jv_ID_google-gmu_06941.wav',
1166
+ 'hu_HU_diana-majlinger.wav',
1167
+ 'tn_ZA_google-nwu_2839.wav',
1168
+ 'bn_multi_03042.wav',
1169
+ 'tn_ZA_google-nwu_5628.wav',
1170
+ 'it_IT_mls_4649.wav',
1171
+ 'af_ZA_google-nwu_7130.wav',
1172
+ 'en_US_cmu_arctic_slt.wav',
1173
+ 'jv_ID_google-gmu_04175.wav',
1174
+ 'gu_IN_cmu-indic_cmu_indic_guj_kt.wav',
1175
+ 'jv_ID_google-gmu_00027.wav',
1176
+ 'jv_ID_google-gmu_02884.wav',
1177
+ 'en_US_vctk_p360.wav',
1178
+ 'en_US_vctk_p334.wav',
1179
+ 'male-27-sad.wav',
1180
+ 'tn_ZA_google-nwu_1498.wav',
1181
+ 'fi_FI_harri-tapani-ylilammi.wav',
1182
+ 'bn_multi_rm.wav',
1183
+ 'ne_NP_ne-google_2139.wav',
1184
+ 'pl_PL_m-ailabs_piotr_nater.wav',
1185
+ 'fr_FR_siwis.wav',
1186
+ 'nl_bart-de-leeuw.wav',
1187
+ 'jv_ID_google-gmu_04715.wav',
1188
+ 'en_US_vctk_p283.wav',
1189
+ 'en_US_vctk_p314.wav',
1190
+ 'en_US_vctk_p335.wav',
1191
+ 'jv_ID_google-gmu_07765.wav',
1192
+ 'en_US_vctk_p273.wav'
1193
+ ]
1194
+
1195
  _tts = StyleTTS2().to('cpu')
1196
 
1197
  def only_greek_or_only_latin(text, lang='grc'):
 
1321
 
1322
 
1323
  def other_tts(text='Hallov worlds Far over the',
1324
+ ref_s='wav/af_ZA_google-nwu_0184.wav',
1325
+ soundscape='birds fomig',
1326
+ cache_lim=64):
1327
+
1328
+ total_audio = []
1329
+
1330
+ final_audio = None
1331
+ speech_audio = None
1332
 
 
1333
 
1334
+ if text and text.strip():
1335
+
1336
+ text = only_greek_or_only_latin(text, lang='eng')
1337
 
1338
+ speech_audio = _tts.inference(text, ref_s=ref_s)[0, 0, :].numpy() # 24 Khz
1339
+
1340
+ if speech_audio.shape[0] > 10:
1341
 
1342
+ speech_audio = audresample.resample(signal=speech_audio.astype(np.float32),
1343
+ original_rate=24000,
1344
+ target_rate=16000)[0, :] # 16 KHz
1345
 
1346
+ # AudioGen
1347
+ if soundscape and soundscape.strip():
1348
 
1349
+
1350
+ speech_duration_secs = len(speech_audio) / 16000 if speech_audio is not None else 0
1351
+ target_duration = max(speech_duration_secs + 0.74, 2.0)
1352
 
 
1353
 
1354
+ background_audio = audiogen.generate(
1355
+ soundscape,
1356
+ duration=target_duration,
1357
+ cache_lim=max(4, int(cache_lim)) # at least allow 10 A/R stEps
1358
+ ).numpy()
1359
+
1360
+ if speech_audio is not None:
1361
+
1362
+ len_speech = len(speech_audio)
1363
+ len_background = len(background_audio)
1364
+
1365
+ if len_background > len_speech:
1366
+ padding = np.zeros(len_background - len_speech,
1367
+ dtype=np.float32)
1368
+ speech_audio = np.concatenate([speech_audio, padding])
1369
+ elif len_speech > len_background:
1370
+ padding = np.zeros(len_speech - len_background,
1371
+ dtype=np.float32)
1372
+ background_audio = np.concatenate([background_audio, padding])
1373
+
1374
+ # Convert to 2D arrays for stereo blending
1375
+ speech_audio_stereo = speech_audio[None, :]
1376
+ background_audio_stereo = background_audio[None, :]
1377
+
1378
+
1379
+ final_audio = np.concatenate([
1380
+ 0.49 * speech_audio_stereo + 0.51 * background_audio_stereo,
1381
+ 0.51 * background_audio_stereo + 0.49 * speech_audio_stereo
1382
+ ],0)
1383
+ else:
1384
+ final_audio = background_audio
1385
+
1386
+ elif speech_audio is not None:
1387
+ final_audio = speech_audio
1388
+
1389
+ # If both inputs are empty, create a 2s silent audio file.
1390
+ if final_audio is None:
1391
+ final_audio = np.zeros(16000 * 2, dtype=np.float32)
1392
+ print('\n=============F I N A L\n', final_audio.shape, final_audio.dtype, final_audio.min(), np.isnan(final_audio).sum())
1393
+ wavfile = '_audionar_.wav'
1394
+ audiofile.write(wavfile, final_audio, 16000)
1395
+ return wavfile
1396
 
1397
  def update_selected_voice(voice_filename):
1398
  return 'wav/' + voice_filename + '.wav'
 
1443
  # Main input and output components
1444
  with gr.Row():
1445
  text_input = gr.Textbox(
1446
+ label="TYpe text for TTS:",
1447
  placeholder="Type your message here...",
1448
  lines=4,
1449
  value="Farover the misty mountains cold too dungeons deep and caverns old.",
1450
  )
1451
+ soundscape_input = gr.Textbox(lines=1,
1452
+ value="frogs",
1453
+ label="AudioGen Txt"
1454
+ ),
1455
+ kv_input = gr.Number(
1456
+ label="kv Period",
1457
+ value=24,
1458
+ )
1459
  generate_button = gr.Button("Generate Audio", variant="primary")
1460
 
1461
  output_audio = gr.Audio(label="TTS Output")
 
1482
 
1483
  generate_button.click(
1484
  fn=other_tts,
1485
+ inputs=[text_input, selected_voice, soundscape_input, kv_input],
1486
  outputs=output_audio
1487
  )
1488
 
 
1524
  value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.',
1525
  label="Type text for TTS"
1526
  )
1527
+ lang_dropdown = gr.Dropdown(choices=language_names, label="TTS language", value="Ancient greek")
1528
+ soundscape_input = gr.Textbox(lines=1, value="dogs barg", label="AudioGen Txt")
1529
+ kv_input = gr.Number(label="kv Period", value=70)
 
 
1530
 
1531
  # Create a button to trigger the TTS function
1532
  tts_button = gr.Button("Generate Audio")
 
1537
  # Link the button click event to the mms_tts function
1538
  tts_button.click(
1539
  fn=audionar_tts,
1540
+ inputs=[text_input, lang_dropdown, soundscape_input, kv_input],
1541
  outputs=audio_output
1542
  )
1543
 
audiocraft.py ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from omegaconf import OmegaConf
5
+ import numpy as np
6
+ from huggingface_hub import hf_hub_download
7
+ import os
8
+ from torch.nn.utils import weight_norm
9
+ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
10
+ from einops import rearrange
11
+
12
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
13
+
14
+
15
+
16
+ N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
17
+
18
+ def _shift(x):
19
+ #print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
20
+ for i, _slice in enumerate(x):
21
+ n = x.shape[2]
22
+ offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
23
+ print(offset)
24
+ x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
25
+ return x
26
+
27
+ class AudioGen(torch.nn.Module):
28
+
29
+ # https://huggingface.co/facebook/audiogen-medium
30
+
31
+ def __init__(self):
32
+
33
+ super().__init__()
34
+ _file_1 = hf_hub_download(
35
+ repo_id='facebook/audiogen-medium',
36
+ filename="compression_state_dict.bin",
37
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
38
+ library_name="audiocraft",
39
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
40
+ pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
41
+ self.compression_model = EncodecModel()
42
+ self.compression_model.load_state_dict(pkg['best_state'], strict=False)
43
+ self.compression_model.eval() # ckpt has also unused encoder weights
44
+ self._chunk_len = 476
45
+ _file_2 = hf_hub_download(
46
+ repo_id='facebook/audiogen-medium',
47
+ filename="state_dict.bin",
48
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
49
+ library_name="audiocraft",
50
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
51
+ pkg = torch.load(_file_2, map_location='cpu')
52
+ cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
53
+ _best = pkg['best_state']
54
+ _best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
55
+ _best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
56
+ self.lm = LMModel()
57
+ self.lm.load_state_dict(pkg['best_state'], strict=True)
58
+ self.lm.eval()
59
+
60
+
61
+ @torch.no_grad()
62
+ def generate(self,
63
+ prompt='dogs mewo',
64
+ duration=2.24, # seconds of audio
65
+ cache_lim=71, # flush kv cache after cache_lim tok
66
+ ):
67
+ torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
68
+ self.lm.cache_lim = cache_lim
69
+ self.lm.n_draw = int(.8 * duration) + 1 # different beam every 0.47 seconds of audio
70
+ with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
+ gen_tokens = self.lm.generate(
72
+ text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
73
+ max_tokens=int(.04 * duration / N_REPEAT * self.compression_model.frame_rate) + 12) # [bs, 4, 74*self.lm.n_draw]
74
+
75
+ # OOM if vocode all tokens
76
+ x = []
77
+
78
+
79
+ for i in range(7, gen_tokens.shape[2], self._chunk_len): # min soundscape 2s assures 10 tokens
80
+
81
+ decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
82
+
83
+ x.append(decoded_chunk)
84
+
85
+ x = torch.cat(x, 2) # [bs, 1, 114000]
86
+
87
+ x = _shift(x) # clone() to have xN
88
+
89
+ return x.reshape(-1) #x / (x.abs().max() + 1e-7)
90
+
91
+
92
+ class EncodecModel(nn.Module):
93
+
94
+ def __init__(self):
95
+
96
+ super().__init__()
97
+ self.decoder = SEANetDecoder()
98
+ self.quantizer = ResidualVectorQuantizer()
99
+ self.frame_rate = 50
100
+
101
+
102
+ def decode(self, codes):
103
+ # B,K,T -> B,C,T
104
+ emb = self.quantizer.decode(codes)
105
+ return self.decoder(emb)
106
+
107
+
108
+ class StreamableLSTM(nn.Module):
109
+
110
+ def __init__(self,
111
+ dimension,
112
+ num_layers=2,
113
+ skip=True):
114
+
115
+ super().__init__()
116
+ self.skip = skip
117
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
118
+
119
+ def forward(self, x):
120
+ x = x.permute(2, 0, 1)
121
+ y, _ = self.lstm(x)
122
+ if self.skip:
123
+ y = y + x
124
+ y = y.permute(1, 2, 0)
125
+ return y
126
+
127
+
128
+
129
+ class SEANetResnetBlock(nn.Module):
130
+
131
+ def __init__(self,
132
+ dim,
133
+ kernel_sizes = [3, 1],
134
+ pad_mode = 'reflect',
135
+ compress = 2):
136
+
137
+ super().__init__()
138
+
139
+ hidden = dim // compress
140
+ block = []
141
+ for i, kernel_size in enumerate(kernel_sizes):
142
+ in_chs = dim if i == 0 else hidden
143
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
144
+ block += [nn.ELU(),
145
+ StreamableConv1d(in_chs,
146
+ out_chs,
147
+ kernel_size=kernel_size,
148
+ pad_mode=pad_mode)]
149
+ self.block = nn.Sequential(*block)
150
+
151
+ def forward(self, x):
152
+ return x + self.block(x)
153
+
154
+
155
+
156
+
157
+
158
+ class SEANetDecoder(nn.Module):
159
+ # channels=1 dimension=128 n_filters=64 n_residual_layers=1 ratios=[8, 5, 4, 2]
160
+ # activation='ELU' activation_params={'alpha': 1.0}, final_activation=None
161
+ # final_activation_params=None norm='weight_norm'
162
+ # norm_params={} kernel_size=7 last_kernel_size=7 residual_kernel_size=3 dilation_base=2
163
+ # causal=False pad_mode='constant'
164
+ # true_skip=True compress=2 lstm=2 disable_norm_outer_blocks=0 trim_right_ratio=1.0
165
+
166
+ def __init__(self,
167
+ channels = 1,
168
+ dimension = 128,
169
+ n_filters = 64,
170
+ n_residual_layers = 1,
171
+ ratios = [8, 5, 4, 2],
172
+ kernel_size = 7,
173
+ last_kernel_size = 7,
174
+ residual_kernel_size = 3,
175
+ pad_mode = 'constant',
176
+ compress = 2,
177
+ lstm = 2):
178
+
179
+ super().__init__()
180
+
181
+
182
+ mult = int(2 ** len(ratios))
183
+ model = [
184
+ StreamableConv1d(dimension, mult * n_filters,
185
+ kernel_size,
186
+ pad_mode=pad_mode)
187
+ ]
188
+
189
+ if lstm:
190
+ print('\n\n\n\nLSTM IN SEANET\n\n\n\n')
191
+ model += [StreamableLSTM(mult * n_filters,
192
+ num_layers=lstm)]
193
+
194
+ # Upsample to raw audio scale
195
+ for i, ratio in enumerate(ratios):
196
+
197
+
198
+ model += [
199
+ nn.ELU(),
200
+ StreamableConvTranspose1d(mult * n_filters,
201
+ mult * n_filters // 2,
202
+ kernel_size=ratio * 2,
203
+ stride=ratio),
204
+ ]
205
+ # Add residual layers
206
+ for j in range(n_residual_layers):
207
+
208
+ model += [
209
+ SEANetResnetBlock(mult * n_filters // 2,
210
+ kernel_sizes=[residual_kernel_size, 1],
211
+ pad_mode=pad_mode,
212
+ compress=compress)]
213
+
214
+ mult //= 2
215
+
216
+ # Add final layers
217
+ model += [
218
+ nn.ELU(),
219
+ StreamableConv1d(n_filters,
220
+ channels,
221
+ last_kernel_size,
222
+ pad_mode=pad_mode)]
223
+ self.model=nn.Sequential(*model)
224
+
225
+ def forward(self, z):
226
+ return self.model(z)
227
+
228
+
229
+
230
+
231
+ def unpad1d(x, paddings):
232
+ padding_left, padding_right = paddings
233
+ end = x.shape[-1] - padding_right
234
+ return x[..., padding_left: end]
235
+
236
+
237
+ class NormConv1d(nn.Module):
238
+
239
+ def __init__(self, *args, **kwargs):
240
+ super().__init__()
241
+
242
+ self.conv = weight_norm(nn.Conv1d(*args, **kwargs)) # norm = weight_norm
243
+
244
+ def forward(self, x):
245
+ return self.conv(x)
246
+
247
+
248
+
249
+
250
+
251
+ class NormConvTranspose1d(nn.Module):
252
+
253
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
254
+ norm_kwargs = {}, **kwargs):
255
+ super().__init__()
256
+
257
+ self.convtr = weight_norm(nn.ConvTranspose1d(*args, **kwargs))
258
+
259
+ def forward(self, x):
260
+ return self.convtr(x)
261
+
262
+
263
+
264
+
265
+
266
+
267
+ class StreamableConv1d(nn.Module):
268
+
269
+ def __init__(self,
270
+ in_channels,
271
+ out_channels,
272
+ kernel_size,
273
+ stride=1,
274
+ groups=1,
275
+ bias=True,
276
+ pad_mode='reflect'):
277
+ super().__init__()
278
+ if (stride != 1) or (groups != 1):
279
+ raise ValueError
280
+ self.conv = NormConv1d(in_channels,
281
+ out_channels,
282
+ kernel_size,
283
+ stride,
284
+ groups=groups,
285
+ bias=bias)
286
+ self.pad_mode = pad_mode
287
+
288
+ def forward(self, x):
289
+ kernel_size = self.conv.conv.kernel_size[0]
290
+ kernel_size = (kernel_size - 1) * self.conv.conv.dilation[0] + 1
291
+ padding_total = kernel_size - self.conv.conv.stride[0]
292
+ padding_right = padding_total // 2
293
+ padding_left = padding_total - padding_right
294
+
295
+ # x = pad1d(x, (padding_left, padding_right), mode=self.pad_mode)
296
+ x = F.pad(x, (padding_left, padding_right), self.pad_mode)
297
+ return self.conv(x)
298
+
299
+
300
+ class StreamableConvTranspose1d(nn.Module):
301
+
302
+ def __init__(self, in_channels: int, out_channels: int,
303
+ kernel_size: int, stride: int = 1, causal: bool = False,
304
+ norm: str = 'none', trim_right_ratio: float = 1.,
305
+ norm_kwargs = {}):
306
+ super().__init__()
307
+ self.convtr = NormConvTranspose1d(in_channels,
308
+ out_channels,
309
+ kernel_size,
310
+ stride)
311
+
312
+
313
+ def forward(self, x):
314
+
315
+ padding_total = self.convtr.convtr.kernel_size[0] - self.convtr.convtr.stride[0]
316
+
317
+ y = self.convtr(x)
318
+
319
+ # Asymmetric padding required for odd strides
320
+ # print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n')
321
+ padding_right = padding_total // 2
322
+ padding_left = padding_total - padding_right
323
+
324
+ y = unpad1d(y, (padding_left, padding_right))
325
+ return y
326
+
327
+
328
+ # VQ
329
+
330
+ class EuclideanCodebook(nn.Module):
331
+ def __init__(self,
332
+ dim,
333
+ codebook_size):
334
+ super().__init__()
335
+ self.register_buffer("embed", torch.zeros(codebook_size, dim))
336
+
337
+
338
+
339
+
340
+ class VectorQuantization(nn.Module):
341
+
342
+ def __init__(self,
343
+ dim,
344
+ codebook_size):
345
+
346
+ super().__init__()
347
+ self._codebook = EuclideanCodebook(dim=dim,
348
+ codebook_size=codebook_size)
349
+
350
+ def decode(self, _ind):
351
+ return F.embedding(_ind, self._codebook.embed)
352
+
353
+
354
+ class ResidualVectorQuantization(nn.Module):
355
+
356
+ def __init__(self, *, num_quantizers, **kwargs):
357
+ super().__init__()
358
+ self.layers = nn.ModuleList(
359
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
360
+ )
361
+
362
+ def decode(self, _ind):
363
+ x = 0.0
364
+ for i, _code in enumerate(_ind):
365
+ x = x + self.layers[i].decode(_code)
366
+ return x.transpose(1, 2)
367
+
368
+
369
+ class ResidualVectorQuantizer(nn.Module):
370
+
371
+ # dimension=128 n_q=4 q_dropout=False bins=2048 decay=0.99 kmeans_init=True
372
+ # kmeans_iters=50 threshold_ema_dead_code=2
373
+ # orthogonal_reg_weight=0.0 orthogonal_reg_active_codes_only=False
374
+ # orthogonal_reg_max_codes=None
375
+
376
+ def __init__(
377
+ self,
378
+ dimension = 128,
379
+ n_q = 4,
380
+ bins = 2048
381
+ ):
382
+
383
+ super().__init__()
384
+ self.vq = ResidualVectorQuantization(dim=dimension,
385
+ codebook_size=bins,
386
+ num_quantizers=n_q)
387
+
388
+ def decode(self, codes):
389
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
390
+ return self.vq.decode(codes.transpose(0, 1))
391
+
392
+
393
+ class T5(nn.Module):
394
+
395
+ def __init__(self):
396
+
397
+ super().__init__()
398
+ self.output_proj = nn.Linear(1024, # t5-large
399
+ 1536) # lm hidden
400
+ self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
401
+ t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
402
+
403
+ # this makes sure that the t5 is not part
404
+ # of the saved checkpoint
405
+ self.__dict__['t5'] = t5.to('cpu')
406
+
407
+ def forward(self, prompt):
408
+ with torch.set_grad_enabled(False): #, torch.autocast(device_type='cpu', dtype=torch.float32):
409
+
410
+ bs = len(prompt) // 2
411
+ d = self.t5_tokenizer(prompt,
412
+ return_tensors='pt',
413
+ padding=True).to(self.output_proj.bias.device)
414
+ d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
415
+
416
+ x = self.t5(input_ids=d['input_ids'],
417
+ attention_mask=d['attention_mask']).last_hidden_state # no kv
418
+ # Float 16
419
+ # > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
420
+ x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
421
+ x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
422
+ return x
423
+
424
+
425
+ class LMModel(nn.Module):
426
+
427
+ def __init__(self,
428
+ n_q = 4,
429
+ card = 2048,
430
+ dim = 1536
431
+ ):
432
+ super().__init__()
433
+ self.cache_lim = -1
434
+ self.t5 = T5()
435
+ self.card = card # 2048
436
+ self.n_draw = 1 # draw > 1 tokens of different CFG scale
437
+ # batch size > 1 is slower from n_draw as calls transformer on larger batch
438
+ self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
439
+ self.transformer = StreamingTransformer()
440
+ self.out_norm = nn.LayerNorm(dim, eps=1e-5)
441
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
442
+
443
+ def forward(self,
444
+ sequence,
445
+ condition_tensors=None,
446
+ cache_position=None
447
+ ):
448
+
449
+ bs, n_q, time_frames = sequence.shape # [bs, 4, time]
450
+
451
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
452
+
453
+ out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
454
+ cross_attention_src=condition_tensors,
455
+ cache_position=cache_position)
456
+
457
+ out = self.out_norm(out)
458
+
459
+ logits = torch.stack([self.linears[k](out) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
460
+ logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
461
+
462
+ #bs, n_q, n_draw, vocab = logits.shape
463
+ tokens = torch.multinomial(torch.softmax(logits.view(bs * self.n_draw * n_q, 2048), dim=1),
464
+ num_samples=1)
465
+ return tokens.view(bs, n_q, self.n_draw).transpose(1, 2)
466
+
467
+ @torch.no_grad()
468
+ def generate(self,
469
+ max_tokens=None,
470
+ text_condition=None
471
+ ):
472
+ x = self.t5(text_condition)
473
+ bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
474
+ self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
475
+ cache_position = 0
476
+
477
+ out_codes = torch.full((bs,
478
+ self.n_draw,
479
+ 4,
480
+ 4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
481
+ self.card,
482
+ dtype=torch.long,
483
+ device=x.device) # [bs, n_draw, 4, dur]
484
+
485
+ # A/R
486
+ for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
487
+
488
+ # extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
489
+ next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
490
+ #gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
491
+ condition_tensors=x, # utilisation of the attention mask of txt condition ?
492
+ cache_position=cache_position) # [bs, n_draw, 4]
493
+
494
+ # Fill of next_token should be also placed on antidiagonal [not column]
495
+
496
+ # Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
497
+ # 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
498
+ #
499
+ # [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
500
+ # [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
501
+ # [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
502
+ # [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
503
+ # NO OVerWriting
504
+ if offset == 0:
505
+
506
+ next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
507
+
508
+ elif offset == 1:
509
+
510
+ next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
511
+
512
+ elif offset == 2:
513
+
514
+ next_token[:, :, 3:4] = 2048
515
+
516
+ elif offset == max_tokens:
517
+
518
+ next_token[:, :, 0:1] = 2048 # top 1 entry of the antidiagonal should stay to 2048
519
+
520
+ elif offset == (max_tokens + 1):
521
+
522
+ next_token[:, :, 0:2] = 2048
523
+
524
+ elif offset == (max_tokens + 2):
525
+
526
+ next_token[:, :, 0:3] = 2048
527
+
528
+ else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
529
+
530
+ pass #print('No delete anti-diag')
531
+
532
+ out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
533
+ # Sink Attn
534
+ if (offset > 0) and (offset % self.cache_lim) == 0:
535
+ n_preserve = 4
536
+ self.transformer._flush(n_preserve=n_preserve)
537
+ cache_position = n_preserve
538
+ else:
539
+ cache_position += 1
540
+
541
+ # [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
542
+ out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
543
+
544
+ # flush for next API call
545
+ self.transformer._flush()
546
+
547
+ return out_codes # SKIP THE 4 fill 2048
548
+
549
+
550
+
551
+
552
+ def create_sin_embedding(positions,
553
+ dim,
554
+ max_period=10000
555
+ ):
556
+ # assert dim % 2 == 0
557
+ half_dim = dim // 2
558
+ positions = positions.to(torch.float)
559
+ adim = torch.arange(half_dim, device=positions.device,
560
+ dtype=torch.float).view(1, 1, -1)
561
+ max_period_tensor = torch.full([],
562
+ max_period,
563
+ device=positions.device,
564
+ dtype=torch.float) # avoid sync point
565
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
566
+ # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
567
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
568
+
569
+
570
+ class StreamingMultiheadAttention(nn.Module):
571
+
572
+ def __init__(self,
573
+ embed_dim,
574
+ num_heads,
575
+ cross_attention=False,
576
+ ):
577
+
578
+ super().__init__()
579
+
580
+ self.cross_attention = cross_attention
581
+ # if not self.cross_attention then it has kvcachingn
582
+ self.k_history = None
583
+ # cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
584
+ self.v_history = None
585
+ self.num_heads = num_heads
586
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
587
+ self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
588
+ dtype=torch.float))
589
+
590
+ def forward(self,
591
+ query,
592
+ key=None,
593
+ value=None):
594
+ layout = "b h t d"
595
+ if self.cross_attention:
596
+
597
+ # Different queries, keys, values > split in_proj_weight
598
+
599
+ dim = self.in_proj_weight.shape[0] // 3
600
+
601
+ q = nn.functional.linear(query, self.in_proj_weight[:dim])
602
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
603
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
604
+
605
+ q, k, v = [
606
+ rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
607
+
608
+ else:
609
+
610
+ # Here <else> = self_attention for audio with itself (above is cross attention txt)
611
+
612
+ # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
613
+
614
+ # here we have different floating values from official
615
+ projected = nn.functional.linear(query, self.in_proj_weight, None)
616
+ # print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
617
+ bound_layout = "b h p t d"
618
+ packed = rearrange(
619
+ projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
620
+ q, k, v = packed.unbind(dim=2)
621
+ if self.k_history is not None:
622
+ # IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
623
+ # thus it will try to continue with incompatible k/v dims!
624
+ self.k_history = torch.cat([self.k_history, k], 2)
625
+ self.v_history = torch.cat([self.v_history, v], 2)
626
+ else:
627
+ self.k_history = k
628
+ self.v_history = v
629
+
630
+ # Assign Completed k / v to k / v
631
+
632
+ k = self.k_history
633
+ v = self.v_history
634
+
635
+ # -> kv CACHE ONLY APPLIES if not self.cross_attention
636
+
637
+ x = torch.nn.functional.scaled_dot_product_attention(
638
+ q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
639
+
640
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
641
+ x = self.out_proj(x)
642
+ return x
643
+
644
+
645
+ class StreamingTransformerLayer(nn.Module):
646
+
647
+ def __init__(self,
648
+ d_model,
649
+ num_heads,
650
+ dim_feedforward):
651
+
652
+ super().__init__()
653
+
654
+ self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
655
+ num_heads=num_heads)
656
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
657
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
658
+ self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
659
+ num_heads=num_heads,
660
+ cross_attention=True)
661
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
662
+ self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
663
+ self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
664
+
665
+ def forward(self,
666
+ x,
667
+ cross_attention_src=None):
668
+ x = x + self.self_attn(self.norm1(x))
669
+ x = x + self.cross_attention(query=self.norm_cross(x),
670
+ key=cross_attention_src,
671
+ value=cross_attention_src) # txtcondition
672
+ x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
673
+ return x
674
+
675
+
676
+ class StreamingTransformer(nn.Module):
677
+
678
+ def __init__(self,
679
+ d_model=1536,
680
+ num_heads=24,
681
+ num_layers=48,
682
+ dim_feedforward=6144):
683
+ super().__init__()
684
+
685
+ self.layers = nn.ModuleList(
686
+ [
687
+ StreamingTransformerLayer(d_model=d_model,
688
+ num_heads=num_heads,
689
+ dim_feedforward=dim_feedforward) for _ in range(num_layers)
690
+ ]
691
+ )
692
+
693
+ def forward(self,
694
+ x,
695
+ cache_position=None,
696
+ cross_attention_src=None):
697
+
698
+ x = x + create_sin_embedding(
699
+ torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
700
+
701
+ for lay in self.layers:
702
+ x = lay(x,
703
+ cross_attention_src=cross_attention_src)
704
+ return x
705
+
706
+ def _flush(self,
707
+ n_preserve=None):
708
+
709
+ for lay in self.layers:
710
+ if n_preserve is not None:
711
+ # cache position is difficult to choose to also preserve kv from end
712
+ lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
713
+ lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
714
+ else:
715
+ lay.self_attn.k_history = None
716
+ lay.self_attn.v_history = None
717
+
718
+
719
+ if __name__ == '__main__':
720
+
721
+ import audiofile
722
+ model = AudioGen().to('cpu')
723
+ x = model.generate(prompt='swims in lake frogs', duration=6.4).cpu().numpy()
724
+ audiofile.write('_sound_.wav', x, 16000)
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
- torch
2
  nltk
3
- pydantic==2.10.6
4
  librosa
5
- transformers
6
  phonemizer
7
  audiofile
8
  matplotlib
@@ -11,4 +9,8 @@ num2words
11
  numpy<2.0.0
12
  gradio==5.27.0
13
  Numbers2Words-Greek
14
-
 
 
 
 
 
1
+ omegaconf
2
  nltk
 
3
  librosa
 
4
  phonemizer
5
  audiofile
6
  matplotlib
 
9
  numpy<2.0.0
10
  gradio==5.27.0
11
  Numbers2Words-Greek
12
+ einops
13
+ torch
14
+ pydantic==2.10.6
15
+ transformers==4.49.0
16
+ sentencepiece