Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
from ultralytics import YOLO | |
from transformers import AutoProcessor | |
from transformers import AutoModelForTokenClassification | |
from utils import normalize_box, unnormalize_box, draw_output, create_df | |
from PIL import Image, ImageDraw | |
from vietocr.tool.predictor import Predictor | |
from vietocr.tool.config import Cfg | |
class Reciept_Analyzer: | |
def __init__(self, | |
processor_pretrained='microsoft/layoutlmv3-base', | |
layoutlm_pretrained=os.path.join( | |
'models', 'checkpoint'), | |
yolo_pretrained=os.path.join( | |
'models', 'best.pt'), | |
vietocr_pretrained=os.path.join( | |
'models', 'vietocr', 'vgg_seq2seq.pth') | |
): | |
print("Initializing processor") | |
if torch.cuda.is_available(): | |
print("Using GPU") | |
else: | |
print("No GPU detected, using CPU") | |
self.processor = AutoProcessor.from_pretrained( | |
processor_pretrained, apply_ocr=False) | |
print("Finished initializing processor") | |
print("Initializing LayoutLM model") | |
self.lalm_model = AutoModelForTokenClassification.from_pretrained( | |
layoutlm_pretrained) | |
print("Finished initializing LayoutLM model") | |
if yolo_pretrained is not None: | |
print("Initializing YOLO model") | |
self.yolo_model = YOLO(yolo_pretrained) | |
print("Finished initializing YOLO model") | |
print("Initializing VietOCR model") | |
config = Cfg.load_config_from_name('vgg_seq2seq') | |
config['weights'] = vietocr_pretrained | |
config['cnn']['pretrained']= False | |
config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
self.vietocr = Predictor(config) | |
print("Finished initializing VietOCR model") | |
def forward(self, img, output_path="output", is_save_cropped_img=False): | |
input_image = Image.open(img) | |
# detection with YOLOv8 | |
bboxes = self.yolov8_det(input_image) | |
# sort | |
sorted_bboxes = self.sort_bboxes(bboxes) | |
# draw bbox | |
image_draw = input_image.copy() | |
self.draw_bbox(image_draw, sorted_bboxes, output_path) | |
# crop images | |
cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path) | |
# recognition with VietOCR | |
texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes) | |
# KIE with LayoutLMv3 | |
pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path) | |
# create dataframe | |
return create_df(pred_texts, pred_label) | |
def yolov8_det(self, img): | |
return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int() | |
def sort_bboxes(self, bboxes): | |
bbox_list = [] | |
for box in bboxes: | |
tlx, tly, brx, bry = map(int, box) | |
bbox_list.append([tlx, tly, brx, bry]) | |
bbox_list.sort(key=lambda x: (x[1], x[2])) | |
return bbox_list | |
def draw_bbox(self, image_draw, bboxes, output_path): | |
# draw bbox | |
draw = ImageDraw.Draw(image_draw) | |
for box in bboxes: | |
draw.rectangle(box, outline='red', width=2) | |
image_draw.save(os.path.join(output_path, 'bbox.jpg')) | |
print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}") | |
def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"): | |
normalized_boxes = [] | |
cropped_images = [] | |
# OCR | |
if is_save_cropped: | |
cropped_folder = os.path.join(output_path, "cropped") | |
if not os.path.exists(cropped_folder): | |
os.makedirs(cropped_folder) | |
i = 0 | |
for box in bboxes: | |
tlx, tly, brx, bry = map(int, box) | |
normalized_box = normalize_box(box, input_image.width, input_image.height) | |
normalized_boxes.append(normalized_box) | |
cropped_ = input_image.crop((tlx, tly, brx, bry)) | |
if is_save_cropped: | |
cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg')) | |
i += 1 | |
cropped_images.append(cropped_) | |
return cropped_images, normalized_boxes | |
def ocr(self, cropped_images, normalized_boxes): | |
mapping_bbox_texts = {} | |
texts = [] | |
for img, normalized_box in zip(cropped_images, normalized_boxes): | |
result = self.vietocr.predict(img) | |
text = result.strip().replace('\n', ' ') | |
texts.append(text) | |
mapping_bbox_texts[','.join(map(str, normalized_box))] = text | |
return texts, mapping_bbox_texts | |
def kie(self, img, texts, boxes, mapping_bbox_texts, output_path): | |
encoding = self.processor(img, texts, | |
boxes=boxes, | |
return_offsets_mapping=True, | |
return_tensors='pt', | |
max_length=512, | |
padding='max_length') | |
offset_mapping = encoding.pop('offset_mapping') | |
with torch.no_grad(): | |
outputs = self.lalm_model(**encoding) | |
id2label = self.lalm_model.config.id2label | |
logits = outputs.logits | |
token_boxes = encoding.bbox.squeeze().tolist() | |
offset_mapping = offset_mapping.squeeze().tolist() | |
predictions = logits.argmax(-1).squeeze().tolist() | |
is_subword = np.array(offset_mapping)[:, 0] != 0 | |
true_predictions = [] | |
true_boxes = [] | |
true_texts = [] | |
for idx in range(len(predictions)): | |
if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]: | |
true_predictions.append(id2label[predictions[idx]]) | |
true_boxes.append(unnormalize_box( | |
token_boxes[idx], img.width, img.height)) | |
true_texts.append(mapping_bbox_texts.get( | |
','.join(map(str, token_boxes[idx])), '')) | |
if isinstance(output_path, str): | |
os.makedirs(output_path, exist_ok=True) | |
img_output = draw_output( | |
image=img, | |
true_predictions=true_predictions, | |
true_boxes=true_boxes | |
) | |
img_output.save(os.path.join(output_path, 'result.jpg')) | |
print(f"Exported result to {os.path.join(output_path, 'result.jpg')}") | |
return true_texts, true_predictions, true_boxes |