Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import sys | |
| __dir__ = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(__dir__) | |
| sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) | |
| os.environ["FLAGS_allocator_strategy"] = 'auto_growth' | |
| import cv2 | |
| import json | |
| import numpy as np | |
| import time | |
| import tools.infer.utility as utility | |
| from ppocr.data import create_operators, transform | |
| from ppocr.postprocess import build_post_process | |
| from ppocr.utils.logging import get_logger | |
| from ppocr.utils.visual import draw_ser_results | |
| from ppocr.utils.utility import get_image_file_list, check_and_read | |
| from ppstructure.utility import parse_args | |
| from paddleocr import PaddleOCR | |
| logger = get_logger() | |
| class SerPredictor(object): | |
| def __init__(self, args): | |
| self.ocr_engine = PaddleOCR( | |
| use_angle_cls=args.use_angle_cls, | |
| det_model_dir=args.det_model_dir, | |
| rec_model_dir=args.rec_model_dir, | |
| show_log=False, | |
| use_gpu=args.use_gpu) | |
| pre_process_list = [{ | |
| 'VQATokenLabelEncode': { | |
| 'algorithm': args.kie_algorithm, | |
| 'class_path': args.ser_dict_path, | |
| 'contains_re': False, | |
| 'ocr_engine': self.ocr_engine, | |
| 'order_method': args.ocr_order_method, | |
| } | |
| }, { | |
| 'VQATokenPad': { | |
| 'max_seq_len': 512, | |
| 'return_attention_mask': True | |
| } | |
| }, { | |
| 'VQASerTokenChunk': { | |
| 'max_seq_len': 512, | |
| 'return_attention_mask': True | |
| } | |
| }, { | |
| 'Resize': { | |
| 'size': [224, 224] | |
| } | |
| }, { | |
| 'NormalizeImage': { | |
| 'std': [58.395, 57.12, 57.375], | |
| 'mean': [123.675, 116.28, 103.53], | |
| 'scale': '1', | |
| 'order': 'hwc' | |
| } | |
| }, { | |
| 'ToCHWImage': None | |
| }, { | |
| 'KeepKeys': { | |
| 'keep_keys': [ | |
| 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', | |
| 'image', 'labels', 'segment_offset_id', 'ocr_info', | |
| 'entities' | |
| ] | |
| } | |
| }] | |
| postprocess_params = { | |
| 'name': 'VQASerTokenLayoutLMPostProcess', | |
| "class_path": args.ser_dict_path, | |
| } | |
| self.preprocess_op = create_operators(pre_process_list, | |
| {'infer_mode': True}) | |
| self.postprocess_op = build_post_process(postprocess_params) | |
| self.predictor, self.input_tensor, self.output_tensors, self.config = \ | |
| utility.create_predictor(args, 'ser', logger) | |
| def __call__(self, img): | |
| ori_im = img.copy() | |
| data = {'image': img} | |
| data = transform(data, self.preprocess_op) | |
| if data[0] is None: | |
| return None, 0 | |
| starttime = time.time() | |
| for idx in range(len(data)): | |
| if isinstance(data[idx], np.ndarray): | |
| data[idx] = np.expand_dims(data[idx], axis=0) | |
| else: | |
| data[idx] = [data[idx]] | |
| for idx in range(len(self.input_tensor)): | |
| self.input_tensor[idx].copy_from_cpu(data[idx]) | |
| self.predictor.run() | |
| outputs = [] | |
| for output_tensor in self.output_tensors: | |
| output = output_tensor.copy_to_cpu() | |
| outputs.append(output) | |
| preds = outputs[0] | |
| post_result = self.postprocess_op( | |
| preds, segment_offset_ids=data[6], ocr_infos=data[7]) | |
| elapse = time.time() - starttime | |
| return post_result, data, elapse | |
| def main(args): | |
| image_file_list = get_image_file_list(args.image_dir) | |
| ser_predictor = SerPredictor(args) | |
| count = 0 | |
| total_time = 0 | |
| os.makedirs(args.output, exist_ok=True) | |
| with open( | |
| os.path.join(args.output, 'infer.txt'), mode='w', | |
| encoding='utf-8') as f_w: | |
| for image_file in image_file_list: | |
| img, flag, _ = check_and_read(image_file) | |
| if not flag: | |
| img = cv2.imread(image_file) | |
| img = img[:, :, ::-1] | |
| if img is None: | |
| logger.info("error in loading image:{}".format(image_file)) | |
| continue | |
| ser_res, _, elapse = ser_predictor(img) | |
| ser_res = ser_res[0] | |
| res_str = '{}\t{}\n'.format( | |
| image_file, | |
| json.dumps( | |
| { | |
| "ocr_info": ser_res, | |
| }, ensure_ascii=False)) | |
| f_w.write(res_str) | |
| img_res = draw_ser_results( | |
| image_file, | |
| ser_res, | |
| font_path=args.vis_font_path, ) | |
| img_save_path = os.path.join(args.output, | |
| os.path.basename(image_file)) | |
| cv2.imwrite(img_save_path, img_res) | |
| logger.info("save vis result to {}".format(img_save_path)) | |
| if count > 0: | |
| total_time += elapse | |
| count += 1 | |
| logger.info("Predict time of {}: {}".format(image_file, elapse)) | |
| if __name__ == "__main__": | |
| main(parse_args()) | |