from typing import List, Tuple

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch

from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and the model
model, tokenizer = get_model_and_tokenizer()

# Instruments
with open('instruments.json', 'r') as f:
    instruments = json.load(f)


def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str:
    """
    Creates a seed string for generating a new piece.

    Args:
        genre (str, optional): The genre of the piece. Defaults to "OTHER".

    Returns:
        str: The seed string.
    """
    if genre == "RANDOM" and artist == "RANDOM":
        seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}"
    elif genre == "RANDOM" and artist != "RANDOM":
        seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}"
    elif genre != "RANDOM" and artist == "RANDOM":
        seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}"
    else:
        seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}"
    return seed_string


def get_instruments(text_sequence: str) -> List[str]:
    """
    Extracts the list of instruments from a text sequence.

    Args:
        text_sequence (str): The text sequence.

    Returns:
        List[str]: The list of instruments.
    """
    instruments = []
    parts = text_sequence.split()
    for part in parts:
        if part.startswith("INST="):
            if part[5:] == "DRUMS":
                instruments.append("Drums")
            else:
                index = int(part[5:])
                instruments.append(GM_INSTRUMENTS[index])
    return instruments


def change_last_instrument( text_sequence: str, 
                           instrument: str, 
                           temp: float = 0.75, 
                           qpm: int = 120
                          ) -> Tuple[ndarray, str, Figure, str, str, str]:


    instrument_idx = instruments.index(instrument)
    #Drums
    if instrument_idx == 0:
        instrument_idx='DRUMS'
    else:
        instrument_idx = str(instrument_idx-1)
    text_sequence = text_sequence.split()
    for token_idx in reversed(range(len(text_sequence))):
        if "INST=" in text_sequence[token_idx]:
            text_sequence[token_idx] = f"INST={instrument_idx}"
            break
    text_sequence = (' ').join(text_sequence)
    #print(text_sequence)

    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm
    )
    # print(type(audio),audio)
    # print(type(midi_file),midi_file) 
    # print(type(fig),fig)
    # print(type(instruments_str),instruments_str)
    # print(type(num_tokens),num_tokens)
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
            


def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
    """
    Generates a new instrument sequence from a given seed and temperature.

    Args:
        seed (str): The seed string for the generation.
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.

    Returns:
        str: The generated instrument sequence.
    """
    seed_length = len(tokenizer.encode(seed))

    while True:
        # Encode the conditioning tokens.
        input_ids = tokenizer.encode(seed, return_tensors="pt")

        # Move the input_ids tensor to the same device as the model
        input_ids = input_ids.to(model.device)

        # Generate more tokens.
        eos_token_id = tokenizer.encode("TRACK_END")[0]
        generated_ids = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temp,
            eos_token_id=eos_token_id,
        )
        generated_sequence = tokenizer.decode(generated_ids[0])

        # Check if the generated sequence contains "NOTE_ON" beyond the seed
        new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
        if "NOTE_ON" in new_generated_sequence:
            return generated_sequence


def get_outputs_from_string(
    generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
    """
    Converts a generated sequence into various output formats including audio, MIDI, plot, etc.

    Args:
        generated_sequence (str): The generated sequence of tokens.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
                                               instruments string, and number of tokens string.
    """
    instruments = get_instruments(generated_sequence)
    instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
    note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)

    synth = note_seq.fluidsynth
    array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
    int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
    fig = note_seq.plot_sequence(note_sequence, show_figure=False)
    num_tokens = str(len(generated_sequence.split()))
    audio = gr.make_waveform((SAMPLE_RATE, int16_data))
    note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
    return audio, "midi_ouput.mid", fig, instruments_str, num_tokens


def remove_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Removes the last instrument from a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    # We split the song into tracks by splitting on 'TRACK_START'
    tracks = text_sequence.split("TRACK_START")
    # We keep all tracks except the last one
    modified_tracks = tracks[:-1]
    # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
    new_song = "TRACK_START".join(modified_tracks)

    if len(tracks) == 2:
        # There is only one instrument, so start from scratch
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_song
        )
    elif len(tracks) == 1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=""
        )
    else:
        audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
            new_song, qpm
        )

    return audio, midi_file, fig, instruments_str, new_song, num_tokens


    genre: str = "OTHER",
    artist: str = "KATE_BUSH",
    instrument: str = "Acoustic Grand Piano",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120
    
def regenerate_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Regenerates the last instrument in a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """

    def remove_last_track(text_sequence):
        tracks = text_sequence.split("TRACK_START")
        # We keep all tracks except the last one
        useful_tracks = tracks[:-1]
        # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
        text_sequence = "TRACK_START".join(useful_tracks)
        return text_sequence
    
    #last_inst_index = text_sequence.rfind("INST=")
    
    for token in reversed(text_sequence.split()):
        if 'INST=' in token:
            instrument_id = token.split('=')[1]
            break
    
    if instrument_id=="DRUMS":
        instrument="Drums"
    else:
        instrument=instruments[int(instrument_id)+1]# Index 0 instrument is 'Acoustic Grand Piano' for rendering:https://soundprogramming.net/file-formats/general-midi-instrument-list/#google_vignette

    new_seed = remove_last_track(text_sequence=text_sequence)
    
    audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
        instrument=instrument,text_sequence=new_seed, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def change_tempo(
    text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Changes the tempo of a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int): The new quarter notes per minute.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, text sequence, and number of tokens string.
    """
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens


def generate_song(
    genre: str = "OTHER",
    artist: str = "KATE_BUSH",
    instrument: str = "Acoustic Grand Piano",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Generates a song given a genre, temperature, initial text sequence, and tempo.

    Args:
        model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
        tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
        genre (str, optional): The genre of the song. Defaults to "OTHER".
        artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH".
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
        text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, generated song string, and number of tokens string.
    """
    instrument = instruments.index(instrument)
    #Drums
    if instrument == 0:
        instrument='DRUMS'
    else:
        instrument = str(instrument-1)
        
    if text_sequence == "":
        seed_string = create_seed_string(genre, artist, instrument)
    else:
        seed_string = text_sequence + " TRACK_START INST=" + instrument

    generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        generated_sequence, qpm
    )
    return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens