from dataclasses import dataclass, field
from typing import Optional, List
import os

CKPT_NAME = "model.pt"
CKPT_LOCAL_DIR = "model_ckpts"
CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
CKPT_REPO = "xcczach/mini-omni"


@dataclass
class VocabConfig:
    text_vocabsize: int = 151936
    text_specialtokens: int = 64
    audio_vocabsize: int = 4096
    audio_specialtokens: int = 64
    total_vocabsize: int = 181120
    code_layer: int = 7

    padded_text_vocabsize: int = field(init=False)
    padded_audio_vocabsize: int = field(init=False)
    total_audio_vocabsize: int = field(init=False)

    eot: int = field(init=False)  # end of text token
    pad_t: int = field(init=False)  # padding text token
    input_t: int = field(init=False)  # input text token
    answer_t: int = field(init=False)  # answer text token
    asr: int = field(init=False)  # ASR token

    eoa: int = field(init=False)  # end of audio token
    pad_a: int = field(init=False)  # padding audio token
    input_a: int = field(init=False)  # input audio token
    answer_a: int = field(init=False)  # answer audio token
    split: int = field(init=False)  # split token

    def __post_init__(self):
        self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
        self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
        self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer

        self.eot = self.text_vocabsize
        self.pad_t = self.text_vocabsize + 1
        self.input_t = self.text_vocabsize + 2
        self.answer_t = self.text_vocabsize + 3
        self.asr = self.text_vocabsize + 4

        self.eoa = self.audio_vocabsize
        self.pad_a = self.audio_vocabsize + 1
        self.input_a = self.audio_vocabsize + 2
        self.answer_a = self.audio_vocabsize + 3
        self.split = self.audio_vocabsize + 4


@dataclass
class TTSAdapterConfig:
    add_qkv_bias: Optional[bool] = True
    bias: bool = False
    gelu_approximate: Optional[str] = None
    head_size: Optional[int] = 64
    intermediate_size: Optional[int] = 4864
    lm_head_bias: bool = False
    mlp_class_name: str = "GptNeoxMLP"
    n_layer: int = 6
    n_head: int = 14
    n_embd: int = 896
    n_query_groups: Optional[int] = 2
    norm_class_name: str = "RMSNorm"
    norm_eps: float = 1e-6
    parallel_residual: bool = False
    rotary_percentage: float = 1
    shared_attention_norm: bool = False

    def __post_init__(self):
        self.rope_n_elem = int(self.rotary_percentage * self.head_size)


@dataclass
class ModelConfig:
    file: str = "model/slam_model_s2s.py:model_factory"
    llm_name: str = "qwen2-0.5b"
    llm_path: str = "Qwen/Qwen2-0.5B"
    llm_type: str = "decoder_only"
    llm_dim: int = 896
    encoder_name: Optional[str] = "whisper"
    encoder_ds_rate: int = 2
    encoder_path: Optional[str] = "small"
    encoder_dim: int = 768
    encoder_projector: str = "linear"
    encoder_projector_ds_rate: int = 5
    modal: str = "audio"
    normalize: Optional[bool] = field(
        default=False,
        metadata={"help": "whether input is normalized, used for models such as wavlm"},
    )
    encoder_type: str = field(
        default="finetune",
        metadata={
            "help": "whether model is only pretrained or finetuned, used for models such as hubert"
        },
    )
    vocab_config: VocabConfig = field(default_factory=VocabConfig)
    codec_decode: bool = True
    codec_decoder_type: str = "SNAC"
    codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
    tts_adapter: bool = False
    tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)


@dataclass
class PeftConfig:
    peft_method: str = "lora"  # None , llama_adapter, prefix
    r: int = 8
    lora_alpha: int = 32
    target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
    bias: str = "none"
    task_type: str = "CAUSAL_LM"
    lora_dropout: float = 0.05
    inference_mode: bool = False


@dataclass
class TrainConfig:
    model_name: str = "s2s"
    enable_ddp: bool = False
    enable_deepspeed: bool = False
    enable_fsdp: bool = False
    low_cpu_fsdp: bool = False
    run_validation: bool = True
    batch_size_training: int = 4
    batching_strategy: str = field(
        default="custom", metadata={"help": "alternative: padding"}
    )  #
    context_length: int = 4096
    gradient_accumulation_steps: int = 1
    num_epochs: int = 1
    num_workers_dataloader: int = 2
    warmup_steps: int = 1000
    total_steps: int = 100000
    validation_interval: int = 1000
    lr: float = 1e-4
    weight_decay: float = 0.0
    gamma: float = 0.85
    seed: int = 42
    use_fp16: bool = False
    mixed_precision: bool = True
    val_batch_size: int = 1

    use_peft: bool = False
    peft_config: PeftConfig = field(default_factory=PeftConfig)
    output_dir: str = "PATH/to/save/PEFT/model"
    freeze_layers: bool = False
    num_freeze_layers: int = 1
    quantization: bool = False
    one_gpu: bool = False
    save_model: bool = True
    dist_checkpoint_root_folder: str = (
        "PATH/to/save/FSDP/model"  # will be used if using FSDP
    )
    dist_checkpoint_folder: str = "fine-tuned"  # will be used if using FSDP
    save_optimizer: bool = False  # will be used if using FSDP
    use_fast_kernels: bool = (
        False  # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
    )
    run_test_during_validation: bool = False
    run_test_during_validation_file: str = "test.wav"
    run_test_during_validation_prompt: str = "<|S2S|>"
    freeze_llm: bool = field(
        default=True,
        metadata={
            "help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
        },
    )
    freeze_encoder: bool = True
    train_embed_only: bool = False
    train_audio_embed_only: bool = False
    task_type: str = "s2s"


@dataclass
class DataConfig:
    dataset: str = "speech_dataset_s2s"
    file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
    train_data_path: Optional[str] = None
    val_data_path: Optional[str] = None
    train_split: str = "train"
    test_split: str = "validation"
    prompt: Optional[str] = None
    data_path: Optional[str] = None
    max_words: Optional[int] = None
    max_mel: Optional[float] = None
    fix_length_audio: int = -1
    inference_mode: bool = True
    input_type: str = field(
        default="mel",
        metadata={"help": "Use raw when input is wav, mel when for whisper"},
    )
    mel_size: int = field(
        default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
    )
    normalize: Optional[bool] = field(
        default=False,
        metadata={"help": "whether input is normalized, used for models such as wavlm"},
    )
    seed: int = 42
    manifest_format: str = field(
        default="datasets", metadata={"help": "alternative: jsonl"}
    )
    split_size: float = 0.1

    vocab_config: VocabConfig = field(default_factory=VocabConfig)
    load_from_cache_file: bool = False
    task_type: str = "s2s"


@dataclass
class DecodeConfig:
    do_sample: bool = False
    max_new_tokens: int = 300
    min_length: int = 10
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.9
    num_beams: int = 1
    num_return_sequences: int = 1
    num_samples: int = 1
    max_time: float = 0.0
    repetition_penalty: float = 1.0
    length_penalty: float = 1.0
    early_stopping: bool = False
    no_repeat_ngram_size: int = 0
    bad_words_ids: List = field(default_factory=list)
    num_beam_groups: int = 1
    diversity_penalty: float = 0.0
    task_type: str = "s2s"
    decode_text_only: bool = False


@dataclass
class FSDPConfig:
    mixed_precision: bool = True
    use_fp16: bool = False
    # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
    sharding_strategy: str = (
        "NO_SHARD"  # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
    )
    checkpoint_type: str = (
        "SHARDED_STATE_DICT"  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
    )
    fsdp_activation_checkpointing: bool = True
    fsdp_cpu_offload: bool = False
    pure_bf16: bool = False
    optimizer: str = "AdamW"


@dataclass
class LogConfig:
    use_wandb: bool = False
    wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
    wandb_entity_name: str = "project_name"
    wandb_project_name: str = "project_name"
    wandb_exp_name: str = "exp_name"
    log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
    log_interval: int = 10
    online_output_dir: Optional[str] = None


@dataclass
class InferenceConfig:
    dataset_config: DataConfig = field(default_factory=DataConfig)
    model_config: ModelConfig = field(default_factory=ModelConfig)
    train_config: TrainConfig = field(default_factory=TrainConfig)
    decode_config: DecodeConfig = field(default_factory=DecodeConfig)