Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L4
	Update model to large sft
Browse files- app.py +53 -37
- tools/extract_model.py +0 -21
- tools/llama/build_dataset.py +0 -165
- tools/llama/generate.py +64 -8
- tools/llama/rebuild_tokenizer.py +0 -57
- tools/merge_asr_files.py +0 -55
- tools/vqgan/create_train_split.py +0 -54
- tools/vqgan/extract_vq.py +0 -213
- tools/whisper_asr.py +0 -113
    	
        app.py
    CHANGED
    
    | @@ -1,34 +1,26 @@ | |
| 1 | 
             
            import subprocess as sp
         | 
| 2 | 
             
            import os
         | 
|  | |
| 3 |  | 
| 4 | 
             
            # Download if not exists
         | 
| 5 | 
             
            os.makedirs("checkpoints", exist_ok=True)
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            if not os.path.exists("checkpoints/text2semantic-medium-v1-2k.pth"):
         | 
| 8 | 
            -
                print("Downloading text2semantic-medium-v1-2k.pth")
         | 
| 9 | 
            -
                sp.run(["wget", "-q", "-O", "checkpoints/text2semantic-medium-v1-2k.pth", os.environ["CKPT_SEMANTIC"]])
         | 
| 10 | 
            -
             | 
| 11 | 
            -
            if not os.path.exists("checkpoints/vq-gan-group-fsq-2x1024.pth"):
         | 
| 12 | 
            -
                print("Downloading vq-gan-group-fsq-2x1024.pth")
         | 
| 13 | 
            -
                sp.run(["wget", "-q", "-O", "checkpoints/vq-gan-group-fsq-2x1024.pth", os.environ["CKPT_VQGAN"]])
         | 
| 14 |  | 
| 15 | 
             
            print("All checkpoints downloaded")
         | 
| 16 |  | 
| 17 | 
             
            import html
         | 
|  | |
|  | |
| 18 | 
             
            from argparse import ArgumentParser
         | 
| 19 | 
            -
            from io import BytesIO
         | 
| 20 | 
             
            from pathlib import Path
         | 
| 21 |  | 
| 22 | 
             
            import gradio as gr
         | 
| 23 | 
             
            import librosa
         | 
| 24 | 
            -
            import spaces
         | 
| 25 | 
             
            import torch
         | 
| 26 | 
             
            from loguru import logger
         | 
| 27 | 
            -
            from torchaudio import functional as AF
         | 
| 28 | 
             
            from transformers import AutoTokenizer
         | 
| 29 |  | 
| 30 | 
            -
            from tools.llama.generate import  | 
| 31 | 
            -
            from tools.llama.generate import load_model as load_llama_model
         | 
| 32 | 
             
            from tools.vqgan.inference import load_model as load_vqgan_model
         | 
| 33 |  | 
| 34 | 
             
            # Make einx happy
         | 
| @@ -52,16 +44,30 @@ We are not responsible for any misuse of the model, please consider your local l | |
| 52 |  | 
| 53 | 
             
            TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
         | 
| 54 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
            def build_html_error_message(error):
         | 
| 57 | 
             
                return f"""
         | 
| 58 | 
            -
                <div style="color: red;  | 
|  | |
| 59 | 
             
                    {html.escape(error)}
         | 
| 60 | 
             
                </div>
         | 
| 61 | 
             
                """
         | 
| 62 |  | 
| 63 |  | 
| 64 | 
            -
            @ | 
|  | |
| 65 | 
             
            def inference(
         | 
| 66 | 
             
                text,
         | 
| 67 | 
             
                enable_reference_audio,
         | 
| @@ -73,13 +79,10 @@ def inference( | |
| 73 | 
             
                top_p,
         | 
| 74 | 
             
                repetition_penalty,
         | 
| 75 | 
             
                temperature,
         | 
| 76 | 
            -
                speaker | 
| 77 | 
             
            ):
         | 
| 78 | 
            -
                if len(reference_text) > 100:
         | 
| 79 | 
            -
                    return None, "Ref text is too long, please keep it under 100 characters."
         | 
| 80 | 
            -
             | 
| 81 | 
             
                if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
         | 
| 82 | 
            -
                    return None, "Text is too long, please keep it under  | 
| 83 |  | 
| 84 | 
             
                # Parse reference audio aka prompt
         | 
| 85 | 
             
                prompt_tokens = None
         | 
| @@ -103,11 +106,9 @@ def inference( | |
| 103 | 
             
                    prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
         | 
| 104 |  | 
| 105 | 
             
                # LLAMA Inference
         | 
| 106 | 
            -
                 | 
| 107 | 
            -
                    model=llama_model,
         | 
| 108 | 
             
                    tokenizer=llama_tokenizer,
         | 
| 109 | 
             
                    device=vqgan_model.device,
         | 
| 110 | 
            -
                    decode_one_token=decode_one_token,
         | 
| 111 | 
             
                    max_new_tokens=max_new_tokens,
         | 
| 112 | 
             
                    text=text,
         | 
| 113 | 
             
                    top_k=int(top_k) if top_k > 0 else None,
         | 
| @@ -123,7 +124,18 @@ def inference( | |
| 123 | 
             
                    prompt_text=reference_text if enable_reference_audio else None,
         | 
| 124 | 
             
                )
         | 
| 125 |  | 
| 126 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 127 |  | 
| 128 | 
             
                # VQGAN Inference
         | 
| 129 | 
             
                feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
         | 
| @@ -151,9 +163,7 @@ def build_app(): | |
| 151 | 
             
                    with gr.Row():
         | 
| 152 | 
             
                        with gr.Column(scale=3):
         | 
| 153 | 
             
                            text = gr.Textbox(
         | 
| 154 | 
            -
                                label="Input Text / 输入文本",
         | 
| 155 | 
            -
                                placeholder=TEXTBOX_PLACEHOLDER,
         | 
| 156 | 
            -
                                lines=15,
         | 
| 157 | 
             
                            )
         | 
| 158 |  | 
| 159 | 
             
                            with gr.Row():
         | 
| @@ -198,11 +208,11 @@ def build_app(): | |
| 198 | 
             
                                        step=0.01,
         | 
| 199 | 
             
                                    )
         | 
| 200 |  | 
| 201 | 
            -
                                     | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
                                     | 
| 206 |  | 
| 207 | 
             
                                with gr.Tab(label="Reference Audio / 参考音频"):
         | 
| 208 | 
             
                                    gr.Markdown(
         | 
| @@ -248,7 +258,7 @@ def build_app(): | |
| 248 | 
             
                            top_p,
         | 
| 249 | 
             
                            repetition_penalty,
         | 
| 250 | 
             
                            temperature,
         | 
| 251 | 
            -
                             | 
| 252 | 
             
                        ],
         | 
| 253 | 
             
                        [audio, error],
         | 
| 254 | 
             
                        concurrency_limit=1,
         | 
| @@ -262,10 +272,10 @@ def parse_args(): | |
| 262 | 
             
                parser.add_argument(
         | 
| 263 | 
             
                    "--llama-checkpoint-path",
         | 
| 264 | 
             
                    type=Path,
         | 
| 265 | 
            -
                    default="checkpoints/text2semantic- | 
| 266 | 
             
                )
         | 
| 267 | 
             
                parser.add_argument(
         | 
| 268 | 
            -
                    "--llama-config-name", type=str, default=" | 
| 269 | 
             
                )
         | 
| 270 | 
             
                parser.add_argument(
         | 
| 271 | 
             
                    "--vqgan-checkpoint-path",
         | 
| @@ -278,7 +288,7 @@ def parse_args(): | |
| 278 | 
             
                parser.add_argument("--half", action="store_true")
         | 
| 279 | 
             
                parser.add_argument("--max-length", type=int, default=2048)
         | 
| 280 | 
             
                parser.add_argument("--compile", action="store_true")
         | 
| 281 | 
            -
                parser.add_argument("--max-gradio-length", type=int, default= | 
| 282 |  | 
| 283 | 
             
                return parser.parse_args()
         | 
| 284 |  | 
| @@ -288,9 +298,15 @@ if __name__ == "__main__": | |
| 288 |  | 
| 289 | 
             
                args.precision = torch.half if args.half else torch.bfloat16
         | 
| 290 | 
             
                args.compile = True
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 291 |  | 
| 292 | 
             
                logger.info("Loading Llama model...")
         | 
| 293 | 
            -
                 | 
| 294 | 
             
                    config_name=args.llama_config_name,
         | 
| 295 | 
             
                    checkpoint_path=args.llama_checkpoint_path,
         | 
| 296 | 
             
                    device=args.device,
         | 
|  | |
| 1 | 
             
            import subprocess as sp
         | 
| 2 | 
             
            import os
         | 
| 3 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 4 |  | 
| 5 | 
             
            # Download if not exists
         | 
| 6 | 
             
            os.makedirs("checkpoints", exist_ok=True)
         | 
| 7 | 
            +
            hf_hub_download("fishaudio/fish-speech-1", "./checkpoints/fish-speech-1")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
             
            print("All checkpoints downloaded")
         | 
| 10 |  | 
| 11 | 
             
            import html
         | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            import threading
         | 
| 14 | 
             
            from argparse import ArgumentParser
         | 
|  | |
| 15 | 
             
            from pathlib import Path
         | 
| 16 |  | 
| 17 | 
             
            import gradio as gr
         | 
| 18 | 
             
            import librosa
         | 
|  | |
| 19 | 
             
            import torch
         | 
| 20 | 
             
            from loguru import logger
         | 
|  | |
| 21 | 
             
            from transformers import AutoTokenizer
         | 
| 22 |  | 
| 23 | 
            +
            from tools.llama.generate import launch_thread_safe_queue
         | 
|  | |
| 24 | 
             
            from tools.vqgan.inference import load_model as load_vqgan_model
         | 
| 25 |  | 
| 26 | 
             
            # Make einx happy
         | 
|  | |
| 44 |  | 
| 45 | 
             
            TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
         | 
| 46 |  | 
| 47 | 
            +
            try:
         | 
| 48 | 
            +
                import spaces
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                GPU_DECORATOR = spaces.GPU
         | 
| 51 | 
            +
            except ImportError:
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def GPU_DECORATOR(func):
         | 
| 54 | 
            +
                    def wrapper(*args, **kwargs):
         | 
| 55 | 
            +
                        return func(*args, **kwargs)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    return wrapper
         | 
| 58 | 
            +
             | 
| 59 |  | 
| 60 | 
             
            def build_html_error_message(error):
         | 
| 61 | 
             
                return f"""
         | 
| 62 | 
            +
                <div style="color: red; 
         | 
| 63 | 
            +
                font-weight: bold;">
         | 
| 64 | 
             
                    {html.escape(error)}
         | 
| 65 | 
             
                </div>
         | 
| 66 | 
             
                """
         | 
| 67 |  | 
| 68 |  | 
| 69 | 
            +
            @GPU_DECORATOR
         | 
| 70 | 
            +
            @torch.inference_mode()
         | 
| 71 | 
             
            def inference(
         | 
| 72 | 
             
                text,
         | 
| 73 | 
             
                enable_reference_audio,
         | 
|  | |
| 79 | 
             
                top_p,
         | 
| 80 | 
             
                repetition_penalty,
         | 
| 81 | 
             
                temperature,
         | 
| 82 | 
            +
                speaker,
         | 
| 83 | 
             
            ):
         | 
|  | |
|  | |
|  | |
| 84 | 
             
                if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
         | 
| 85 | 
            +
                    return None, f"Text is too long, please keep it under {args.max_gradio_length} characters."
         | 
| 86 |  | 
| 87 | 
             
                # Parse reference audio aka prompt
         | 
| 88 | 
             
                prompt_tokens = None
         | 
|  | |
| 106 | 
             
                    prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
         | 
| 107 |  | 
| 108 | 
             
                # LLAMA Inference
         | 
| 109 | 
            +
                request = dict(
         | 
|  | |
| 110 | 
             
                    tokenizer=llama_tokenizer,
         | 
| 111 | 
             
                    device=vqgan_model.device,
         | 
|  | |
| 112 | 
             
                    max_new_tokens=max_new_tokens,
         | 
| 113 | 
             
                    text=text,
         | 
| 114 | 
             
                    top_k=int(top_k) if top_k > 0 else None,
         | 
|  | |
| 124 | 
             
                    prompt_text=reference_text if enable_reference_audio else None,
         | 
| 125 | 
             
                )
         | 
| 126 |  | 
| 127 | 
            +
                payload = dict(
         | 
| 128 | 
            +
                    event=threading.Event(),
         | 
| 129 | 
            +
                    request=request,
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
                llama_queue.put(payload)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                # Wait for the result
         | 
| 134 | 
            +
                payload["event"].wait()
         | 
| 135 | 
            +
                if payload["success"] is False:
         | 
| 136 | 
            +
                    raise payload["response"]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                codes = payload["response"][0]
         | 
| 139 |  | 
| 140 | 
             
                # VQGAN Inference
         | 
| 141 | 
             
                feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
         | 
|  | |
| 163 | 
             
                    with gr.Row():
         | 
| 164 | 
             
                        with gr.Column(scale=3):
         | 
| 165 | 
             
                            text = gr.Textbox(
         | 
| 166 | 
            +
                                label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
         | 
|  | |
|  | |
| 167 | 
             
                            )
         | 
| 168 |  | 
| 169 | 
             
                            with gr.Row():
         | 
|  | |
| 208 | 
             
                                        step=0.01,
         | 
| 209 | 
             
                                    )
         | 
| 210 |  | 
| 211 | 
            +
                                    speaker = gr.Textbox(
         | 
| 212 | 
            +
                                        label="Speaker / 说话人",
         | 
| 213 | 
            +
                                        placeholder="Type name of the speaker / 输入说话人的名称",
         | 
| 214 | 
            +
                                        lines=1,
         | 
| 215 | 
            +
                                    )
         | 
| 216 |  | 
| 217 | 
             
                                with gr.Tab(label="Reference Audio / 参考音频"):
         | 
| 218 | 
             
                                    gr.Markdown(
         | 
|  | |
| 258 | 
             
                            top_p,
         | 
| 259 | 
             
                            repetition_penalty,
         | 
| 260 | 
             
                            temperature,
         | 
| 261 | 
            +
                            speaker,
         | 
| 262 | 
             
                        ],
         | 
| 263 | 
             
                        [audio, error],
         | 
| 264 | 
             
                        concurrency_limit=1,
         | 
|  | |
| 272 | 
             
                parser.add_argument(
         | 
| 273 | 
             
                    "--llama-checkpoint-path",
         | 
| 274 | 
             
                    type=Path,
         | 
| 275 | 
            +
                    default="checkpoints/text2semantic-sft-large-v1-4k.pth",
         | 
| 276 | 
             
                )
         | 
| 277 | 
             
                parser.add_argument(
         | 
| 278 | 
            +
                    "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
         | 
| 279 | 
             
                )
         | 
| 280 | 
             
                parser.add_argument(
         | 
| 281 | 
             
                    "--vqgan-checkpoint-path",
         | 
|  | |
| 288 | 
             
                parser.add_argument("--half", action="store_true")
         | 
| 289 | 
             
                parser.add_argument("--max-length", type=int, default=2048)
         | 
| 290 | 
             
                parser.add_argument("--compile", action="store_true")
         | 
| 291 | 
            +
                parser.add_argument("--max-gradio-length", type=int, default=0)
         | 
| 292 |  | 
| 293 | 
             
                return parser.parse_args()
         | 
| 294 |  | 
|  | |
| 298 |  | 
| 299 | 
             
                args.precision = torch.half if args.half else torch.bfloat16
         | 
| 300 | 
             
                args.compile = True
         | 
| 301 | 
            +
                args.max_gradio_length = 1024
         | 
| 302 | 
            +
                args.tokenizer = "./checkpoints/fish-speech-1"
         | 
| 303 | 
            +
                args.llama_checkpoint_path = "./checkpoints/text2semantic-sft-large-v1-4k.pth"
         | 
| 304 | 
            +
                args.llama_config_name = "dual_ar_2_codebook_large"
         | 
| 305 | 
            +
                args.vqgan_checkpoint_path = "./checkpoints/vq-gan-group-fsq-2x1024.pth"
         | 
| 306 | 
            +
                args.vqgan_config_name = "vqgan_pretrain"
         | 
| 307 |  | 
| 308 | 
             
                logger.info("Loading Llama model...")
         | 
| 309 | 
            +
                llama_queue = launch_thread_safe_queue(
         | 
| 310 | 
             
                    config_name=args.llama_config_name,
         | 
| 311 | 
             
                    checkpoint_path=args.llama_checkpoint_path,
         | 
| 312 | 
             
                    device=args.device,
         | 
    	
        tools/extract_model.py
    DELETED
    
    | @@ -1,21 +0,0 @@ | |
| 1 | 
            -
            import click
         | 
| 2 | 
            -
            import torch
         | 
| 3 | 
            -
            from loguru import logger
         | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            @click.command()
         | 
| 7 | 
            -
            @click.argument("model_path")
         | 
| 8 | 
            -
            @click.argument("output_path")
         | 
| 9 | 
            -
            def main(model_path, output_path):
         | 
| 10 | 
            -
                if model_path == output_path:
         | 
| 11 | 
            -
                    logger.error("Model path and output path are the same")
         | 
| 12 | 
            -
                    return
         | 
| 13 | 
            -
             | 
| 14 | 
            -
                logger.info(f"Loading model from {model_path}")
         | 
| 15 | 
            -
                state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
         | 
| 16 | 
            -
                torch.save(state_dict, output_path)
         | 
| 17 | 
            -
                logger.info(f"Model saved to {output_path}")
         | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
            if __name__ == "__main__":
         | 
| 21 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/llama/build_dataset.py
    DELETED
    
    | @@ -1,165 +0,0 @@ | |
| 1 | 
            -
            import itertools
         | 
| 2 | 
            -
            import os
         | 
| 3 | 
            -
            import re
         | 
| 4 | 
            -
            from collections import defaultdict
         | 
| 5 | 
            -
            from functools import partial
         | 
| 6 | 
            -
            from multiprocessing import Pool
         | 
| 7 | 
            -
            from pathlib import Path
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            import click
         | 
| 10 | 
            -
            import numpy as np
         | 
| 11 | 
            -
            from loguru import logger
         | 
| 12 | 
            -
            from tqdm import tqdm
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
         | 
| 15 | 
            -
            from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
         | 
| 16 | 
            -
            from fish_speech.utils.file import load_filelist
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            # To avoid CPU overload
         | 
| 19 | 
            -
            os.environ["MKL_NUM_THREADS"] = "1"
         | 
| 20 | 
            -
            os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
            def task_generator_folder(root: Path, text_extension: str):
         | 
| 24 | 
            -
                files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
         | 
| 25 | 
            -
                files = sorted(files)
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                grouped_files = defaultdict(list)
         | 
| 28 | 
            -
                for file in tqdm(files, desc=f"Grouping {root}"):
         | 
| 29 | 
            -
                    p = str(file.parent)
         | 
| 30 | 
            -
             | 
| 31 | 
            -
                    try:
         | 
| 32 | 
            -
                        if isinstance(text_extension, str):
         | 
| 33 | 
            -
                            texts = [file.with_suffix(text_extension).read_text()]
         | 
| 34 | 
            -
                        else:
         | 
| 35 | 
            -
                            texts = [file.with_suffix(ext).read_text() for ext in text_extension]
         | 
| 36 | 
            -
                    except Exception as e:
         | 
| 37 | 
            -
                        logger.error(f"Failed to read text {file}: {e}")
         | 
| 38 | 
            -
                        continue
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                    grouped_files[p].append((file, texts))
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                logger.info(
         | 
| 43 | 
            -
                    f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
         | 
| 44 | 
            -
                )
         | 
| 45 | 
            -
                for name, subset in grouped_files.items():
         | 
| 46 | 
            -
                    yield name, subset, "folder"
         | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
            def task_generator_filelist(filelist):
         | 
| 50 | 
            -
                grouped_files = defaultdict(list)
         | 
| 51 | 
            -
                for filename, speaker, _, text in load_filelist(filelist):
         | 
| 52 | 
            -
                    grouped_files[speaker].append((Path(filename), [text]))
         | 
| 53 | 
            -
             | 
| 54 | 
            -
                logger.info(f"Found {len(grouped_files)} groups in {filelist}")
         | 
| 55 | 
            -
                for speaker, values in grouped_files.items():
         | 
| 56 | 
            -
                    yield speaker, values, "filelist"
         | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
            def run_task(task):
         | 
| 60 | 
            -
                name, subset, source = task
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                # Parse the files
         | 
| 63 | 
            -
                sentences = []
         | 
| 64 | 
            -
                for file in subset:
         | 
| 65 | 
            -
                    file, texts = file
         | 
| 66 | 
            -
             | 
| 67 | 
            -
                    np_file = file.with_suffix(".npy")
         | 
| 68 | 
            -
                    if np_file.exists() is False:
         | 
| 69 | 
            -
                        logger.warning(f"Can't find {np_file}")
         | 
| 70 | 
            -
                        continue
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                    new_texts = []
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    for text in texts:
         | 
| 75 | 
            -
                        # Simple cleaning: replace { xxx } and < xxx > with space
         | 
| 76 | 
            -
                        text = re.sub(r"\{.*?\}", " ", text)
         | 
| 77 | 
            -
                        text = re.sub(r"<.*?>", " ", text)
         | 
| 78 | 
            -
                        text = re.sub(r"\s+", " ", text)
         | 
| 79 | 
            -
                        new_texts.append(text)
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                    try:
         | 
| 82 | 
            -
                        semantics = np.load(np_file)
         | 
| 83 | 
            -
                    except Exception as e:
         | 
| 84 | 
            -
                        logger.error(f"Failed to parse {file}: {e}")
         | 
| 85 | 
            -
                        continue
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    if isinstance(semantics, np.ndarray):
         | 
| 88 | 
            -
                        semantics = semantics.tolist()
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                    sentences.append(
         | 
| 91 | 
            -
                        Sentence(
         | 
| 92 | 
            -
                            texts=new_texts,
         | 
| 93 | 
            -
                            semantics=[Semantics(values=s) for s in semantics],
         | 
| 94 | 
            -
                        )
         | 
| 95 | 
            -
                    )
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                # Pack the sentences
         | 
| 98 | 
            -
                return pack_pb_stream(
         | 
| 99 | 
            -
                    TextData(
         | 
| 100 | 
            -
                        source=source,
         | 
| 101 | 
            -
                        name=name,
         | 
| 102 | 
            -
                        sentences=sentences,
         | 
| 103 | 
            -
                    )
         | 
| 104 | 
            -
                )
         | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
            @click.command()
         | 
| 108 | 
            -
            @click.option(
         | 
| 109 | 
            -
                "--input",
         | 
| 110 | 
            -
                type=click.Path(path_type=Path),
         | 
| 111 | 
            -
                required=True,
         | 
| 112 | 
            -
                help="A folder containing the dataset or a filelist",
         | 
| 113 | 
            -
                multiple=True,
         | 
| 114 | 
            -
            )
         | 
| 115 | 
            -
            @click.option(
         | 
| 116 | 
            -
                "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
         | 
| 117 | 
            -
            )
         | 
| 118 | 
            -
            @click.option("--num-workers", type=int, default=16)
         | 
| 119 | 
            -
            @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
         | 
| 120 | 
            -
            @click.option(
         | 
| 121 | 
            -
                "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
         | 
| 122 | 
            -
            )
         | 
| 123 | 
            -
            def main(input, output, num_workers, text_extension, shard_size):
         | 
| 124 | 
            -
                generator_fns = []
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                for f in input:
         | 
| 127 | 
            -
                    assert f.exists(), f"{f} not found"
         | 
| 128 | 
            -
             | 
| 129 | 
            -
                    if f.is_dir():
         | 
| 130 | 
            -
                        generator_fn = task_generator_folder(f, text_extension)
         | 
| 131 | 
            -
                    else:
         | 
| 132 | 
            -
                        generator_fn = task_generator_filelist(f)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                    generator_fns.append(generator_fn)
         | 
| 135 | 
            -
             | 
| 136 | 
            -
                generator_fn = itertools.chain(*generator_fns)
         | 
| 137 | 
            -
                output.mkdir(parents=True, exist_ok=True)
         | 
| 138 | 
            -
             | 
| 139 | 
            -
                dataset_fp = None
         | 
| 140 | 
            -
                tar_idx = 0
         | 
| 141 | 
            -
                written_size = 0
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                with Pool(num_workers) as p:
         | 
| 144 | 
            -
                    for result in tqdm(p.imap_unordered(run_task, generator_fn)):
         | 
| 145 | 
            -
                        if dataset_fp is None:
         | 
| 146 | 
            -
                            dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                        dataset_fp.write(result)
         | 
| 149 | 
            -
                        written_size += len(result)
         | 
| 150 | 
            -
             | 
| 151 | 
            -
                        if written_size > shard_size * 1024 * 1024:
         | 
| 152 | 
            -
                            logger.info(f"Finished writing {tar_idx} shards to {output}")
         | 
| 153 | 
            -
                            dataset_fp.close()
         | 
| 154 | 
            -
                            dataset_fp = None
         | 
| 155 | 
            -
                            written_size = 0
         | 
| 156 | 
            -
                            tar_idx += 1
         | 
| 157 | 
            -
             | 
| 158 | 
            -
                if dataset_fp is not None:
         | 
| 159 | 
            -
                    dataset_fp.close()
         | 
| 160 | 
            -
             | 
| 161 | 
            -
                logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
         | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
            if __name__ == "__main__":
         | 
| 165 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/llama/generate.py
    CHANGED
    
    | @@ -1,9 +1,12 @@ | |
| 1 | 
             
            import os
         | 
|  | |
|  | |
| 2 | 
             
            import time
         | 
| 3 | 
             
            from pathlib import Path
         | 
| 4 | 
             
            from typing import Optional, Tuple, Union
         | 
| 5 |  | 
| 6 | 
             
            import click
         | 
|  | |
| 7 | 
             
            import numpy as np
         | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import torch._dynamo.config
         | 
| @@ -361,6 +364,7 @@ def encode_tokens( | |
| 361 | 
             
            def load_model(
         | 
| 362 | 
             
                config_name, checkpoint_path, device, precision, max_length, compile=False
         | 
| 363 | 
             
            ):
         | 
|  | |
| 364 | 
             
                with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
         | 
| 365 | 
             
                    cfg = compose(
         | 
| 366 | 
             
                        config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
         | 
| @@ -456,6 +460,7 @@ def generate_long( | |
| 456 | 
             
                speaker: Optional[str] = None,
         | 
| 457 | 
             
                prompt_text: Optional[str] = None,
         | 
| 458 | 
             
                prompt_tokens: Optional[torch.Tensor] = None,
         | 
|  | |
| 459 | 
             
            ):
         | 
| 460 | 
             
                model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 461 | 
             
                im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
         | 
| @@ -496,6 +501,10 @@ def generate_long( | |
| 496 | 
             
                    all_codes = []
         | 
| 497 | 
             
                    seg_idx = 0
         | 
| 498 |  | 
|  | |
|  | |
|  | |
|  | |
| 499 | 
             
                    while seg_idx < len(encoded):
         | 
| 500 | 
             
                        logger.info(
         | 
| 501 | 
             
                            f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
         | 
| @@ -562,10 +571,7 @@ def generate_long( | |
| 562 | 
             
                        codes = y[1:, prompt_length:-2].clone()
         | 
| 563 |  | 
| 564 | 
             
                        codes = codes - 2
         | 
| 565 | 
            -
                         | 
| 566 | 
            -
                            global_encoded.pop()
         | 
| 567 | 
            -
                            logger.warning(f"Negative code found: {codes}, retrying ...")
         | 
| 568 | 
            -
                            continue
         | 
| 569 |  | 
| 570 | 
             
                        decoded = y[:, prompt_length:-1].clone()
         | 
| 571 | 
             
                        if decoded[0, -1] != im_end_id:  # <im_end>
         | 
| @@ -576,13 +582,63 @@ def generate_long( | |
| 576 |  | 
| 577 | 
             
                        # But for global encoding, we should keep the <im_end> token
         | 
| 578 | 
             
                        global_encoded.append(decoded)
         | 
| 579 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 580 | 
             
                        seg_idx += 1
         | 
| 581 |  | 
| 582 | 
            -
                     | 
| 583 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 584 |  | 
| 585 | 
            -
             | 
| 586 |  | 
| 587 |  | 
| 588 | 
             
            @click.command()
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            import queue
         | 
| 3 | 
            +
            import threading
         | 
| 4 | 
             
            import time
         | 
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
             
            from typing import Optional, Tuple, Union
         | 
| 7 |  | 
| 8 | 
             
            import click
         | 
| 9 | 
            +
            import hydra
         | 
| 10 | 
             
            import numpy as np
         | 
| 11 | 
             
            import torch
         | 
| 12 | 
             
            import torch._dynamo.config
         | 
|  | |
| 364 | 
             
            def load_model(
         | 
| 365 | 
             
                config_name, checkpoint_path, device, precision, max_length, compile=False
         | 
| 366 | 
             
            ):
         | 
| 367 | 
            +
                hydra.core.global_hydra.GlobalHydra.instance().clear()
         | 
| 368 | 
             
                with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
         | 
| 369 | 
             
                    cfg = compose(
         | 
| 370 | 
             
                        config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
         | 
|  | |
| 460 | 
             
                speaker: Optional[str] = None,
         | 
| 461 | 
             
                prompt_text: Optional[str] = None,
         | 
| 462 | 
             
                prompt_tokens: Optional[torch.Tensor] = None,
         | 
| 463 | 
            +
                is_streaming: bool = False,
         | 
| 464 | 
             
            ):
         | 
| 465 | 
             
                model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
         | 
| 466 | 
             
                im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
         | 
|  | |
| 501 | 
             
                    all_codes = []
         | 
| 502 | 
             
                    seg_idx = 0
         | 
| 503 |  | 
| 504 | 
            +
                    if use_prompt:
         | 
| 505 | 
            +
                        seg_idx = 1
         | 
| 506 | 
            +
                        global_encoded.append(encoded[0])
         | 
| 507 | 
            +
             | 
| 508 | 
             
                    while seg_idx < len(encoded):
         | 
| 509 | 
             
                        logger.info(
         | 
| 510 | 
             
                            f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
         | 
|  | |
| 571 | 
             
                        codes = y[1:, prompt_length:-2].clone()
         | 
| 572 |  | 
| 573 | 
             
                        codes = codes - 2
         | 
| 574 | 
            +
                        assert (codes >= 0).all(), f"Negative code found"
         | 
|  | |
|  | |
|  | |
| 575 |  | 
| 576 | 
             
                        decoded = y[:, prompt_length:-1].clone()
         | 
| 577 | 
             
                        if decoded[0, -1] != im_end_id:  # <im_end>
         | 
|  | |
| 582 |  | 
| 583 | 
             
                        # But for global encoding, we should keep the <im_end> token
         | 
| 584 | 
             
                        global_encoded.append(decoded)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                        if is_streaming:
         | 
| 587 | 
            +
                            assert (codes >= 0).all(), f"Negative code found: {codes}"
         | 
| 588 | 
            +
                            yield codes
         | 
| 589 | 
            +
                        else:
         | 
| 590 | 
            +
                            all_codes.append(codes)
         | 
| 591 | 
            +
             | 
| 592 | 
             
                        seg_idx += 1
         | 
| 593 |  | 
| 594 | 
            +
                    if is_streaming:
         | 
| 595 | 
            +
                        # This indicates the end of the current sample
         | 
| 596 | 
            +
                        yield None
         | 
| 597 | 
            +
                    else:
         | 
| 598 | 
            +
                        all_codes = torch.cat(all_codes, dim=1)
         | 
| 599 | 
            +
                        assert (all_codes >= 0).all(), f"Negative code found: {codes}"
         | 
| 600 | 
            +
                        yield all_codes
         | 
| 601 | 
            +
             | 
| 602 | 
            +
             | 
| 603 | 
            +
            def launch_thread_safe_queue(
         | 
| 604 | 
            +
                config_name,
         | 
| 605 | 
            +
                checkpoint_path,
         | 
| 606 | 
            +
                device,
         | 
| 607 | 
            +
                precision,
         | 
| 608 | 
            +
                max_length,
         | 
| 609 | 
            +
                compile=False,
         | 
| 610 | 
            +
            ):
         | 
| 611 | 
            +
                input_queue = queue.Queue()
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                def worker():
         | 
| 614 | 
            +
                    model, decode_one_token = load_model(
         | 
| 615 | 
            +
                        config_name, checkpoint_path, device, precision, max_length, compile=compile
         | 
| 616 | 
            +
                    )
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    while True:
         | 
| 619 | 
            +
                        item = input_queue.get()
         | 
| 620 | 
            +
                        if item is None:
         | 
| 621 | 
            +
                            break
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                        kwargs = item["request"]
         | 
| 624 | 
            +
                        event = item["event"]
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                        try:
         | 
| 627 | 
            +
                            item["success"] = True
         | 
| 628 | 
            +
                            item["response"] = list(
         | 
| 629 | 
            +
                                generate_long(
         | 
| 630 | 
            +
                                    model=model, decode_one_token=decode_one_token, **kwargs
         | 
| 631 | 
            +
                                )
         | 
| 632 | 
            +
                            )
         | 
| 633 | 
            +
                        except Exception as e:
         | 
| 634 | 
            +
                            item["success"] = False
         | 
| 635 | 
            +
                            item["response"] = e
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                        event.set()
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                threading.Thread(target=worker, daemon=True).start()
         | 
| 640 |  | 
| 641 | 
            +
                return input_queue
         | 
| 642 |  | 
| 643 |  | 
| 644 | 
             
            @click.command()
         | 
    	
        tools/llama/rebuild_tokenizer.py
    DELETED
    
    | @@ -1,57 +0,0 @@ | |
| 1 | 
            -
            from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
         | 
| 2 | 
            -
            from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            # Initialize a tokenizer
         | 
| 5 | 
            -
            tokenizer = Tokenizer(models.BPE())
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            # Customize pre-tokenization and decoding
         | 
| 8 | 
            -
            tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
         | 
| 9 | 
            -
            tokenizer.decoder = decoders.ByteLevel()
         | 
| 10 | 
            -
            tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            # Don't train the tokenizer
         | 
| 13 | 
            -
            trainer = trainers.BpeTrainer(
         | 
| 14 | 
            -
                vocab_size=0,
         | 
| 15 | 
            -
                min_frequency=2,
         | 
| 16 | 
            -
                initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
         | 
| 17 | 
            -
                special_tokens=[
         | 
| 18 | 
            -
                    "<|begin_of_sequence|>",
         | 
| 19 | 
            -
                    "<|end_of_sequence|>",
         | 
| 20 | 
            -
                    "<|im_start|>",
         | 
| 21 | 
            -
                    "<|im_sep|>",  # system, user, assistant, etc.
         | 
| 22 | 
            -
                    "<|im_end|>",
         | 
| 23 | 
            -
                    "<|semantic|>",  # audio features
         | 
| 24 | 
            -
                    "<|pad|>",
         | 
| 25 | 
            -
                ],
         | 
| 26 | 
            -
            )
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            # <|im_start|>user<|im_sep|>...<|im_end|>
         | 
| 29 | 
            -
            # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
         | 
| 30 | 
            -
            tokenizer.train_from_iterator([], trainer=trainer)
         | 
| 31 | 
            -
             | 
| 32 | 
            -
            print(len(tokenizer.get_vocab()))
         | 
| 33 | 
            -
            x = tokenizer.encode(
         | 
| 34 | 
            -
                "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
         | 
| 35 | 
            -
            ).ids
         | 
| 36 | 
            -
            print(x, len(x))
         | 
| 37 | 
            -
            print(tokenizer.decode(x, skip_special_tokens=True))
         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
            tokenizer = PreTrainedTokenizerFast(
         | 
| 41 | 
            -
                tokenizer_object=tokenizer,
         | 
| 42 | 
            -
                pad_token="<|pad|>",
         | 
| 43 | 
            -
                bos_token="<|begin_of_sequence|>",
         | 
| 44 | 
            -
                eos_token="<|end_of_sequence|>",
         | 
| 45 | 
            -
            )
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            # Try tokenizing a new sequence
         | 
| 48 | 
            -
            sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
         | 
| 49 | 
            -
            encoded = tokenizer(sequence).input_ids
         | 
| 50 | 
            -
             | 
| 51 | 
            -
            print("Test encoding....")
         | 
| 52 | 
            -
            print(f"\tSentence: {sequence}")
         | 
| 53 | 
            -
            print(f"\tEncoded: {encoded}")
         | 
| 54 | 
            -
            print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
         | 
| 55 | 
            -
            print(f"\tDecoded: {tokenizer.decode(encoded)}")
         | 
| 56 | 
            -
             | 
| 57 | 
            -
            tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/merge_asr_files.py
    DELETED
    
    | @@ -1,55 +0,0 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            from pathlib import Path
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            from pydub import AudioSegment
         | 
| 5 | 
            -
            from tqdm import tqdm
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
         | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            def merge_and_delete_files(save_dir, original_files):
         | 
| 11 | 
            -
                save_path = Path(save_dir)
         | 
| 12 | 
            -
                audio_slice_files = list_files(
         | 
| 13 | 
            -
                    path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
         | 
| 14 | 
            -
                )
         | 
| 15 | 
            -
                audio_files = {}
         | 
| 16 | 
            -
                label_files = {}
         | 
| 17 | 
            -
                for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
         | 
| 18 | 
            -
                    rel_path = Path(file_path).relative_to(save_path)
         | 
| 19 | 
            -
                    (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
         | 
| 20 | 
            -
                    if file_path.suffix == ".wav":
         | 
| 21 | 
            -
                        prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
         | 
| 22 | 
            -
                        if prefix == rel_path.parent / file_path.stem:
         | 
| 23 | 
            -
                            continue
         | 
| 24 | 
            -
                        audio = AudioSegment.from_wav(file_path)
         | 
| 25 | 
            -
                        if prefix in audio_files.keys():
         | 
| 26 | 
            -
                            audio_files[prefix] = audio_files[prefix] + audio
         | 
| 27 | 
            -
                        else:
         | 
| 28 | 
            -
                            audio_files[prefix] = audio
         | 
| 29 | 
            -
             | 
| 30 | 
            -
                    elif file_path.suffix == ".lab":
         | 
| 31 | 
            -
                        prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
         | 
| 32 | 
            -
                        if prefix == rel_path.parent / file_path.stem:
         | 
| 33 | 
            -
                            continue
         | 
| 34 | 
            -
                        with open(file_path, "r", encoding="utf-8") as f:
         | 
| 35 | 
            -
                            label = f.read()
         | 
| 36 | 
            -
                        if prefix in label_files.keys():
         | 
| 37 | 
            -
                            label_files[prefix] = label_files[prefix] + ", " + label
         | 
| 38 | 
            -
                        else:
         | 
| 39 | 
            -
                            label_files[prefix] = label
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                for prefix, audio in audio_files.items():
         | 
| 42 | 
            -
                    output_audio_path = save_path / f"{prefix}.wav"
         | 
| 43 | 
            -
                    audio.export(output_audio_path, format="wav")
         | 
| 44 | 
            -
             | 
| 45 | 
            -
                for prefix, label in label_files.items():
         | 
| 46 | 
            -
                    output_label_path = save_path / f"{prefix}.lab"
         | 
| 47 | 
            -
                    with open(output_label_path, "w", encoding="utf-8") as f:
         | 
| 48 | 
            -
                        f.write(label)
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                for file_path in original_files:
         | 
| 51 | 
            -
                    os.remove(file_path)
         | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
            if __name__ == "__main__":
         | 
| 55 | 
            -
                merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/vqgan/create_train_split.py
    DELETED
    
    | @@ -1,54 +0,0 @@ | |
| 1 | 
            -
            import math
         | 
| 2 | 
            -
            from pathlib import Path
         | 
| 3 | 
            -
            from random import Random
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            import click
         | 
| 6 | 
            -
            from loguru import logger
         | 
| 7 | 
            -
            from tqdm import tqdm
         | 
| 8 | 
            -
             | 
| 9 | 
            -
            from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
         | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
            @click.command()
         | 
| 13 | 
            -
            @click.argument("root", type=click.Path(exists=True, path_type=Path))
         | 
| 14 | 
            -
            @click.option("--val-ratio", type=float, default=None)
         | 
| 15 | 
            -
            @click.option("--val-count", type=int, default=None)
         | 
| 16 | 
            -
            @click.option("--filelist", default=None, type=Path)
         | 
| 17 | 
            -
            def main(root, val_ratio, val_count, filelist):
         | 
| 18 | 
            -
                if filelist:
         | 
| 19 | 
            -
                    files = [i[0] for i in load_filelist(filelist)]
         | 
| 20 | 
            -
                else:
         | 
| 21 | 
            -
                    files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
         | 
| 22 | 
            -
             | 
| 23 | 
            -
                logger.info(f"Found {len(files)} files")
         | 
| 24 | 
            -
                files = [str(file.relative_to(root)) for file in tqdm(files)]
         | 
| 25 | 
            -
             | 
| 26 | 
            -
                Random(42).shuffle(files)
         | 
| 27 | 
            -
             | 
| 28 | 
            -
                if val_count is None and val_ratio is None:
         | 
| 29 | 
            -
                    logger.info("Validation ratio and count not specified, using min(20%, 100)")
         | 
| 30 | 
            -
                    val_size = min(100, math.ceil(len(files) * 0.2))
         | 
| 31 | 
            -
                elif val_count is not None and val_ratio is not None:
         | 
| 32 | 
            -
                    logger.error("Cannot specify both val_count and val_ratio")
         | 
| 33 | 
            -
                    return
         | 
| 34 | 
            -
                elif val_count is not None:
         | 
| 35 | 
            -
                    if val_count < 1 or val_count > len(files):
         | 
| 36 | 
            -
                        logger.error("val_count must be between 1 and number of files")
         | 
| 37 | 
            -
                        return
         | 
| 38 | 
            -
                    val_size = val_count
         | 
| 39 | 
            -
                else:
         | 
| 40 | 
            -
                    val_size = math.ceil(len(files) * val_ratio)
         | 
| 41 | 
            -
             | 
| 42 | 
            -
                logger.info(f"Using {val_size} files for validation")
         | 
| 43 | 
            -
             | 
| 44 | 
            -
                with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
         | 
| 45 | 
            -
                    f.write("\n".join(files[val_size:]))
         | 
| 46 | 
            -
             | 
| 47 | 
            -
                with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
         | 
| 48 | 
            -
                    f.write("\n".join(files[:val_size]))
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                logger.info("Done")
         | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
            if __name__ == "__main__":
         | 
| 54 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/vqgan/extract_vq.py
    DELETED
    
    | @@ -1,213 +0,0 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
            -
            import subprocess as sp
         | 
| 3 | 
            -
            import sys
         | 
| 4 | 
            -
            import time
         | 
| 5 | 
            -
            from datetime import timedelta
         | 
| 6 | 
            -
            from functools import lru_cache
         | 
| 7 | 
            -
            from pathlib import Path
         | 
| 8 | 
            -
            from random import Random
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            import click
         | 
| 11 | 
            -
            import numpy as np
         | 
| 12 | 
            -
            import torch
         | 
| 13 | 
            -
            import torchaudio
         | 
| 14 | 
            -
            from hydra import compose, initialize
         | 
| 15 | 
            -
            from hydra.utils import instantiate
         | 
| 16 | 
            -
            from lightning import LightningModule
         | 
| 17 | 
            -
            from loguru import logger
         | 
| 18 | 
            -
            from omegaconf import OmegaConf
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            # register eval resolver
         | 
| 23 | 
            -
            OmegaConf.register_new_resolver("eval", eval)
         | 
| 24 | 
            -
            # This file is used to convert the audio files to text files using the Whisper model.
         | 
| 25 | 
            -
            # It's mainly used to generate the training data for the VQ model.
         | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            RANK = int(os.environ.get("SLURM_PROCID", 0))
         | 
| 29 | 
            -
            WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            logger_format = (
         | 
| 32 | 
            -
                "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
         | 
| 33 | 
            -
                "<level>{level: <8}</level> | "
         | 
| 34 | 
            -
                "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
         | 
| 35 | 
            -
                "{extra[rank]} - <level>{message}</level>"
         | 
| 36 | 
            -
            )
         | 
| 37 | 
            -
            logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
         | 
| 38 | 
            -
            logger.remove()
         | 
| 39 | 
            -
            logger.add(sys.stderr, format=logger_format)
         | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
            @lru_cache(maxsize=1)
         | 
| 43 | 
            -
            def get_model(
         | 
| 44 | 
            -
                config_name: str = "vqgan_pretrain",
         | 
| 45 | 
            -
                checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
         | 
| 46 | 
            -
            ):
         | 
| 47 | 
            -
                with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         | 
| 48 | 
            -
                    cfg = compose(config_name=config_name)
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                model: LightningModule = instantiate(cfg.model)
         | 
| 51 | 
            -
                state_dict = torch.load(
         | 
| 52 | 
            -
                    checkpoint_path,
         | 
| 53 | 
            -
                    map_location=model.device,
         | 
| 54 | 
            -
                )
         | 
| 55 | 
            -
                if "state_dict" in state_dict:
         | 
| 56 | 
            -
                    state_dict = state_dict["state_dict"]
         | 
| 57 | 
            -
             | 
| 58 | 
            -
                model.load_state_dict(state_dict, strict=False)
         | 
| 59 | 
            -
                model.eval()
         | 
| 60 | 
            -
                model.cuda()
         | 
| 61 | 
            -
             | 
| 62 | 
            -
                logger.info(f"Loaded model")
         | 
| 63 | 
            -
                return model
         | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
            -
            @torch.inference_mode()
         | 
| 67 | 
            -
            def process_batch(files: list[Path], model) -> float:
         | 
| 68 | 
            -
                wavs = []
         | 
| 69 | 
            -
                audio_lengths = []
         | 
| 70 | 
            -
                new_files = []
         | 
| 71 | 
            -
                max_length = total_time = 0
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                for file in files:
         | 
| 74 | 
            -
                    try:
         | 
| 75 | 
            -
                        wav, sr = torchaudio.load(
         | 
| 76 | 
            -
                            str(file), backend="sox"
         | 
| 77 | 
            -
                        )  # Need to install libsox-dev
         | 
| 78 | 
            -
                    except Exception as e:
         | 
| 79 | 
            -
                        logger.error(f"Error reading {file}: {e}")
         | 
| 80 | 
            -
                        continue
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    if wav.shape[0] > 1:
         | 
| 83 | 
            -
                        wav = wav.mean(dim=0, keepdim=True)
         | 
| 84 | 
            -
             | 
| 85 | 
            -
                    wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
         | 
| 86 | 
            -
                    total_time += len(wav) / model.sampling_rate
         | 
| 87 | 
            -
                    max_length = max(max_length, len(wav))
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                    wavs.append(wav)
         | 
| 90 | 
            -
                    audio_lengths.append(len(wav))
         | 
| 91 | 
            -
                    new_files.append(file)
         | 
| 92 | 
            -
             | 
| 93 | 
            -
                files = new_files
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                # Pad to max length
         | 
| 96 | 
            -
                for i, wav in enumerate(wavs):
         | 
| 97 | 
            -
                    wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                audios = torch.stack(wavs, dim=0)[:, None]
         | 
| 100 | 
            -
                audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                # Calculate lengths
         | 
| 103 | 
            -
                indices, feature_lengths = model.encode(audios, audio_lengths)
         | 
| 104 | 
            -
             | 
| 105 | 
            -
                # Save to disk
         | 
| 106 | 
            -
                outputs = indices.cpu().numpy()
         | 
| 107 | 
            -
             | 
| 108 | 
            -
                for file, length, feature, audio_length in zip(
         | 
| 109 | 
            -
                    files, feature_lengths, outputs, audio_lengths
         | 
| 110 | 
            -
                ):
         | 
| 111 | 
            -
                    feature = feature[:, :length]
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    # (T,)
         | 
| 114 | 
            -
                    with open(file.with_suffix(".npy"), "wb") as f:
         | 
| 115 | 
            -
                        np.save(f, feature)
         | 
| 116 | 
            -
             | 
| 117 | 
            -
                return total_time
         | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
            @click.command()
         | 
| 121 | 
            -
            @click.argument("folder")
         | 
| 122 | 
            -
            @click.option("--num-workers", default=1)
         | 
| 123 | 
            -
            @click.option("--config-name", default="vqgan_pretrain")
         | 
| 124 | 
            -
            @click.option(
         | 
| 125 | 
            -
                "--checkpoint-path",
         | 
| 126 | 
            -
                default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
         | 
| 127 | 
            -
            )
         | 
| 128 | 
            -
            @click.option("--batch-size", default=64)
         | 
| 129 | 
            -
            @click.option("--filelist", default=None, type=Path)
         | 
| 130 | 
            -
            def main(
         | 
| 131 | 
            -
                folder: str,
         | 
| 132 | 
            -
                num_workers: int,
         | 
| 133 | 
            -
                config_name: str,
         | 
| 134 | 
            -
                checkpoint_path: str,
         | 
| 135 | 
            -
                batch_size: int,
         | 
| 136 | 
            -
                filelist: Path,
         | 
| 137 | 
            -
            ):
         | 
| 138 | 
            -
                if num_workers > 1 and WORLD_SIZE != num_workers:
         | 
| 139 | 
            -
                    assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
         | 
| 140 | 
            -
             | 
| 141 | 
            -
                    logger.info(f"Spawning {num_workers} workers")
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
         | 
| 144 | 
            -
                    if visible_devices is None:
         | 
| 145 | 
            -
                        visible_devices = list(range(torch.cuda.device_count()))
         | 
| 146 | 
            -
                    else:
         | 
| 147 | 
            -
                        visible_devices = visible_devices.split(",")
         | 
| 148 | 
            -
             | 
| 149 | 
            -
                    processes = []
         | 
| 150 | 
            -
                    for i in range(num_workers):
         | 
| 151 | 
            -
                        env = os.environ.copy()
         | 
| 152 | 
            -
                        env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
         | 
| 153 | 
            -
                        env["SLURM_PROCID"] = str(i)
         | 
| 154 | 
            -
                        env["SLURM_NTASKS"] = str(num_workers)
         | 
| 155 | 
            -
             | 
| 156 | 
            -
                        processes.append(
         | 
| 157 | 
            -
                            sp.Popen(
         | 
| 158 | 
            -
                                [sys.executable] + sys.argv.copy(),
         | 
| 159 | 
            -
                                env=env,
         | 
| 160 | 
            -
                            )
         | 
| 161 | 
            -
                        )
         | 
| 162 | 
            -
             | 
| 163 | 
            -
                    for p in processes:
         | 
| 164 | 
            -
                        p.wait()
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    logger.info(f"All workers finished")
         | 
| 167 | 
            -
                    return
         | 
| 168 | 
            -
             | 
| 169 | 
            -
                # This is a worker
         | 
| 170 | 
            -
                logger.info(f"Starting worker")
         | 
| 171 | 
            -
                if filelist:
         | 
| 172 | 
            -
                    files = [i[0] for i in load_filelist(filelist)]
         | 
| 173 | 
            -
                else:
         | 
| 174 | 
            -
                    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                print(f"Found {len(files)} files")
         | 
| 177 | 
            -
                # files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
         | 
| 178 | 
            -
             | 
| 179 | 
            -
                total_files = len(files)
         | 
| 180 | 
            -
                files = files[RANK::WORLD_SIZE]
         | 
| 181 | 
            -
                logger.info(f"Processing {len(files)}/{total_files} files")
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                # Batch processing
         | 
| 184 | 
            -
                total_time = 0
         | 
| 185 | 
            -
                begin_time = time.time()
         | 
| 186 | 
            -
                processed_files = 0
         | 
| 187 | 
            -
                model = get_model(config_name, checkpoint_path)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
                for n_batch, idx in enumerate(range(0, len(files), batch_size)):
         | 
| 190 | 
            -
                    batch = files[idx : idx + batch_size]
         | 
| 191 | 
            -
                    batch_time = process_batch(batch, model)
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    total_time += batch_time
         | 
| 194 | 
            -
                    processed_files += len(batch)
         | 
| 195 | 
            -
             | 
| 196 | 
            -
                    if (n_batch + 1) % 10 == 0:
         | 
| 197 | 
            -
                        eta = (
         | 
| 198 | 
            -
                            (time.time() - begin_time)
         | 
| 199 | 
            -
                            / processed_files
         | 
| 200 | 
            -
                            * (len(files) - processed_files)
         | 
| 201 | 
            -
                        )
         | 
| 202 | 
            -
                        logger.info(
         | 
| 203 | 
            -
                            f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
         | 
| 204 | 
            -
                            + f"ETA: {timedelta(seconds=round(eta))}s"
         | 
| 205 | 
            -
                        )
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                logger.info(
         | 
| 208 | 
            -
                    f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
         | 
| 209 | 
            -
                )
         | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
            if __name__ == "__main__":
         | 
| 213 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        tools/whisper_asr.py
    DELETED
    
    | @@ -1,113 +0,0 @@ | |
| 1 | 
            -
            """
         | 
| 2 | 
            -
            Used to transcribe all audio files in one folder into another folder.
         | 
| 3 | 
            -
            e.g.
         | 
| 4 | 
            -
            Directory structure:
         | 
| 5 | 
            -
            --pre_data_root
         | 
| 6 | 
            -
            ----SP_1
         | 
| 7 | 
            -
            ------01.wav
         | 
| 8 | 
            -
            ------02.wav
         | 
| 9 | 
            -
            ------......
         | 
| 10 | 
            -
            ----SP_2
         | 
| 11 | 
            -
            ------01.wav
         | 
| 12 | 
            -
            ------02.wav
         | 
| 13 | 
            -
            ------......
         | 
| 14 | 
            -
            Use 
         | 
| 15 | 
            -
            python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1 
         | 
| 16 | 
            -
            to transcribe the first speaker.
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            Use 
         | 
| 19 | 
            -
            python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2 
         | 
| 20 | 
            -
            to transcribe the second speaker.
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
         | 
| 23 | 
            -
            """
         | 
| 24 | 
            -
            from pathlib import Path
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            import click
         | 
| 27 | 
            -
            import librosa
         | 
| 28 | 
            -
            import soundfile as sf
         | 
| 29 | 
            -
            import whisper
         | 
| 30 | 
            -
            from loguru import logger
         | 
| 31 | 
            -
            from merge_asr_files import merge_and_delete_files
         | 
| 32 | 
            -
            from tqdm import tqdm
         | 
| 33 | 
            -
             | 
| 34 | 
            -
            from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
         | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
            @click.command()
         | 
| 38 | 
            -
            @click.option("--model-size", default="large", help="Size of the Whisper model")
         | 
| 39 | 
            -
            @click.option("--audio-dir", required=True, help="Directory containing audio files")
         | 
| 40 | 
            -
            @click.option(
         | 
| 41 | 
            -
                "--save-dir", required=True, help="Directory to save processed audio files"
         | 
| 42 | 
            -
            )
         | 
| 43 | 
            -
            @click.option(
         | 
| 44 | 
            -
                "--sample-rate",
         | 
| 45 | 
            -
                default=None,
         | 
| 46 | 
            -
                type=int,
         | 
| 47 | 
            -
                help="Output sample rate, default to input sample rate",
         | 
| 48 | 
            -
            )
         | 
| 49 | 
            -
            @click.option("--device", default="cuda", help="Device to use")
         | 
| 50 | 
            -
            @click.option("--language", default="ZH", help="Language of the transcription")
         | 
| 51 | 
            -
            def main(model_size, audio_dir, save_dir, sample_rate, device, language):
         | 
| 52 | 
            -
                logger.info("Loading / Downloading OpenAI Whisper model...")
         | 
| 53 | 
            -
                model = whisper.load_model(
         | 
| 54 | 
            -
                    name=model_size,
         | 
| 55 | 
            -
                    device=device,
         | 
| 56 | 
            -
                    download_root=str(Path(".cache/whisper").resolve()),
         | 
| 57 | 
            -
                )
         | 
| 58 | 
            -
                logger.info("Model loaded.")
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                save_path = Path(save_dir)
         | 
| 61 | 
            -
                save_path.mkdir(parents=True, exist_ok=True)
         | 
| 62 | 
            -
                original_files = []
         | 
| 63 | 
            -
                audio_files = list_files(
         | 
| 64 | 
            -
                    path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
         | 
| 65 | 
            -
                )
         | 
| 66 | 
            -
                for file_path in tqdm(audio_files, desc="Processing audio file"):
         | 
| 67 | 
            -
                    file_stem = file_path.stem
         | 
| 68 | 
            -
                    file_suffix = file_path.suffix
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                    rel_path = Path(file_path).relative_to(audio_dir)
         | 
| 71 | 
            -
                    (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                    if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
         | 
| 74 | 
            -
                        save_path / rel_path.parent / f"{rel_path.stem}.lab"
         | 
| 75 | 
            -
                    ).exists():
         | 
| 76 | 
            -
                        continue
         | 
| 77 | 
            -
             | 
| 78 | 
            -
                    audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
         | 
| 79 | 
            -
                    transcription = model.transcribe(str(file_path), language=language)
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                    for segment in transcription.get("segments", []):
         | 
| 82 | 
            -
                        id, text, start, end = (
         | 
| 83 | 
            -
                            segment["id"],
         | 
| 84 | 
            -
                            segment["text"],
         | 
| 85 | 
            -
                            segment["start"],
         | 
| 86 | 
            -
                            segment["end"],
         | 
| 87 | 
            -
                        )
         | 
| 88 | 
            -
             | 
| 89 | 
            -
                        extract = audio[..., int(start * sr) : int(end * sr)]
         | 
| 90 | 
            -
                        audio_save_path = (
         | 
| 91 | 
            -
                            save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
         | 
| 92 | 
            -
                        )
         | 
| 93 | 
            -
                        sf.write(
         | 
| 94 | 
            -
                            audio_save_path,
         | 
| 95 | 
            -
                            extract,
         | 
| 96 | 
            -
                            samplerate=sr,
         | 
| 97 | 
            -
                        )
         | 
| 98 | 
            -
                        original_files.append(audio_save_path)
         | 
| 99 | 
            -
             | 
| 100 | 
            -
                        transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
         | 
| 101 | 
            -
                        with open(
         | 
| 102 | 
            -
                            transcript_save_path,
         | 
| 103 | 
            -
                            "w",
         | 
| 104 | 
            -
                            encoding="utf-8",
         | 
| 105 | 
            -
                        ) as f:
         | 
| 106 | 
            -
                            f.write(text)
         | 
| 107 | 
            -
                        original_files.append(transcript_save_path)
         | 
| 108 | 
            -
             | 
| 109 | 
            -
                merge_and_delete_files(save_dir, original_files)
         | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
            if __name__ == "__main__":
         | 
| 113 | 
            -
                main()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 

