Gregniuki commited on
Commit
dbcc874
·
verified ·
1 Parent(s): cba8833

Delete cog.py

Browse files
Files changed (1) hide show
  1. cog.py +0 -180
cog.py DELETED
@@ -1,180 +0,0 @@
1
- # Prediction interface for Cog ⚙️
2
- # https://cog.run/python
3
-
4
- from cog import BasePredictor, Input, Path
5
-
6
- import os
7
- import re
8
- import torch
9
- import torchaudio
10
- import numpy as np
11
- import tempfile
12
- from einops import rearrange
13
- from ema_pytorch import EMA
14
- from vocos import Vocos
15
- from pydub import AudioSegment
16
- from model import CFM, UNetT, DiT, MMDiT
17
- from cached_path import cached_path
18
- from model.utils import (
19
- get_tokenizer,
20
- convert_char_to_pinyin,
21
- save_spectrogram,
22
- )
23
- from transformers import pipeline
24
- import librosa
25
-
26
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
-
28
- target_sample_rate = 24000
29
- n_mel_channels = 100
30
- hop_length = 256
31
- target_rms = 0.1
32
- nfe_step = 32 # 16, 32
33
- cfg_strength = 2.0
34
- ode_method = 'euler'
35
- sway_sampling_coef = -1.0
36
- speed = 1.0
37
- # fix_duration = 27 # None or float (duration in seconds)
38
- fix_duration = None
39
-
40
-
41
- class Predictor(BasePredictor):
42
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
43
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
44
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
45
- model = CFM(
46
- transformer=model_cls(
47
- **model_cfg,
48
- text_num_embeds=vocab_size,
49
- mel_dim=n_mel_channels
50
- ),
51
- mel_spec_kwargs=dict(
52
- target_sample_rate=target_sample_rate,
53
- n_mel_channels=n_mel_channels,
54
- hop_length=hop_length,
55
- ),
56
- odeint_kwargs=dict(
57
- method=ode_method,
58
- ),
59
- vocab_char_map=vocab_char_map,
60
- ).to(device)
61
-
62
- ema_model = EMA(model, include_online_model=False).to(device)
63
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
64
- ema_model.copy_params_from_ema_to_model()
65
-
66
- return ema_model, model
67
- def setup(self) -> None:
68
- """Load the model into memory to make running multiple predictions efficient"""
69
- # self.model = torch.load("./weights.pth")
70
- print("Loading Whisper model...")
71
- self.pipe = pipeline(
72
- "automatic-speech-recognition",
73
- model="openai/whisper-large-v3-turbo",
74
- torch_dtype=torch.float16,
75
- device=device,
76
- )
77
- print("Loading F5-TTS model...")
78
-
79
- F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
- self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
-
82
-
83
- def predict(
84
- self,
85
- gen_text: str = Input(description="Text to generate"),
86
- ref_audio_orig: Path = Input(description="Reference audio"),
87
- remove_silence: bool = Input(description="Remove silences", default=True),
88
- ) -> Path:
89
- """Run a single prediction on the model"""
90
- model_choice = "F5-TTS"
91
- print(gen_text)
92
- if len(gen_text) > 200:
93
- raise gr.Error("Please keep your text under 200 chars.")
94
- gr.Info("Converting audio...")
95
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
96
- aseg = AudioSegment.from_file(ref_audio_orig)
97
- audio_duration = len(aseg)
98
- if audio_duration > 15000:
99
- gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
- aseg = aseg[:15000]
101
- aseg.export(f.name, format="wav")
102
- ref_audio = f.name
103
- ema_model = self.F5TTS_ema_model
104
- base_model = self.F5TTS_base_model
105
-
106
- if not ref_text.strip():
107
- gr.Info("No reference text provided, transcribing reference audio...")
108
- ref_text = outputs = self.pipe(
109
- ref_audio,
110
- chunk_length_s=30,
111
- batch_size=128,
112
- generate_kwargs={"task": "transcribe"},
113
- return_timestamps=False,
114
- )['text'].strip()
115
- gr.Info("Finished transcription")
116
- else:
117
- gr.Info("Using custom reference text...")
118
- audio, sr = torchaudio.load(ref_audio)
119
-
120
- rms = torch.sqrt(torch.mean(torch.square(audio)))
121
- if rms < target_rms:
122
- audio = audio * target_rms / rms
123
- if sr != target_sample_rate:
124
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
125
- audio = resampler(audio)
126
- audio = audio.to(device)
127
-
128
- # Prepare the text
129
- text_list = [ref_text + gen_text]
130
- final_text_list = convert_char_to_pinyin(text_list)
131
-
132
- # Calculate duration
133
- ref_audio_len = audio.shape[-1] // hop_length
134
- # if fix_duration is not None:
135
- # duration = int(fix_duration * target_sample_rate / hop_length)
136
- # else:
137
- zh_pause_punc = r"。,、;:?!"
138
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
139
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
140
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
141
-
142
- # inference
143
- gr.Info(f"Generating audio using F5-TTS")
144
- with torch.inference_mode():
145
- generated, _ = base_model.sample(
146
- cond=audio,
147
- text=final_text_list,
148
- duration=duration,
149
- steps=nfe_step,
150
- cfg_strength=cfg_strength,
151
- sway_sampling_coef=sway_sampling_coef,
152
- )
153
-
154
- generated = generated[:, ref_audio_len:, :]
155
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
- gr.Info("Running vocoder")
157
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
158
- generated_wave = vocos.decode(generated_mel_spec.cpu())
159
- if rms < target_rms:
160
- generated_wave = generated_wave * rms / target_rms
161
-
162
- # wav -> numpy
163
- generated_wave = generated_wave.squeeze().cpu().numpy()
164
-
165
- if remove_silence:
166
- gr.Info("Removing audio silences... This may take a moment")
167
- non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
168
- non_silent_wave = np.array([])
169
- for interval in non_silent_intervals:
170
- start, end = interval
171
- non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
172
- generated_wave = non_silent_wave
173
-
174
-
175
- # spectogram
176
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
177
- wav_path = tmp_wav.name
178
- torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
179
-
180
- return wav_path