license: mit tags: - vision - image-segmentation - instance-segmentation datasets: - custom-germination-dataset widget: - src: https://example.com/path/to/germination-image1.jpg example_title: Germination Image 1 - src: https://example.com/path/to/germination-image2.jpg example_title: Germination Image 2

GermiNet: A MaskFormer Model for Germination Counting

GermiNet model trained on a custom germination dataset for instance segmentation (small-sized version, Swin backbone). It is based on the MaskFormer architecture introduced in the paper MaskFormer: Per-Pixel Classification is Not All You Need for Semantic Segmentation and first released in this repository.

Disclaimer: This model card is written by [Your Name/Organization] with assistance from Grok (xAI) for the Hugging Face community.

Model Description

GermiNet is a MaskFormer-based instance segmentation model fine-tuned to detect and segment "normal" and "abnormal" seeds in germination images. It uses the facebook/maskformer-swin-tiny-coco pre-trained checkpoint as its backbone, with a Swin-Tiny transformer architecture. The model predicts a set of masks and corresponding labels for three classes: "background," "normal," and "abnormal," with an additional "no object" class handled internally by MaskFormer. The model was trained on a small custom dataset as a proof-of-concept for automating germination counting in agricultural research.

model image

Intended Uses & Limitations

GermiNet is intended for instance segmentation tasks in agricultural research, specifically for detecting and segmenting "normal" and "abnormal" seeds in germination images. It can be used with tools like CVAT for automated annotation workflows.

Limitations

  • The model was trained on a small dataset (18 images), which limits its generalization.
  • Local inference shows a bias toward "no object" predictions, with few "normal" detections and no "abnormal" detections, indicating underfitting.
  • Mask resolution is 56x56 (upscaled to 224x224 or higher for visualization), which may miss fine details.
  • The model requires further training with a larger dataset and more epochs for improved performance.

See the model hub to look for other fine-tuned versions if needed.

How to Use

Here’s how to use this model for instance segmentation:

import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, MaskFormerForInstanceSegmentation

# Load GermiNet fine-tuned on custom germination dataset
processor = AutoImageProcessor.from_pretrained("your-username/germi-net")
model = MaskFormerForInstanceSegmentation.from_pretrained("your-username/germi-net")

# Load an image (replace with your image URL or local path)
url = "https://example.com/path/to/germination-image.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# Alternatively, use a local image
# image = Image.open("path/to/your/image.jpg")
inputs = processor(images=image, return_tensors="pt")

# Run inference
with torch.no_grad():
    outputs = model(**inputs)

# Model predicts class_queries_logits and masks_queries_logits
class_queries_logits = outputs.class_queries_logits  # Shape: (batch_size, num_queries, num_classes + 1)
masks_queries_logits = outputs.masks_queries_logits  # Shape: (batch_size, num_queries, height, width)

# Post-process predictions
predicted_classes = class_queries_logits.argmax(-1).cpu().numpy()
mask_predictions = masks_queries_logits.sigmoid().cpu().numpy()
binary_masks = (mask_predictions > 0.5).astype(np.uint8)

# Map predictions to labels
id2label = {0: "background", 1: "normal", 2: "abnormal", 3: "no object"}
predicted_labels = [id2label[cls] for cls in predicted_classes[0]]
print("Predicted labels:", predicted_labels)

# Optional: Visualize (requires matplotlib and cv2)
import numpy as np
import matplotlib.pyplot as plt
import cv2

visualization_size = (800, 800)
resized_masks = np.zeros((binary_masks.shape[1], *visualization_size), dtype=np.uint8)
for i in range(binary_masks.shape[1]):
    resized_masks[i] = cv2.resize(binary_masks[0, i], visualization_size, interpolation=cv2.INTER_NEAREST)

image_np = np.array(image)
aspect_ratio = image_np.shape[1] / image_np.shape[0]
new_height = visualization_size[0]
new_width = int(new_height * aspect_ratio)
resized_image = cv2.resize(image_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
if new_width != visualization_size[1]:
    start_x = (new_width - visualization_size[1]) // 2
    resized_image = resized_image[:, start_x:start_x + visualization_size[1]]
Downloads last month
0
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.