from transformers import Pipeline import torch from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches from torchvision.transforms import transforms from huggingface_hub import hf_hub_download import io class FormFieldDetectionPipeline(Pipeline): def __init__(self, model, tokenizer=None, **kwargs): super().__init__(model=model, tokenizer=tokenizer, **kwargs) self.confidence_threshold = 0.8 def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} postprocess_kwargs = {} # Allow confidence threshold to be configured if "confidence_threshold" in kwargs: postprocess_kwargs["confidence_threshold"] = kwargs["confidence_threshold"] return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, image): if isinstance(image, str): image = Image.open(image).convert("RGB") elif not isinstance(image, Image.Image): raise ValueError("Input must be an image path or PIL Image") transform = transforms.ToTensor() image_tensor = transform(image) return {"image_tensor": image_tensor, "original_image": image} def _forward(self, model_inputs): image_tensor = model_inputs["original_image"] with torch.no_grad(): predictions = self.model([image_tensor]) return {"predictions": predictions[0], "original_image": image_tensor} def postprocess(self, model_outputs, confidence_threshold=0.8): predictions = model_outputs["predictions"] original_image = model_outputs["original_image"] # Filter predictions by confidence mask = predictions["scores"] > confidence_threshold boxes = predictions["boxes"][mask] labels = predictions["labels"][mask] # Create visualization plt.figure(figsize=(12, 8)) plt.imshow(original_image) # Draw boxes for fields (odd-numbered labels) for box, label in zip(boxes, labels): if label % 2 == 1: # Only odd numbered labels are fields x1, y1, x2, y2 = box.tolist() rect = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor="r", facecolor="none" ) plt.gca().add_patch(rect) plt.axis("off") # Convert plot to image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) output_image = Image.open(buf) return { "image": output_image, "boxes": boxes.tolist(), "labels": labels.tolist() } # Add this to your model's repo def pipeline(): return FormFieldDetectionPipeline( model="AaronNL/form-field-detector", task="object-detection" )