import os 

os.system("apt-get update")
os.system("apt-get install -y python3-pip") # Make sure pip is available
os.system("pip install transformers")

 # Restart the kernel here if you have the option (in a notebook setting)
import transformers
from torch.utils.data import DataLoader


import streamlit as st
from datasets import load_dataset, Audio
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import torch
import os
# Install using apt

# Load the MInDS-14 dataset
dataset = load_dataset("PolyAI/minds14", "en-US", split="train", trust_remote_code=True)

# Load pretrained model and feature extractor
model = AutoModelForAudioClassification.from_pretrained("facebook/wav2vec2-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

# Resample audio to 16kHz
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

# Preprocessing function
def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=16000,
        padding=True,
        max_length=100000,
        truncation=True,
    )
    return inputs


dataset = dataset.map(preprocess_function, batched=True)
dataset = dataset.rename_column("intent_class", "labels")
dataset = dataset.set_format(type="torch", columns=["input_values", "labels"])

# Create DataLoader
batch_size = 4  # Adjust as needed
dataloader = DataLoader(dataset, batch_size=batch_size)

# Set device and move model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop (example)
num_epochs = 2 # Keep small for testing on Spaces!
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(num_epochs):
    for batch in dataloader:
        input_values = batch["input_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

# Streamlit UI
st.title("Audio Classification with Minds14")
st.write("Training complete!") # You'll want to add more insightful outputs here eventually


st.markdown("""
<div class="mt-4">
 <div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-3 md:gap-y-4 md:gap-x-5">
    <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="#audio"> <div class="w-full text-center bg-gradient-to-r from-violet-300 via-sky-400 to-green-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Audio</div>
     <p class="text-gray-700">Resample an audio dataset and get it ready for a model to classify what type of banking issue a speaker is calling about.</p>
    </a>
    <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="#vision"> <div class="w-full text-center bg-gradient-to-r from-pink-400 via-purple-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Vision</div>
     <p class="text-gray-700">Apply data augmentation to an image dataset and get it ready for a model to diagnose disease in bean plants.</p>
    </a>
     <a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="#nlp"> <div class="w-full text-center bg-gradient-to-r from-orange-300 via-red-400 to-violet-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">NLP</div>
     <p class="text-gray-700">Tokenize a dataset and get it ready for a model to determine whether a pair of sentences have the same meaning.</p>
    </a>
 </div>
</div>
<div class="mt-4">  </div>
<p>
Check out <a href="https://huggingface.co/course/chapter5/1?fw=pt">Chapter 5</a> of the Hugging Face course to learn more about other important topics such as loading remote or local datasets, tools for cleaning up a dataset, and creating your own dataset.
</p>
""", unsafe_allow_html=True)