Reciept_Analyzer / models.py
Huy0502's picture
Upload 3 files
8aca528 verified
import os
import torch
import numpy as np
from ultralytics import YOLO
from transformers import AutoProcessor
from transformers import AutoModelForTokenClassification
from utils import normalize_box, unnormalize_box, draw_output, create_df
from PIL import Image, ImageDraw
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
class Reciept_Analyzer:
def __init__(self,
processor_pretrained='microsoft/layoutlmv3-base',
layoutlm_pretrained=os.path.join(
'models', 'checkpoint'),
yolo_pretrained=os.path.join(
'models', 'best.pt'),
vietocr_pretrained=os.path.join(
'models', 'vietocr', 'vgg_seq2seq.pth')
):
print("Initializing processor")
if torch.cuda.is_available():
print("Using GPU")
else:
print("No GPU detected, using CPU")
self.processor = AutoProcessor.from_pretrained(
processor_pretrained, apply_ocr=False)
print("Finished initializing processor")
print("Initializing LayoutLM model")
self.lalm_model = AutoModelForTokenClassification.from_pretrained(
layoutlm_pretrained)
print("Finished initializing LayoutLM model")
if yolo_pretrained is not None:
print("Initializing YOLO model")
self.yolo_model = YOLO(yolo_pretrained)
print("Finished initializing YOLO model")
print("Initializing VietOCR model")
config = Cfg.load_config_from_name('vgg_seq2seq')
config['weights'] = vietocr_pretrained
config['cnn']['pretrained']= False
config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.vietocr = Predictor(config)
print("Finished initializing VietOCR model")
def forward(self, img, output_path="output", is_save_cropped_img=False):
input_image = Image.open(img)
# detection with YOLOv8
bboxes = self.yolov8_det(input_image)
# sort
sorted_bboxes = self.sort_bboxes(bboxes)
# draw bbox
image_draw = input_image.copy()
self.draw_bbox(image_draw, sorted_bboxes, output_path)
# crop images
cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path)
# recognition with VietOCR
texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes)
# KIE with LayoutLMv3
pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path)
# create dataframe
return create_df(pred_texts, pred_label)
def yolov8_det(self, img):
return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int()
def sort_bboxes(self, bboxes):
bbox_list = []
for box in bboxes:
tlx, tly, brx, bry = map(int, box)
bbox_list.append([tlx, tly, brx, bry])
bbox_list.sort(key=lambda x: (x[1], x[2]))
return bbox_list
def draw_bbox(self, image_draw, bboxes, output_path):
# draw bbox
draw = ImageDraw.Draw(image_draw)
for box in bboxes:
draw.rectangle(box, outline='red', width=2)
image_draw.save(os.path.join(output_path, 'bbox.jpg'))
print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}")
def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"):
normalized_boxes = []
cropped_images = []
# OCR
if is_save_cropped:
cropped_folder = os.path.join(output_path, "cropped")
if not os.path.exists(cropped_folder):
os.makedirs(cropped_folder)
i = 0
for box in bboxes:
tlx, tly, brx, bry = map(int, box)
normalized_box = normalize_box(box, input_image.width, input_image.height)
normalized_boxes.append(normalized_box)
cropped_ = input_image.crop((tlx, tly, brx, bry))
if is_save_cropped:
cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg'))
i += 1
cropped_images.append(cropped_)
return cropped_images, normalized_boxes
def ocr(self, cropped_images, normalized_boxes):
mapping_bbox_texts = {}
texts = []
for img, normalized_box in zip(cropped_images, normalized_boxes):
result = self.vietocr.predict(img)
text = result.strip().replace('\n', ' ')
texts.append(text)
mapping_bbox_texts[','.join(map(str, normalized_box))] = text
return texts, mapping_bbox_texts
def kie(self, img, texts, boxes, mapping_bbox_texts, output_path):
encoding = self.processor(img, texts,
boxes=boxes,
return_offsets_mapping=True,
return_tensors='pt',
max_length=512,
padding='max_length')
offset_mapping = encoding.pop('offset_mapping')
with torch.no_grad():
outputs = self.lalm_model(**encoding)
id2label = self.lalm_model.config.id2label
logits = outputs.logits
token_boxes = encoding.bbox.squeeze().tolist()
offset_mapping = offset_mapping.squeeze().tolist()
predictions = logits.argmax(-1).squeeze().tolist()
is_subword = np.array(offset_mapping)[:, 0] != 0
true_predictions = []
true_boxes = []
true_texts = []
for idx in range(len(predictions)):
if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]:
true_predictions.append(id2label[predictions[idx]])
true_boxes.append(unnormalize_box(
token_boxes[idx], img.width, img.height))
true_texts.append(mapping_bbox_texts.get(
','.join(map(str, token_boxes[idx])), ''))
if isinstance(output_path, str):
os.makedirs(output_path, exist_ok=True)
img_output = draw_output(
image=img,
true_predictions=true_predictions,
true_boxes=true_boxes
)
img_output.save(os.path.join(output_path, 'result.jpg'))
print(f"Exported result to {os.path.join(output_path, 'result.jpg')}")
return true_texts, true_predictions, true_boxes