Spaces:
Sleeping
Sleeping
gpt-omni
commited on
Commit
·
411819d
1
Parent(s):
9b186d7
update
Browse files
app.py
CHANGED
|
@@ -30,7 +30,7 @@ import soundfile as sf
|
|
| 30 |
from litgpt.model import GPT, Config
|
| 31 |
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
| 32 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
| 33 |
-
from utils.snac_utils import get_snac
|
| 34 |
import whisper
|
| 35 |
from tqdm import tqdm
|
| 36 |
from huggingface_hub import snapshot_download
|
|
@@ -80,19 +80,19 @@ if not os.path.exists(ckpt_dir):
|
|
| 80 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
| 81 |
whispermodel = whisper.load_model("small").to(device)
|
| 82 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 83 |
-
fabric = L.Fabric(devices=1, strategy="auto")
|
| 84 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
| 85 |
config.post_adapter = False
|
| 86 |
|
| 87 |
model = GPT(config, device=device)
|
| 88 |
|
| 89 |
-
# model = fabric.setup(model)
|
| 90 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
| 91 |
model.load_state_dict(state_dict, strict=True)
|
| 92 |
model = model.to(device)
|
| 93 |
model.eval()
|
| 94 |
|
| 95 |
|
|
|
|
| 96 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
| 97 |
with torch.no_grad():
|
| 98 |
mel = mel.unsqueeze(0).to(device)
|
|
@@ -128,6 +128,7 @@ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
|
| 128 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 129 |
|
| 130 |
|
|
|
|
| 131 |
def next_token_batch(
|
| 132 |
model: GPT,
|
| 133 |
audio_features: torch.tensor,
|
|
@@ -162,9 +163,19 @@ def load_audio(path):
|
|
| 162 |
mel = whisper.log_mel_spectrogram(audio)
|
| 163 |
return mel, int(duration_ms / 20) + 1
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# @torch.inference_mode()
|
| 167 |
-
@spaces.GPU
|
| 168 |
def run_AT_batch_stream(
|
| 169 |
audio_path,
|
| 170 |
stream_stride=4,
|
|
@@ -178,11 +189,10 @@ def run_AT_batch_stream(
|
|
| 178 |
|
| 179 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 180 |
|
| 181 |
-
# with self.fabric.init_tensor():
|
| 182 |
model.set_kv_cache(batch_size=2)
|
| 183 |
|
| 184 |
mel, leng = load_audio(audio_path)
|
| 185 |
-
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng,
|
| 186 |
T = input_ids[0].size(1)
|
| 187 |
device = input_ids[0].device
|
| 188 |
|
|
|
|
| 30 |
from litgpt.model import GPT, Config
|
| 31 |
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
| 32 |
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
| 33 |
+
from utils.snac_utils import get_snac
|
| 34 |
import whisper
|
| 35 |
from tqdm import tqdm
|
| 36 |
from huggingface_hub import snapshot_download
|
|
|
|
| 80 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
| 81 |
whispermodel = whisper.load_model("small").to(device)
|
| 82 |
text_tokenizer = Tokenizer(ckpt_dir)
|
| 83 |
+
# fabric = L.Fabric(devices=1, strategy="auto")
|
| 84 |
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
| 85 |
config.post_adapter = False
|
| 86 |
|
| 87 |
model = GPT(config, device=device)
|
| 88 |
|
|
|
|
| 89 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
| 90 |
model.load_state_dict(state_dict, strict=True)
|
| 91 |
model = model.to(device)
|
| 92 |
model.eval()
|
| 93 |
|
| 94 |
|
| 95 |
+
@spaces.GPU
|
| 96 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
| 97 |
with torch.no_grad():
|
| 98 |
mel = mel.unsqueeze(0).to(device)
|
|
|
|
| 128 |
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
| 129 |
|
| 130 |
|
| 131 |
+
@spaces.GPU
|
| 132 |
def next_token_batch(
|
| 133 |
model: GPT,
|
| 134 |
audio_features: torch.tensor,
|
|
|
|
| 163 |
mel = whisper.log_mel_spectrogram(audio)
|
| 164 |
return mel, int(duration_ms / 20) + 1
|
| 165 |
|
| 166 |
+
|
| 167 |
+
@spaces.GPU
|
| 168 |
+
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
| 169 |
+
audio = reconstruct_tensors(snac_tokens, device)
|
| 170 |
+
with torch.inference_mode():
|
| 171 |
+
audio_hat = snacmodel.decode(audio)
|
| 172 |
+
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
| 173 |
+
audio_data = audio_data.astype(np.int16)
|
| 174 |
+
audio_data = audio_data.tobytes()
|
| 175 |
+
return audio_data
|
| 176 |
+
|
| 177 |
|
| 178 |
# @torch.inference_mode()
|
|
|
|
| 179 |
def run_AT_batch_stream(
|
| 180 |
audio_path,
|
| 181 |
stream_stride=4,
|
|
|
|
| 189 |
|
| 190 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
| 191 |
|
|
|
|
| 192 |
model.set_kv_cache(batch_size=2)
|
| 193 |
|
| 194 |
mel, leng = load_audio(audio_path)
|
| 195 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
| 196 |
T = input_ids[0].size(1)
|
| 197 |
device = input_ids[0].device
|
| 198 |
|