|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- image-classification |
|
|
- checkbox-detection |
|
|
- computer-vision |
|
|
- pytorch |
|
|
datasets: |
|
|
- wendys-llc/chkbx |
|
|
metrics: |
|
|
- accuracy |
|
|
library_name: pytorch |
|
|
--- |
|
|
|
|
|
# Checkbox State Classifier |
|
|
|
|
|
This model classifies whether a checkbox is checked or |
|
|
unchecked. |
|
|
|
|
|
## Model Details |
|
|
- **Architecture**: EfficientNetV2-S (PyTorch) |
|
|
- **Input Size**: 128x128 RGB images |
|
|
- **Output**: Binary classification (unchecked: 0, checked: 1) |
|
|
- **Validation Accuracy**: 97.1% |
|
|
- **Training**: Mixed precision on A100 GPU |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from huggingface_hub import hf_hub_download |
|
|
import torch.nn as nn |
|
|
from torchvision.models import efficientnet_v2_s, |
|
|
EfficientNet_V2_S_Weights |
|
|
|
|
|
# Define model architecture |
|
|
class EfficientNetV2Classifier(nn.Module): |
|
|
def __init__(self, num_classes=2, dropout_rate=0.3): |
|
|
super().__init__() |
|
|
self.backbone = efficientnet_v2_s(weights=EfficientNet |
|
|
_V2_S_Weights.IMAGENET1K_V1) |
|
|
num_features = self.backbone.classifier[1].in_features |
|
|
self.backbone.classifier = nn.Sequential( |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(num_features, 512), |
|
|
nn.SiLU(inplace=True), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(512, 256), |
|
|
nn.SiLU(inplace=True), |
|
|
nn.BatchNorm1d(256), |
|
|
nn.Dropout(dropout_rate/2), |
|
|
nn.Linear(256, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |
|
|
|
|
|
# Download and load model |
|
|
model_path = hf_hub_download(repo_id="wendys-llc/checkbox-classifier", |
|
|
filename="checkbox_classifier.pth") |
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
model = EfficientNetV2Classifier(num_classes=2) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Image preprocessing |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
# Predict |
|
|
def predict(image_path): |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(output, |
|
|
dim=1) |
|
|
predicted = torch.argmax(probabilities, dim=1).item() |
|
|
confidence = probabilities[0][predicted].item() |
|
|
|
|
|
labels = {0: "unchecked", 1: "checked"} |
|
|
return labels[predicted], confidence |
|
|
|
|
|
# Example usage |
|
|
result, conf = predict("checkbox.jpg") |
|
|
print(f"Result: {result} (confidence: {conf:.1%})") |
|
|
|
|
|
Training Dataset |
|
|
|
|
|
Trained on https://huggingface.co/datasets/wendys-llc/chkbx |
|
|
dataset containing ~6,000 annotated checkbox images. |
|
|
|
|
|
Limitations |
|
|
|
|
|
- Trained specifically on UI checkboxes, may not work well on |
|
|
hand-drawn checkmarks |
|
|
- Best performance on clear, high-contrast checkbox images |
|
|
- Input images are resized to 128x128, very small checkboxes |
|
|
may lose detail |
|
|
|