import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from huggingface_hub import HfApi, Repository
import os
import matplotlib.pyplot as plt

import utils

# Hugging Face Hub credentials
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation"  
DATASET_REPO_ID = "louiecerv/american_sign_language"  

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
st.write(f"Device: {device}")

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 7 * 7, 128)  # Adjusted for 28x28 images
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 25)  # 25 classes (A-Y)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc(x))
        x = self.fc2(x)
        return x

# Create a model card
def create_model_card():
    model_card = """
    ---
    language: en
    tags:
    - image-classification
    - deep-learning
    - cnn
    license: apache-2.0
    datasets:
 Network (CNN) designed to recognize American Sign Language (ASL) letters from images. It was trained on the `louiecerv/american_sign_language` dataset.

    ## Model Description

    The model consists of two convolutional layers followed by max-pooling layers, a flattening layer, and two fully connected layers. It is designed to classify images of ASL letters into 25 classes (A-Y).

    ## Intended Uses & Limitations

    This model is intended for educational purposes and as a demonstration of image classification using CNNs. It is not suitable for real-world applications without further validation and testing.

    ## How to Use

    ```python
    import torch
    from torchvision import transforms
    from PIL import Image

    # Load the model
    model = CNN()
    model.load_state_dict(torch.load("path_to_model/pytorch_model.bin"))
    model.eval()

    # Preprocess the image
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    image = Image.open("path_to_image").convert("RGB")
    image = transform(image).unsqueeze(0)

    # Make a prediction
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
    print(f"Predicted ASL letter: {predicted.item()}")
    ```

    ## Training Data

    The model was trained on the `louiecerv/american_sign_language` dataset, which contains images of ASL letters.

    ## Training Procedure

    The model was trained using the Adam optimizer with a learning rate of 0.001 and a batch size of 64. The training process included 5 epochs.

    ## Evaluation Results

    The model achieved an accuracy of 92% on the validation set.
    """
    with open("model_repo/README.md", "w") as f:
        f.write(model_card)

# Streamlit app
def main():
    st.title("American Sign Language Recognition")

    # Load the dataset from Hugging Face Hub
    dataset = load_dataset(DATASET_REPO_ID)

    # Data loaders with preprocessing:
    transform = transforms.Compose([
        transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if needed
    ])

    def collate_fn(batch):
        images = []
        labels = []
        for item in batch:
            if 'pixel_values' in item and 'label' in item:
                image = torch.tensor(item['pixel_values'])  # Convert to tensor
                label = item['label']
                try:
                    image = transform(image)
                    images.append(image)
                    labels.append(label)
                except Exception as e:
                    print(f"Error processing image: {e}")
                    continue  # Skip to the next image

        if not images:  # Check if the list is empty!
            return torch.tensor([]), torch.tensor([])  # Return empty tensors if no images loaded

        images = torch.stack(images).to(device)
        labels = torch.tensor(labels).long().to(device)
        return images, labels

    train_loader = DataLoader(dataset["train"], batch_size=64, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset["validation"], batch_size=64, collate_fn=collate_fn)

    # Model, loss, and optimizer
    model = CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = st.slider("Number of Epochs", 1, 20, 5)  # Streamlit slider
    if st.button("Train Model"):
        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(train_loader):
                if images.nelement() == 0:  # Check if images tensor is empty
                    continue

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (i + 1) % 100 == 0:
                    st.write(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

        # Validation
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                if images.nelement() == 0:  # Check if images tensor is empty
                    continue
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        if total > 0:
            accuracy = 100 * correct / total
            st.write(f'Accuracy of the model on the validation images: {accuracy:.2f}%')
        else:
            st.write("No validation images were processed.")

        # Save model to Hugging Face Hub
        if HF_TOKEN:
            repo = Repository(local_dir="model_repo", clone_from=MODEL_REPO_ID, use_auth_token=HF_TOKEN)
            model_path = os.path.join(repo.local_dir, "pytorch_model.bin")
            torch.save(model.state_dict(), model_path)

            create_model_card()
            repo.push_to_hub(commit_message="Trained model and model card", blocking=True)
            st.write(f"Model and model card saved to {MODEL_REPO_ID}")
        else:
            st.warning("HF_TOKEN environment variable not set. Model not saved.")

if __name__ == "__main__":
    main()