import math
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from torch import nn

#.............................................



class VitsConfig(PretrainedConfig):
    model_type = "vits"
    
    def __init__(
        self,
        vocab_size=38,
        hidden_size=192,
        num_hidden_layers=6,
        num_attention_heads=2,
        window_size=4,
        use_bias=True,
        ffn_dim=768,
        layerdrop=0.1,
        ffn_kernel_size=3,
        flow_size=192,
        spectrogram_bins=513,
        hidden_act="relu",
        hidden_dropout=0.1,
        attention_dropout=0.1,
        activation_dropout=0.1,
        initializer_range=0.02,
        layer_norm_eps=1e-5,
        use_stochastic_duration_prediction=True,
        num_speakers=1,
        speaker_embedding_size=0,
        upsample_initial_channel=512,
        upsample_rates=[8, 8, 2, 2],
        upsample_kernel_sizes=[16, 16, 4, 4],
        resblock_kernel_sizes=[3, 7, 11],
        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        leaky_relu_slope=0.1,
        depth_separable_channels=2,
        depth_separable_num_layers=3,
        duration_predictor_flow_bins=10,
        duration_predictor_tail_bound=5.0,
        duration_predictor_kernel_size=3,
        duration_predictor_dropout=0.5,
        duration_predictor_num_flows=4,
        duration_predictor_filter_channels=256,
        prior_encoder_num_flows=4,
        prior_encoder_num_wavenet_layers=4,
        posterior_encoder_num_wavenet_layers=16,
        wavenet_kernel_size=5,
        wavenet_dilation_rate=1,
        wavenet_dropout=0.0,
        speaking_rate=1.0,
        noise_scale=0.667,
        noise_scale_duration=0.8,
        sampling_rate=16_000,
        discriminator_kernel_size=5,
        discriminator_stride=3,
        discriminator_periods=[2, 3, 5, 7, 11],
        discriminator_period_channels=[1, 32, 128, 512, 1024],
        discriminator_scale_channels=[1, 16, 64, 256, 1024],
        segment_size=8192,
        hop_length=256,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.window_size = window_size
        self.use_bias = use_bias
        self.ffn_dim = ffn_dim
        self.layerdrop = layerdrop
        self.ffn_kernel_size = ffn_kernel_size
        self.flow_size = flow_size
        self.spectrogram_bins = spectrogram_bins
        self.hidden_act = hidden_act
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.activation_dropout = activation_dropout
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
        self.num_speakers = num_speakers
        self.speaker_embedding_size = speaker_embedding_size
        self.upsample_initial_channel = upsample_initial_channel
        self.upsample_rates = upsample_rates
        self.upsample_kernel_sizes = upsample_kernel_sizes
        self.resblock_kernel_sizes = resblock_kernel_sizes
        self.resblock_dilation_sizes = resblock_dilation_sizes
        self.leaky_relu_slope = leaky_relu_slope
        self.depth_separable_channels = depth_separable_channels
        self.depth_separable_num_layers = depth_separable_num_layers
        self.duration_predictor_flow_bins = duration_predictor_flow_bins
        self.duration_predictor_tail_bound = duration_predictor_tail_bound
        self.duration_predictor_kernel_size = duration_predictor_kernel_size
        self.duration_predictor_dropout = duration_predictor_dropout
        self.duration_predictor_num_flows = duration_predictor_num_flows
        self.duration_predictor_filter_channels = duration_predictor_filter_channels
        self.prior_encoder_num_flows = prior_encoder_num_flows
        self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
        self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
        self.wavenet_kernel_size = wavenet_kernel_size
        self.wavenet_dilation_rate = wavenet_dilation_rate
        self.wavenet_dropout = wavenet_dropout
        self.speaking_rate = speaking_rate
        self.noise_scale = noise_scale
        self.noise_scale_duration = noise_scale_duration
        self.sampling_rate = sampling_rate

        # used for training
        self.discriminator_kernel_size = discriminator_kernel_size
        self.discriminator_stride = discriminator_stride
        self.discriminator_periods = discriminator_periods
        self.discriminator_period_channels = discriminator_period_channels
        self.discriminator_scale_channels = discriminator_scale_channels
        self.segment_size = segment_size
        self.hop_length = hop_length

        if len(upsample_kernel_sizes) != len(upsample_rates):
            raise ValueError(
                f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
                f"`upsample_rates` ({len(upsample_rates)})"
            )

        super().__init__(**kwargs)

#.............................................................................................

class VitsPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    config_class = VitsConfig
    base_model_prefix = "vits"
    main_input_name = "input_ids"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
                nn.init.uniform_(module.bias, a=-k, b=k)
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


#.............................................................................................