Hematej commited on
Commit
c3641e4
Β·
verified Β·
1 Parent(s): 266d41f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +152 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import torch
4
+ import soundfile as sf
5
+ from xcodec2.modeling_xcodec2 import XCodec2Model
6
+ import torchaudio
7
+ import gradio as gr
8
+ import tempfile
9
+
10
+ # βœ… Automatically detects whether to use GPU or CPU
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ llasa_3b ='srinivasbilla/llasa-3b'
14
+ tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ llasa_3b,
18
+ trust_remote_code=True,
19
+ device_map=device,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 # βœ… Uses float16 for GPU, float32 for CPU
21
+ )
22
+
23
+ model_path = "srinivasbilla/xcodec2"
24
+ Codec_model = XCodec2Model.from_pretrained(model_path)
25
+ Codec_model.eval().to(device) # βœ… Moves model to correct device dynamically
26
+
27
+ # βœ… Whisper ASR pipeline with automatic CPU/GPU selection
28
+ whisper_turbo_pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model="openai/whisper-large-v3-turbo",
31
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
+ device=device # βœ… Automatically selects CPU/GPU
33
+ )
34
+
35
+ def ids_to_speech_tokens(speech_ids):
36
+ return [f"<|s_{speech_id}|>" for speech_id in speech_ids]
37
+
38
+ def extract_speech_ids(speech_tokens_str):
39
+ speech_ids = []
40
+ for token_str in speech_tokens_str:
41
+ if token_str.startswith('<|s_') and token_str.endswith('|>'):
42
+ try:
43
+ speech_ids.append(int(token_str[4:-2]))
44
+ except ValueError:
45
+ print(f"Unexpected token: {token_str}")
46
+ return speech_ids
47
+
48
+ @spaces.GPU(duration=60)
49
+ def infer(sample_audio_path, target_text, progress=gr.Progress()):
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
51
+ progress(0, 'Loading and trimming audio...')
52
+ waveform, sample_rate = torchaudio.load(sample_audio_path)
53
+ if len(waveform[0])/sample_rate > 15:
54
+ gr.Warning("Trimming audio to first 15secs.")
55
+ waveform = waveform[:, :sample_rate*15]
56
+
57
+ # βœ… Convert stereo to mono dynamically
58
+ waveform_mono = waveform.mean(dim=0, keepdim=True) if waveform.size(0) > 1 else waveform
59
+ waveform_mono = waveform_mono.to(device)
60
+
61
+ prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
62
+ prompt_text = whisper_turbo_pipe(prompt_wav[0].cpu().numpy(), language="en")['text'].strip() # βœ… Force English transcription
63
+ progress(0.5, 'Transcribed! Generating speech...')
64
+
65
+ if len(target_text) == 0:
66
+ return None
67
+ elif len(target_text) > 300:
68
+ gr.Warning("Text is too long. Please keep it under 300 characters.")
69
+ target_text = target_text[:300]
70
+
71
+ input_text = f"{prompt_text} {target_text}"
72
+
73
+ with torch.no_grad():
74
+ vq_code_prompt = Codec_model.encode_code(prompt_wav)
75
+ vq_code_prompt = vq_code_prompt[0,0,:]
76
+ speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
77
+
78
+ formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
79
+
80
+ chat = [
81
+ {"role": "user", "content": "Convert the text to speech:" + formatted_text},
82
+ {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
83
+ ]
84
+
85
+ input_ids = tokenizer.apply_chat_template(
86
+ chat,
87
+ tokenize=True,
88
+ return_tensors='pt',
89
+ continue_final_message=True
90
+ ).to(device)
91
+
92
+ speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
93
+
94
+ if speech_end_id is None:
95
+ raise ValueError("Error: `<|SPEECH_GENERATION_END|>` token not found!")
96
+
97
+ outputs = model.generate(
98
+ input_ids,
99
+ max_length=2048,
100
+ eos_token_id=speech_end_id,
101
+ do_sample=True,
102
+ top_p=1,
103
+ temperature=0.8
104
+ )
105
+
106
+ generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix):-1]
107
+ speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
108
+ speech_tokens = extract_speech_ids(speech_tokens)
109
+
110
+ if not speech_tokens:
111
+ raise ValueError("Error: No valid speech tokens extracted!")
112
+
113
+ speech_tensor = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0).to(device)
114
+
115
+ gen_wav = Codec_model.decode_code(speech_tensor)
116
+ gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
117
+
118
+ progress(1, 'Synthesized!')
119
+
120
+ return (16000, gen_wav[0, 0, :].cpu().numpy())
121
+
122
+ # βœ… Gradio UI setup
123
+ with gr.Blocks() as app_tts:
124
+ gr.Markdown("# Zero Shot Voice Clone TTS")
125
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
126
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
127
+ generate_btn = gr.Button("Synthesize", variant="primary")
128
+ audio_output = gr.Audio(label="Synthesized Audio")
129
+
130
+ generate_btn.click(
131
+ infer,
132
+ inputs=[ref_audio_input, gen_text_input],
133
+ outputs=[audio_output],
134
+ )
135
+
136
+ with gr.Blocks() as app_credits:
137
+ gr.Markdown("""
138
+ # Credits
139
+ * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training)
140
+ * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
141
+ """)
142
+
143
+ with gr.Blocks() as app:
144
+ gr.Markdown("""
145
+ # llasa 3b TTS
146
+ This is a local web UI for llasa 3b SOTA Zero Shot Voice Cloning and TTS model.
147
+ The checkpoints support English and Chinese.
148
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
149
+ """)
150
+ gr.TabbedInterface([app_tts], ["TTS"])
151
+
152
+ app.launch(ssr_mode=False)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ xcodec2==0.1.3