File size: 6,862 Bytes
8aca528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import numpy as np
from ultralytics import YOLO
from transformers import AutoProcessor
from transformers import AutoModelForTokenClassification
from utils import normalize_box, unnormalize_box, draw_output, create_df
from PIL import Image, ImageDraw
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg

class Reciept_Analyzer:
    def __init__(self,

                 processor_pretrained='microsoft/layoutlmv3-base',

                 layoutlm_pretrained=os.path.join(

                      'models', 'checkpoint'),

                 yolo_pretrained=os.path.join(

                      'models', 'best.pt'),

                 vietocr_pretrained=os.path.join(

                      'models', 'vietocr', 'vgg_seq2seq.pth')

                 ):
        
        print("Initializing processor")
        if torch.cuda.is_available():
            print("Using GPU")
        else:
            print("No GPU detected, using CPU")

        self.processor = AutoProcessor.from_pretrained(
            processor_pretrained, apply_ocr=False)
        print("Finished initializing processor")

        print("Initializing LayoutLM model")
        self.lalm_model = AutoModelForTokenClassification.from_pretrained(
            layoutlm_pretrained)
        print("Finished initializing LayoutLM model")

        if yolo_pretrained is not None:
            print("Initializing YOLO model")
            self.yolo_model = YOLO(yolo_pretrained)
            print("Finished initializing YOLO model")

        print("Initializing VietOCR model")
        config = Cfg.load_config_from_name('vgg_seq2seq')
        config['weights'] = vietocr_pretrained
        config['cnn']['pretrained']= False
        config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.vietocr = Predictor(config)
        print("Finished initializing VietOCR model")

    def forward(self, img, output_path="output", is_save_cropped_img=False):
        input_image = Image.open(img)

        # detection with YOLOv8
        bboxes = self.yolov8_det(input_image)
        
        # sort
        sorted_bboxes = self.sort_bboxes(bboxes)
        
        # draw bbox
        image_draw = input_image.copy()
        self.draw_bbox(image_draw, sorted_bboxes, output_path)

        # crop images
        cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path)
        
        # recognition with VietOCR
        texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes)

        # KIE with LayoutLMv3
        pred_texts, pred_label, boxes  = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path)
        
        # create dataframe
        return create_df(pred_texts, pred_label)
        

    def yolov8_det(self, img):
        return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int()

    def sort_bboxes(self, bboxes):
        bbox_list = []
        for box in bboxes:
            tlx, tly, brx, bry = map(int, box)
            bbox_list.append([tlx, tly, brx, bry])
            bbox_list.sort(key=lambda x: (x[1], x[2]))
        return bbox_list
    
    def draw_bbox(self, image_draw, bboxes, output_path):
        # draw bbox
        draw = ImageDraw.Draw(image_draw)
        for box in bboxes:
            draw.rectangle(box, outline='red', width=2)
        image_draw.save(os.path.join(output_path, 'bbox.jpg'))
        print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}")

    def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"):
        normalized_boxes = []
        cropped_images = []

        # OCR
        if is_save_cropped:
            cropped_folder = os.path.join(output_path, "cropped")
            if not os.path.exists(cropped_folder):
                os.makedirs(cropped_folder)
            i = 0
        for box in bboxes:
            tlx, tly, brx, bry = map(int, box)
            normalized_box = normalize_box(box, input_image.width, input_image.height)
            normalized_boxes.append(normalized_box)
            cropped_ = input_image.crop((tlx, tly, brx, bry))
            if is_save_cropped:
                cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg'))
                i += 1
            cropped_images.append(cropped_)

        return cropped_images, normalized_boxes

    def ocr(self, cropped_images, normalized_boxes):
        mapping_bbox_texts = {}
        texts = []
        for img, normalized_box in zip(cropped_images, normalized_boxes):
            result = self.vietocr.predict(img)
            text = result.strip().replace('\n', ' ')
            texts.append(text)
            mapping_bbox_texts[','.join(map(str, normalized_box))] = text
        
        return texts, mapping_bbox_texts
    
    def kie(self, img, texts, boxes, mapping_bbox_texts, output_path):
        encoding = self.processor(img, texts,
                                  boxes=boxes,
                                  return_offsets_mapping=True,
                                  return_tensors='pt',
                                  max_length=512,
                                  padding='max_length')
        offset_mapping = encoding.pop('offset_mapping')

        with torch.no_grad():
            outputs = self.lalm_model(**encoding)

        id2label = self.lalm_model.config.id2label
        logits = outputs.logits
        token_boxes = encoding.bbox.squeeze().tolist()
        offset_mapping = offset_mapping.squeeze().tolist()

        predictions = logits.argmax(-1).squeeze().tolist()
        is_subword = np.array(offset_mapping)[:, 0] != 0

        true_predictions = []
        true_boxes = []
        true_texts = []
        for idx in range(len(predictions)):
            if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]:
                true_predictions.append(id2label[predictions[idx]])
                true_boxes.append(unnormalize_box(
                    token_boxes[idx], img.width, img.height))
                true_texts.append(mapping_bbox_texts.get(
                    ','.join(map(str, token_boxes[idx])), ''))

        if isinstance(output_path, str):
            os.makedirs(output_path, exist_ok=True)
            img_output = draw_output(
                image=img,
                true_predictions=true_predictions,
                true_boxes=true_boxes
            )
            img_output.save(os.path.join(output_path, 'result.jpg'))
            print(f"Exported result to {os.path.join(output_path, 'result.jpg')}")
        return true_texts, true_predictions, true_boxes