import warnings
warnings.filterwarnings("ignore")
import random
import torchaudio
# from torch._six import string_classes
import collections
import re
import numpy as np
from transformers import AutoTokenizer, logging
try:
    from models.clap import CLAP
    from models.mapper import get_clapcap
except:
    from .models.clap import CLAP
    from .models.mapper import get_clapcap
import math
import torchaudio.transforms as T
import os
import torch
from importlib_resources import files
import argparse
import yaml
import sys
logging.set_verbosity_error()


class CLAPWrapper():
    """
    A class for interfacing CLAP model.  
    """

    def __init__(self, model_fp, config_root, version, use_cuda=False):
        self.supported_versions = ['2022', '2023', 'clapcap']
        self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
        self.file_path = os.path.realpath(__file__)
        self.default_collate_err_msg_format = (
            "default_collate: batch must contain tensors, numpy arrays, numbers, "
            "dicts or lists; found {}")
        self.config_root = config_root
        self.config_as_str = self.get_config_path(version)
        self.model_fp = model_fp
        self.use_cuda = use_cuda
        self.version = version
        if 'clapcap' in self.version:
            self.clapcap, self.tokenizer, self.args = self.load_clapcap()
        else:
            self.clap, self.tokenizer, self.args = self.load_clap()
    
    def get_config_path(self, version):
        if version in self.supported_versions:
            # config_root = /home/zkong/audio_flamingo/audio_flamingo_v1/microsoft_clap/src/configs
            return f"{self.config_root}/config_{version}.yml"
        else:
            raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
    
    def read_config_as_args(self,config_path,args=None,is_config_str=False):
        return_dict = {}

        if config_path is not None:
            if is_config_str:
                yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
            else:
                with open(config_path, "r") as f:
                    yml_config = yaml.load(f, Loader=yaml.FullLoader)

            if args != None:
                for k, v in yml_config.items():
                    if k in args.__dict__:
                        args.__dict__[k] = v
                    else:
                        sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
            else:
                for k, v in yml_config.items():
                    return_dict[k] = v

        args = args if args != None else return_dict
        return argparse.Namespace(**args)

    def load_clap(self):
        r"""Load CLAP model with args from config file"""

        args = self.read_config_as_args(self.config_as_str, is_config_str=False)

        if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
            self.token_keys = ['input_ids', 'attention_mask']
        elif 'bert' in args.text_model:
            self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']

        clap = CLAP(
            audioenc_name=args.audioenc_name,
            sample_rate=args.sampling_rate,
            window_size=args.window_size,
            hop_size=args.hop_size,
            mel_bins=args.mel_bins,
            fmin=args.fmin,
            fmax=args.fmax,
            classes_num=args.num_classes,
            out_emb=args.out_emb,
            text_model=args.text_model,
            transformer_embed_dim=args.transformer_embed_dim,
            d_proj=args.d_proj
        )

        # Load pretrained weights for model
        model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']

        # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`: 
        # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
        clap.load_state_dict(model_state_dict)

        clap.eval()  # set clap in eval mode
        tokenizer = AutoTokenizer.from_pretrained(args.text_model)
        if 'gpt' in args.text_model:
            tokenizer.add_special_tokens({'pad_token': '!'})

        if self.use_cuda and torch.cuda.is_available():
            clap = clap.cuda()

        return clap, tokenizer, args
    
    def load_clapcap(self):
        r"""Load CLAP model with args from config file"""

        args = self.read_config_as_args(self.config_as_str, is_config_str=False)
        args.prefix_dim = args.d_proj
        text_model = args.text_model
        args.text_model = args.text_decoder
        args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False

        if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
            self.token_keys = ['input_ids', 'attention_mask']
        elif 'bert' in args.text_model:
            self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']

        clap = CLAP(
            audioenc_name=args.audioenc_name,
            sample_rate=args.sampling_rate,
            window_size=args.window_size,
            hop_size=args.hop_size,
            mel_bins=args.mel_bins,
            fmin=args.fmin,
            fmax=args.fmax,
            classes_num=args.num_classes,
            out_emb=args.out_emb,
            text_model=text_model,
            transformer_embed_dim=args.transformer_embed_dim,
            d_proj=args.d_proj
        )

        clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
                 args.num_layers, args.normalize_prefix, args.mapping_type, True, True)

        model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
        clapcap.load_state_dict(model_state_dict)

        clapcap.eval()  # set clap in eval mode
        tokenizer = AutoTokenizer.from_pretrained(args.text_model)
        if 'gpt' in args.text_model:
            tokenizer.add_special_tokens({'pad_token': '!'})

        if self.use_cuda and torch.cuda.is_available():
            clapcap = clapcap.cuda()

        return clapcap, tokenizer, args

    def default_collate(self, batch):
        r"""Puts each data field into a tensor with outer dimension batch size"""
        elem = batch[0]
        elem_type = type(elem)
        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum([x.numel() for x in batch])
                storage = elem.storage()._new_shared(numel)
                out = elem.new(storage)
            return torch.stack(batch, 0, out=out)
        elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
                and elem_type.__name__ != 'string_':
            if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
                # array of string classes and object
                if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                    raise TypeError(
                        self.default_collate_err_msg_format.format(elem.dtype))

                return self.default_collate([torch.as_tensor(b) for b in batch])
            elif elem.shape == ():  # scalars
                return torch.as_tensor(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float64)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        # elif isinstance(elem, string_classes):
        #     return batch
        elif isinstance(elem, collections.abc.Mapping):
            return {key: self.default_collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
            return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
        elif isinstance(elem, collections.abc.Sequence):
            # check to make sure that the elements in batch have consistent size
            it = iter(batch)
            elem_size = len(next(it))
            if not all(len(elem) == elem_size for elem in it):
                raise RuntimeError(
                    'each element in list of batch should be of equal size')
            transposed = zip(*batch)
            return [self.default_collate(samples) for samples in transposed]

        raise TypeError(self.default_collate_err_msg_format.format(elem_type))
    
    def read_audio(self, audio_path, resample=False):
        r"""Loads audio file or array and returns a torch tensor"""
        # Randomly sample a segment of audio_duration from the clip or pad to match duration
        audio_time_series, sample_rate = torchaudio.load(audio_path)
        
        resample_rate = self.args.sampling_rate
        if resample:
            resampler = T.Resample(sample_rate, resample_rate)
            audio_time_series = resampler(audio_time_series)
        return audio_time_series, sample_rate

    def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
        r"""Loads audio file and returns raw audio."""
        # Randomly sample a segment of audio_duration from the clip or pad to match duration
        audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
        audio_time_series = audio_time_series.reshape(-1)

        # audio_time_series is shorter than predefined audio duration,
        # so audio_time_series is extended
        if audio_duration*sample_rate >= audio_time_series.shape[0]:
            repeat_factor = int(np.ceil((audio_duration*sample_rate) /
                                        audio_time_series.shape[0]))
            # Repeat audio_time_series by repeat_factor to match audio_duration
            audio_time_series = audio_time_series.repeat(repeat_factor)
            # remove excess part of audio_time_series
            audio_time_series = audio_time_series[0:audio_duration*sample_rate]
        else:
            # audio_time_series is longer than predefined audio duration,
            # so audio_time_series is trimmed
            start_index = random.randrange(
                audio_time_series.shape[0] - audio_duration*sample_rate)
            audio_time_series = audio_time_series[start_index:start_index +
                                                  audio_duration*sample_rate]
        return torch.FloatTensor(audio_time_series)
    
    # modified by Kong
    def load_audio_clip_into_tensor(self, audio_clip, audio_duration, resample=False):
        r"""Loads audio clip and returns raw audio."""
        # Randomly sample a segment of audio_duration from the clip or pad to match duration
        sample_rate = 44100
        audio_time_series = audio_clip.reshape(-1)

        # audio_time_series is shorter than predefined audio duration,
        # so audio_time_series is extended
        assert audio_duration * sample_rate >= audio_time_series.shape[0], \
            'dur * sr = {} should be larger than len = {}'.format(audio_duration * sample_rate, audio_time_series.shape[0])
        repeat_factor = int(np.ceil((audio_duration*sample_rate) /
                                    audio_time_series.shape[0]))
        # Repeat audio_time_series by repeat_factor to match audio_duration
        audio_time_series = audio_time_series.repeat(repeat_factor)
        # remove excess part of audio_time_series
        audio_time_series = audio_time_series[0:audio_duration*sample_rate]

        # return torch.FloatTensor(audio_time_series)
        return audio_time_series  # already on cuda device
    
    def preprocess_audio(self, audio_files, resample):
        r"""Load list of audio files and return raw audio"""
        audio_tensors = []
        for audio_file in audio_files:
            audio_tensor = self.load_audio_into_tensor(
                audio_file, self.args.duration, resample)
            audio_tensor = audio_tensor.reshape(
                1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
            audio_tensors.append(audio_tensor)
        return self.default_collate(audio_tensors)
    
    # modified by Kong
    def preprocess_audio_clips(self, audio_clips, resample=False):
        r"""Load list of audio clips and return raw audio"""
        audio_tensors = []
        for audio_clip in audio_clips:
            audio_tensor = self.load_audio_clip_into_tensor(
                audio_clip, self.args.duration, resample=False)
            audio_tensor = audio_tensor.reshape(
                1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
            audio_tensors.append(audio_tensor)
        return self.default_collate(audio_tensors)

    def preprocess_text(self, text_queries):
        r"""Load list of class labels and return tokenized text"""
        tokenized_texts = []
        for ttext in text_queries:
            if 'gpt' in self.args.text_model:
                ttext = ttext + ' <|endoftext|>'
            tok = self.tokenizer.encode_plus(
                text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
            for key in self.token_keys:
                tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
            tokenized_texts.append(tok)
        return self.default_collate(tokenized_texts)

    def get_text_embeddings(self, class_labels):
        r"""Load list of class labels and return text embeddings"""
        preprocessed_text = self.preprocess_text(class_labels)
        return self._get_text_embeddings(preprocessed_text)

    def get_audio_embeddings(self, audio_files, resample):
        r"""Load list of audio files and return a audio embeddings"""
        preprocessed_audio = self.preprocess_audio(audio_files, resample)
        return self._get_audio_embeddings(preprocessed_audio)
    
    # modified by Kong
    def get_audio_embeddings_from_clips(self, audio_clips, resample=False):
        r"""Load list of audio files and return a audio embeddings"""
        preprocessed_audio = self.preprocess_audio_clips(audio_clips, resample)
        return self._get_audio_embeddings(preprocessed_audio)

    def _get_text_embeddings(self, preprocessed_text):
        r"""Load preprocessed text and return text embeddings"""
        with torch.no_grad():
            return self.clap.caption_encoder(preprocessed_text)

    # modified by Kong
    def _get_audio_embeddings(self, preprocessed_audio):
        r"""Load preprocessed audio and return a audio embeddings"""
        with torch.no_grad():
            preprocessed_audio = preprocessed_audio.reshape(
                preprocessed_audio.shape[0], preprocessed_audio.shape[2])
            #Append [0] the audio emebdding, [1] has output class probabilities
            if 'clapcap' in self.version:
                return self.clapcap.clap(preprocessed_audio)[0]
            else:
                return self.clap.audio_encoder(preprocessed_audio)[0]

    def _generic_batch_inference(self, func, *args):
        r"""Process audio and/or text per batch"""
        input_tmp = args[0]
        batch_size = args[-1]
        # args[0] has audio_files, args[1] has class_labels
        inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
        args0_len = len(args[0])
        # compute text_embeddings once for all the audio_files batches
        if len(inputs) == 2:
            text_embeddings = self.get_text_embeddings(args[1])
            inputs = [args[0], args[1], text_embeddings]
        dataset_idx = 0
        for _ in range(math.ceil(args0_len/batch_size)):
            next_batch_idx = dataset_idx + batch_size
            # batch size is bigger than available audio/text items
            if next_batch_idx >= args0_len:
                inputs[0] = input_tmp[dataset_idx:]
                return func(*tuple(inputs))
            else:
                inputs[0] = input_tmp[dataset_idx:next_batch_idx]
                yield func(*tuple(inputs))
            dataset_idx = next_batch_idx

    def get_audio_embeddings_per_batch(self, audio_files, batch_size):
        r"""Load preprocessed audio and return a audio embeddings per batch"""
        return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)

    def get_text_embeddings_per_batch(self, class_labels, batch_size):
        r"""Load preprocessed text and return text embeddings per batch"""
        return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)

    def compute_similarity(self, audio_embeddings, text_embeddings):
        r"""Compute similarity between text and audio embeddings"""
        audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
        text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
    
        logit_scale = self.clap.logit_scale.exp()
        similarity = logit_scale*text_embeddings @ audio_embeddings.T
        return similarity.T

    def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
        r"""Compute classification probabilities for each audio recording in a batch and each class label"""
        return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
    
    def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
        r"""Generate audio captions for each audio recording in a batch"""
        captions = []
        audio_tensors = self.preprocess_audio(audio_files, resample)

        with torch.no_grad():
            prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
            if self.args.normalize_prefix:
                prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
            prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])

            for i in range(len(audio_tensors)):
                gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
                                                            beam_size=beam_size,\
                                                            entry_length=entry_length,\
                                                            temperature=temperature)[0]
                captions.append(gen_caption.capitalize())
        return captions
    
    def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
                  entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
        r"""Generate captions by beam search decoding"""
        self.clapcap.eval()
        stop_token_index = self.tokenizer.encode(stop_token)[0]
        tokens = None
        scores = None
        device = next(self.clapcap.parameters()).device
        seq_lengths = torch.ones(beam_size, device=device)
        is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
        with torch.no_grad():
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(self.tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)
                    generated = self.clapcap.gpt.transformer.wte(tokens)
            for i in range(entry_length):
                outputs = self.clapcap.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                logits = logits.softmax(-1).log()
                if scores is None:
                    scores, next_tokens = logits.topk(beam_size, -1)
                    generated = generated.expand(beam_size, *generated.shape[1:])
                    next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                    if tokens is None:
                        tokens = next_tokens
                    else:
                        tokens = tokens.expand(beam_size, *tokens.shape[1:])
                        tokens = torch.cat((tokens, next_tokens), dim=1)
                else:
                    logits[is_stopped] = -float(np.inf)
                    logits[is_stopped, 0] = 0
                    scores_sum = scores[:, None] + logits
                    seq_lengths[~is_stopped] += 1
                    scores_sum_average = scores_sum / seq_lengths[:, None]
                    scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                    next_tokens_source = next_tokens // scores_sum.shape[1]
                    seq_lengths = seq_lengths[next_tokens_source]
                    next_tokens = next_tokens % scores_sum.shape[1]
                    next_tokens = next_tokens.unsqueeze(1)
                    tokens = tokens[next_tokens_source]
                    tokens = torch.cat((tokens, next_tokens), dim=1)
                    generated = generated[next_tokens_source]
                    scores = scores_sum_average * seq_lengths
                    is_stopped = is_stopped[next_tokens_source]
                next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
                if is_stopped.all():
                    break
        scores = scores / seq_lengths
        output_list = tokens.cpu().numpy()
        output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
        order = scores.argsort(descending=True)
        output_texts = [output_texts[i] for i in order]
        return output_texts