import streamlit as st
import tensorflow as tf
import numpy as np
import librosa
import matplotlib.pyplot as plt
import librosa.display
import tempfile
import os

# Load the trained model #test
@st.cache_resource
def load_model():
    model_path = "sound_classification_model.h5"  # Replace with the path to your .h5 file
    model = tf.keras.models.load_model(model_path)
    return model

model = load_model()

# Map Class Labels
CLASS_LABELS = {
    0: 'Air Conditioner',
    1: 'Car Horn',
    2: 'Children Playing',
    3: 'Dog Bark',
    4: 'Drilling',
    5: 'Engine Idling',
    6: 'Gun Shot',
    7: 'Jackhammer',
    8: 'Siren',
    9: 'Street Music'
}

# Preprocess audio into a spectrogram
def preprocess_audio(file_path, n_mels=128, fixed_time_steps=128):
    try:
        y, sr = librosa.load(file_path, sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmax=sr / 2)
        log_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
        log_spectrogram = log_spectrogram / np.max(np.abs(log_spectrogram))
        if log_spectrogram.shape[1] < fixed_time_steps:
            padding = fixed_time_steps - log_spectrogram.shape[1]
            log_spectrogram = np.pad(log_spectrogram, ((0, 0), (0, padding)), mode='constant')
        else:
            log_spectrogram = log_spectrogram[:, :fixed_time_steps]
        return np.expand_dims(log_spectrogram, axis=-1)  # Add channel dimension for CNNs
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

# Streamlit app UI
st.title("Audio Spectrogram Prediction")
st.write("Upload an audio file to generate a spectrogram and predict its class using your trained model.")

# File upload widget
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3"])

if uploaded_file is not None:
    # Save the uploaded audio file to a temporary location
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
        temp_audio_file.write(uploaded_file.read())
        temp_audio_path = temp_audio_file.name

    # Preprocess the audio into a spectrogram
    st.write("Processing audio into a spectrogram...")
    spectrogram = preprocess_audio(temp_audio_path)
    
    if spectrogram is not None:
        # Display the spectrogram
        st.write("Generated Spectrogram:")
        plt.figure(figsize=(10, 4))
        librosa.display.specshow(spectrogram[:, :, 0], sr=22050, x_axis='time', y_axis='mel', fmax=8000, cmap='plasma')
        plt.colorbar(format='%+2.0f dB')
        plt.title('Mel-Spectrogram')
        plt.tight_layout()
        st.pyplot(plt)

        # Predict using the model
        st.write("Predicting...")
        spectrogram = np.expand_dims(spectrogram, axis=0)  # Add batch dimension
        predictions = model.predict(spectrogram)
        predicted_class_index = np.argmax(predictions, axis=-1)[0]
        predicted_class_label = CLASS_LABELS.get(predicted_class_index, "Unknown")

        # Display the results
        st.write("Prediction Results:")
        st.write(f"**Predicted Class:** {predicted_class_label} (Index: {predicted_class_index})")
        st.write(f"**Raw Model Output:** {predictions}")
    else:
        st.write("Failed to process the audio file. Please try again with a different file.")

    # Optional: Clean up temporary file
    os.remove(temp_audio_path)



# st.write("### Developer Team")
# developer_info = [
#     {"name": "Faheyra", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21200/user/icon/remui/f3?rev=40826", "title": "MLetops Engineer"},
#     {"name": "Adilah", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21229/user/icon/remui/f3?rev=43498", "title": "Ra-Sis-Chear"},
#     {"name": "Aida", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21236/user/icon/remui/f3?rev=43918", "title": "Ra-Sis-Chear"},
#     {"name": "Naufal", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21260/user/icon/remui/f3?rev=400622", "title": "Rizzichear"},
#     {"name": "Fadzwan", "image_url": "https://italeemc.iium.edu.my/pluginfile.php/21094/user/icon/remui/f3?rev=59457", "title": "Nasser"},
# ]

# # Dynamically create columns based on the number of developers
# num_devs = len(developer_info)
# cols = st.columns(num_devs)

# # Display the developer profiles
# for idx, dev in enumerate(developer_info):
#     col = cols[idx]
#     with col:
#         st.markdown(
#             f'<div style="display: flex; flex-direction: column; align-items: center;">'
#             f'<img src="{dev["image_url"]}" width="100" style="border-radius: 50%;">'
#             f'<p>{dev["name"]}<br>{dev["title"]}</p>'
#             f'<p></p>'
#             f'</div>',
#             unsafe_allow_html=True
#         )