# Importing libraries

import gradio as gr
from asteroid.models import ConvTasNet, DPRNNTasNet
import torch
import os
import shutil


# ------------------ #

class tester():
    def __init__(self, model):
        # Modeling
        self.model = model
        # Test Directory will also contain output after files
        self.test_dir = 'asset/test_subject/'
    def prepare_test(self):
        # Removing older test and their results
        if os.path.exists(self.test_dir):
            shutil.rmtree(self.test_dir)
        
        if not os.path.exists(self.test_dir):    
            os.mkdir(self.test_dir)
    def test(self, path):
        self.prepare_test()

        test_subject = self.test_dir + 'test.wav'
        shutil.copyfile(path, test_subject)

        self.model.separate(test_subject,force_overwrite=True, resample=True)

def load_model():
    model = torch.load('asset/model/model_two.bin')
    return model

def separator(original_audio, path):
    Test = tester(load_model())
    Test.prepare_test()
    Test.test(path)

    separated_audios = list()
    separated_audios.append('asset/test_subject/test.wav')
    separated_audios.append('asset/test_subject/test_est1.wav')
    separated_audios.append('asset/test_subject/test_est2.wav')

    return separated_audios

demo = gr.Blocks(theme=gr.themes.Soft())
with demo:
    
    gr.Markdown('''
    <center>
        <h1>Speech Separation</h1>
        <div style="display:flex;align-items:center;justify-content:center;">
            <iframe src="https://streamable.com/e/uribry?autoplay=1&nocontrols=1" frameborder="0" allow="autoplay">
            </iframe>
        </div>
        <div></div>
        <p>
            It is a shareable demonstration window which can be used to view result on any device by setting 'share' a launch parameter 'True'.
            It displays original audio for mixture of speaker, seperated audio by our model and original individual speaker audio.
        </p>
    </center>
    ''')

    with gr.Row():
        pass
    with gr.Row():
        pass
    gr.Markdown('''
        <h2> Original Audio</h2>
    
    ''')
    with gr.Row():
        output_text1 = gr.Text("Original Speech signal ", label='Original Audio', interactive=False)
        original_audio = gr.Audio(label='Original Audio', interactive=False)

    with gr.Row():
        pass
    with gr.Row():
        pass
    gr.Markdown('''
        <h2> Separated Audio</h2>
    
    ''')

    with gr.Row():
        output_text1 = gr.Text("Separated Speech signal Speaker 1 ", label='Speaker 1', interactive=False)
        output_audio1 = gr.Audio(label='Speaker 1', interactive=False)
    with gr.Row():
        output_text2 = gr.Text("Separated Speech signal Speaker 2 ", label='Speaker 2', interactive=False)
        output_audio2 = gr.Audio(label='Speaker 2', interactive=False)
    

    outputs_audio = [original_audio, output_audio1, output_audio2]
    button = gr.Button("Separate")
    examples = [
        "asset/test/mix_clean/Audio0.wav",
        "asset/test/mix_clean/Audio1.wav",
        "asset/test/mix_clean/Audio2.wav",
        "asset/test/mix_clean/Audio3.wav"
    ]

    example_selector = gr.inputs.Radio(examples, label="Example Audio")
    button.click(separator, inputs=[original_audio, example_selector], outputs=outputs_audio)

    gr.Markdown('''
    <center>
        <div style="display:flex;align-items:center;justify-content:center;">
            <a href="https://www.linkedin.com/in/dhruv73/" target="blank">
                <img src="https://raw.githubusercontent.com/devicons/devicon/1119b9f84c0290e0f0b38982099a2bd027a48bf1/icons/linkedin/linkedin-original.svg" alt="LinkedIN: /dhruv_73" width="100" height="100"/> 
            </a>
            &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp
            <a href="https://github.com/DS-73" target="_blank"> 
                <img src="https://raw.githubusercontent.com/devicons/devicon/1119b9f84c0290e0f0b38982099a2bd027a48bf1/icons/github/github-original.svg" alt="Github: /DS-73" width="100" height="100"/> 
            </a>
        </div>
    </center>
    ''')

demo.launch()



# ------------------ #