import torch
import gradio as gr
import torchvision.transforms as transforms
from neural_network import MNISTNetwork


transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize the image
])

# Load the trained model   
net = MNISTNetwork() 
net.load_state_dict(torch.load('MNISTModel.pth'))
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

def predict(drawing): 
    if drawing is None: 
        return "Draw a number hoe"
    
    input_tensor = transform(drawing)
    x = input_tensor
    # x = input_tensor.view(input_tensor.shape[0], -1)
    
    with torch.no_grad():
        output = net(x)

        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        values, indices = torch.topk(probabilities, 10)
        results = {LABELS[i]: v.item() for i, v in zip(indices, values)}
        
        return results


sketchpad_input = gr.Sketchpad(shape=(28, 28))
interface = gr.Interface(
    fn=predict, 
    inputs=sketchpad_input,
    outputs="label",
    live=True
)
interface.launch()