import gradio as gr
import tensorflow as tf
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
import sounddevice as sd
import soundfile as sf
import threading

# Load the pre-trained model
model = tf.keras.models.load_model("model.h5")

# Function to process audio, predict, and generate results
def process_audio(audio_file, breath_in_time, breath_out_time):
    try:
        # Calculate total recording duration
        total_time = breath_in_time + breath_out_time

        # Load the audio file
        y, sr = librosa.load(audio_file, sr=16000)

        # Detect segments (e.g., using energy or silence)
        intervals = librosa.effects.split(y, top_db=20)

        results = []

        plt.figure(figsize=(10, 4))
        librosa.display.waveshow(y, sr=sr, alpha=0.5)

        # Process each segment
        for i, (start, end) in enumerate(intervals):
            segment = y[start:end]
            duration = (end - start) / sr

            # Compute the amplitude (mean absolute value)
            amplitude = np.mean(np.abs(segment))

            # Extract MFCC features
            mfcc = librosa.feature.mfcc(y=segment, sr=sr, n_mfcc=13)
            mfcc = np.mean(mfcc, axis=1).reshape(1, -1)

            # Predict inhale or exhale
            prediction = model.predict(mfcc)
            label_from_model = "Inhale" if np.argmax(prediction) == 0 else "Exhale"

            # Assign label based on amplitude
            label = "Inhale" if amplitude > 0.05 else "Exhale"  # Threshold for exhale

            # Append results
            results.append({
                "Segment": i + 1,
                "Type": label,
                "Duration (s)": round(duration, 2),
                "Amplitude": round(amplitude, 4)
            })

            # Highlight segment on waveform with swapped colors
            plt.axvspan(start / sr, end / sr, color='red' if label == "Inhale" else 'blue', alpha=0.3)

        # Save the waveform with highlighted segments
        plt.title("Audio Waveform with Inhale/Exhale Segments")
        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude")
        plt.savefig("waveform_highlighted.png")
        plt.close()

        # Format results as a table
        result_table = "Segment\tType\tDuration (s)\tAmplitude\n" + "\n".join(
            f"{row['Segment']}\t{row['Type']}\t{row['Duration (s)']}\t{row['Amplitude']}" for row in results
        )

        return result_table, "waveform_highlighted.png"

    except Exception as e:
        return f"Error: {str(e)}", None

# Function to record audio for a specified duration
def record_audio(duration):
    try:
        # Define the file name
        audio_file = "recorded_audio.wav"

        # Record audio
        print(f"Recording for {duration} seconds...")
        recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32')
        sd.wait()  # Wait until recording is finished
        sf.write(audio_file, recording, 16000)
        print("Recording complete!")

        return audio_file
    except Exception as e:
        return f"Error: {str(e)}"

# Function to animate the scaling circle (during recording)
def create_circle_animation(duration):
    # HTML and CSS to create the animated circle
    circle_animation = f"""
    <div id="circle-container" style="text-align:center; margin-top: 50px;">
        <div id="circle" style="width: 50px; height: 50px; border-radius: 50%; background-color: #3498db; animation: scale-up-down {duration}s infinite;"></div>
    </div>
    <style>
        @keyframes scale-up-down {{
            0% {{ transform: scale(1); }}
            50% {{ transform: scale(1.5); }}
            100% {{ transform: scale(1); }}
        }}
    </style>
    """
    return circle_animation

# Define Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("### Breathe Training Application")

    with gr.Row():
        breath_in_time = gr.Number(label="Breath In Time (seconds)", value=0, interactive=True)
        breath_out_time = gr.Number(label="Breath Out Time (seconds)", value=0, interactive=True)

    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Upload Audio (optional)")
        result_output = gr.Textbox(label="Prediction Results (Table)")
        waveform_output = gr.Image(label="Waveform with Highlighted Segments")
        circle_output = gr.HTML(label="Breath Circle Animation")

    submit_button = gr.Button("Analyze")

    def handle_record_and_analyze(breath_in, breath_out):
        total_duration = breath_in + breath_out
        circle_animation = create_circle_animation(total_duration)
        audio_file = record_audio(total_duration)
        result_table, waveform_image = process_audio(audio_file, breath_in, breath_out)
        return circle_animation, result_table, waveform_image

    submit_button.click(
        fn=process_audio,
        inputs=[audio_input, breath_in_time, breath_out_time],
        outputs=[result_output, waveform_output],
    )

# Run the Gradio app
demo.launch()
# re load