|
from transformers import Pipeline |
|
import torch |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
from torchvision.transforms import transforms |
|
from huggingface_hub import hf_hub_download |
|
import io |
|
|
|
class FormFieldDetectionPipeline(Pipeline): |
|
def __init__(self, model, tokenizer=None, **kwargs): |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
self.confidence_threshold = 0.8 |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
postprocess_kwargs = {} |
|
|
|
|
|
if "confidence_threshold" in kwargs: |
|
postprocess_kwargs["confidence_threshold"] = kwargs["confidence_threshold"] |
|
|
|
return preprocess_kwargs, {}, postprocess_kwargs |
|
|
|
def preprocess(self, image): |
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
elif not isinstance(image, Image.Image): |
|
raise ValueError("Input must be an image path or PIL Image") |
|
|
|
transform = transforms.ToTensor() |
|
image_tensor = transform(image) |
|
|
|
return {"image_tensor": image_tensor, "original_image": image} |
|
|
|
def _forward(self, model_inputs): |
|
image_tensor = model_inputs["original_image"] |
|
|
|
with torch.no_grad(): |
|
predictions = self.model([image_tensor]) |
|
|
|
return {"predictions": predictions[0], "original_image": image_tensor} |
|
|
|
def postprocess(self, model_outputs, confidence_threshold=0.8): |
|
predictions = model_outputs["predictions"] |
|
original_image = model_outputs["original_image"] |
|
|
|
|
|
mask = predictions["scores"] > confidence_threshold |
|
boxes = predictions["boxes"][mask] |
|
labels = predictions["labels"][mask] |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
plt.imshow(original_image) |
|
|
|
|
|
for box, label in zip(boxes, labels): |
|
if label % 2 == 1: |
|
x1, y1, x2, y2 = box.tolist() |
|
rect = patches.Rectangle( |
|
(x1, y1), x2 - x1, y2 - y1, |
|
linewidth=1, edgecolor="r", facecolor="none" |
|
) |
|
plt.gca().add_patch(rect) |
|
|
|
plt.axis("off") |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
buf.seek(0) |
|
output_image = Image.open(buf) |
|
|
|
return { |
|
"image": output_image, |
|
"boxes": boxes.tolist(), |
|
"labels": labels.tolist() |
|
} |
|
|
|
|
|
def pipeline(): |
|
return FormFieldDetectionPipeline( |
|
model="AaronNL/form-field-detector", |
|
task="object-detection" |
|
) |