import os
import json
import gradio as gr

from utils import (
    generate_song,
    remove_last_instrument,
    regenerate_last_instrument,
    change_last_instrument,
    change_tempo,
)


os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"



# Genrs
genres = ["ROCK", "POP", "OTHER", "R&B/SOUL", "JAZZ", "ELECTRONIC", "RANDOM"]

# Artists
with open('artist_names.json', 'r') as f:
    artist_names = json.load(f)
#print("Loaded Artists names:", artist_names)

# Instruments
with open('instruments.json', 'r') as f:
    instruments = json.load(f)


demo = gr.Blocks()


def run():
    with demo:
        gr.DuplicateButton(value="Duplicate Space for private use")
        with gr.Row():
            with gr.Column():
                temp = gr.Slider(
                    minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
                )
                genre = gr.Dropdown(
                    choices=genres, value="POP", label="Select the genre"
                )
                artist = gr.Dropdown(
                    choices=artist_names, value=artist_names[0], label="Select the artist style"
                )
                instrument = gr.Dropdown(
                    choices=instruments, value=instruments[0], label="Select the instrument to be generated"
                )
                with gr.Row():
                    btn_from_scratch = gr.Button("๐Ÿงน Start from scratch")
                    btn_continue = gr.Button("โžก๏ธ Generate New Track")
                    btn_remove_last = gr.Button("โ†ฉ๏ธ Remove last instrument")
                    btn_regenerate_last = gr.Button("๐Ÿ”„ Regenerate last instrument")
                    btn_change_last = gr.Button("๐Ÿ”„ Swap last instrument")
            with gr.Column():
                with gr.Group():
                    audio_output = gr.Video(show_share_button=True)
                    midi_file = gr.File()
                    with gr.Row():
                        qpm = gr.Slider(
                            minimum=60, maximum=140, step=10, value=120, label="Tempo"
                        )
                        btn_qpm = gr.Button("Change Tempo")
        with gr.Row():
            with gr.Column():
                plot_output = gr.Plot()
            with gr.Column():
                instruments_output = gr.Markdown("# List of generated instruments")
        with gr.Row():
            text_sequence = gr.Text()
            empty_sequence = gr.Text(visible=False)
        with gr.Row():
            num_tokens = gr.Text(visible=False)
        btn_from_scratch.click(
            fn=generate_song,
            inputs=[genre, artist, instrument, temp, empty_sequence, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )
        btn_continue.click(
            fn=generate_song,
            inputs=[genre, artist, instrument, temp, text_sequence, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )
        btn_remove_last.click(
            fn=remove_last_instrument,
            inputs=[text_sequence, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )
        btn_regenerate_last.click(
            fn=regenerate_last_instrument,
            inputs=[text_sequence, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )
        btn_change_last.click(
            fn=change_last_instrument,
            inputs=[text_sequence, instrument, temp, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )
        btn_qpm.click(
            fn=change_tempo,
            inputs=[text_sequence, qpm],
            outputs=[
                audio_output,
                midi_file,
                plot_output,
                instruments_output,
                text_sequence,
                num_tokens,
            ],
        )

    demo.launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    run()