from datetime import datetime
from miditok import Event, MIDILike
import os
import json
from time import perf_counter
from constants import DRUMS_BEAT_QUANTIZATION, NONE_DRUMS_BEAT_QUANTIZATION
from joblib import Parallel, delayed
from zipfile import ZipFile, ZIP_DEFLATED
from scipy.io.wavfile import write
import numpy as np
from pydub import AudioSegment
import shutil

""" Diverse utils"""


def index_has_substring(list, substring):
    for i, s in enumerate(list):
        if substring in s:
            return i
    return -1


# TODO: Make this singleton
def get_miditok():
    pitch_range = range(0, 127)  # was (21, 109)
    beat_res = {(0, 400): 8}
    return MIDILike(pitch_range, beat_res)


def timeit(func):
    def wrapper(*args, **kwargs):
        start = perf_counter()
        result = func(*args, **kwargs)
        end = perf_counter()
        print(f"{func.__name__} took {end - start:.2f} seconds to run.")
        return result

    return wrapper


def chain(input, funcs, *params):
    """Chain functions together, passing the output of one function as the input of the next."""
    res = input
    for func in funcs:
        try:
            res = func(res, *params)
        except TypeError:
            res = func(res)
    return res


def split_dots(value):
    """Splits a string separated by dots "a.b.c" into a list of integers [a, b, c]"""
    return list(map(int, value.split(".")))


def compute_list_average(l):
    return sum(l) / len(l)


def get_datetime():
    return datetime.now().strftime("%Y%m%d_%H%M%S")


""" Encoding functions """


def int_dec_base_to_beat(beat_str):
    """
    Converts "integer.decimal.base" (str, from miditok) into beats
    e.g. "0.4.8" = 0 + 4/8 = 0.5
    Args:
        - beat_str: "integer.decimal.base"
    Returns:
        - beats: float
    """
    integer, decimal, base = split_dots(beat_str)
    return integer + decimal / base


def int_dec_base_to_delta(beat_str, instrument="drums"):
    """converts the time shift to time_delta according to Tristan's encoding scheme
    Drums TIME_DELTA are quantized according to DRUMS_BEAT_QUANTIZATION
    Other Instrument TIME_DELTA are quantized according to NONE_DRUMS_BEAT_QUANTIZATION
    """

    beat_res = (
        DRUMS_BEAT_QUANTIZATION
        if instrument.lower() == "drums"
        else NONE_DRUMS_BEAT_QUANTIZATION
    )
    time_delta = int_dec_base_to_beat(beat_str) * beat_res
    return time_delta.__int__()


def get_text(event, instrument="drums"):
    """Converts an event into a string for the midi-text format"""
    match event.type:
        case "Piece-Start":
            return "PIECE_START "
        case "Track-Start":
            return "TRACK_START "
        case "Track-End":
            return "TRACK_END "
        case "Instrument":
            if str(event.value).lower() == "drums":
                return f"INST=DRUMS "
            else:
                return f"INST={event.value} "
        case "Density":
            return f"DENSITY={event.value} "
        case "Bar-Start":
            return "BAR_START "
        case "Bar-End":
            return "BAR_END "
        case "Time-Shift":
            return f"TIME_DELTA={int_dec_base_to_delta(event.value, instrument)} "
        case "Note-On":
            return f"NOTE_ON={event.value} "
        case "Note-Off":
            return f"NOTE_OFF={event.value} "
        case _:
            return ""


""" Decoding functions """


def time_delta_to_beat(time_delta, instrument="drums"):
    """
    Converts TIME_DELTA (from midi-text) to beats according to Tristan's encoding scheme
    Args:
        - time_delta: int (TIME_DELTA)
        - instrument: str ("Drums" or other instrument): used to determine the quantization resolution defined on constants.py
    Returns:
        - beats: float
    """
    beat_res = (
        DRUMS_BEAT_QUANTIZATION
        if instrument.lower() == "drums"
        else NONE_DRUMS_BEAT_QUANTIZATION
    )
    beats = float(time_delta) / beat_res
    return beats


def beat_to_int_dec_base(beat, beat_res=8):
    """
    Converts beats into "integer.decimal.base" (str) for miditok
    Args:
        - beat_str: "integer.decimal.base"
    Returns:
        - beats: float (e.g. "0.4.8" = 0 + 4/8 = 0.5)
    """
    int_dec_base = [
        int((beat * beat_res) // beat_res),
        int((beat * beat_res) % beat_res),
        beat_res,
    ]
    return ".".join(map(str, int_dec_base))


def time_delta_to_int_dec_base(time_delta, instrument="drums"):
    return chain(
        time_delta,
        [
            time_delta_to_beat,
            beat_to_int_dec_base,
        ],
        instrument,
    )


def get_event(text, value=None, instrument="drums"):
    """Converts a midi-text like event into a miditok like event"""
    match text:
        case "PIECE_START":
            return Event("Piece-Start", value)
        case "TRACK_START":
            return Event("Track-Start", value)
        case "TRACK_END":
            return Event("Track-End", value)
        case "INST":
            if value == "DRUMS":
                value = "Drums"
            return Event("Instrument", value)
        case "BAR_START":
            return Event("Bar-Start", value)
        case "BAR_END":
            return Event("Bar-End", value)
        case "TIME_SHIFT":
            return Event("Time-Shift", value)
        case "TIME_DELTA":
            return Event("Time-Shift", time_delta_to_int_dec_base(value, instrument))
            # return Event("Time-Shift", to_beat_str(int(value) / 4))
        case "NOTE_ON":
            return Event("Note-On", value)
        case "NOTE_OFF":
            return Event("Note-Off", value)
        case _:
            return None


""" File utils"""


def writeToFile(path, content):
    if type(content) is dict:
        with open(f"{path}", "w") as json_file:
            json.dump(content, json_file)
    else:
        if type(content) is not str:
            content = str(content)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w") as f:
            f.write(content)


def readFromFile(path, isJSON=False):
    with open(path, "r") as f:
        if isJSON:
            return json.load(f)
        else:
            return f.read()


def get_files(directory, extension, recursive=False):
    """
    Given a directory, get a list of the file paths of all files matching the
    specified file extension.
    directory: the directory to search as a Path object
    extension: the file extension to match as a string
    recursive: whether to search recursively in the directory or not
    """
    if recursive:
        return list(directory.rglob(f"*.{extension}"))
    else:
        return list(directory.glob(f"*.{extension}"))


def load_jsonl(filepath):
    """Load a jsonl file"""
    with open(filepath, "r") as f:
        data = [json.loads(line) for line in f]
    return data


def write_mp3(waveform, output_path, bitrate="92k"):
    """
    Write a waveform to an mp3 file.
    output_path: Path object for the output mp3 file
    waveform: numpy array of the waveform
    bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
    """
    # write the wav file
    wav_path = output_path.with_suffix(".wav")
    write(wav_path, 44100, waveform.astype(np.float32))
    # compress the wav file as mp3
    AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
    # remove the wav file
    wav_path.unlink()


def copy_file(input_file, output_dir):
    """Copy an input file to the output_dir"""
    output_file = output_dir / input_file.name
    shutil.copy(input_file, output_file)


class FileCompressor:
    def __init__(self, input_directory, output_directory, n_jobs=-1):
        self.input_directory = input_directory
        self.output_directory = output_directory
        self.n_jobs = n_jobs

    # File compression and decompression
    def unzip_file(self, file):
        """uncompress single zip file"""
        with ZipFile(file, "r") as zip_ref:
            zip_ref.extractall(self.output_directory)

    def zip_file(self, file):
        """compress a single text file to a new zip file and delete the original"""
        output_file = self.output_directory / (file.stem + ".zip")
        with ZipFile(output_file, "w") as zip_ref:
            zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED)
            file.unlink()

    @timeit
    def unzip(self):
        """uncompress all zip files in folder"""
        files = get_files(self.input_directory, extension="zip")
        Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files)

    @timeit
    def zip(self):
        """compress all text files in folder to new zip files and remove the text files"""
        files = get_files(self.output_directory, extension="txt")
        Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)