PercivalFletcher commited on
Commit
38eedd3
·
verified ·
1 Parent(s): 2f2cb69

Upload complex_parser.py

Browse files
Files changed (1) hide show
  1. complex_parser.py +91 -0
complex_parser.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: complex_parser.py
2
+ import torch
3
+ import pandas as pd
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import AutoImageProcessor, TableTransformerForObjectDetection
7
+ import easyocr
8
+ from typing import List
9
+
10
+ # --- Configuration & Model Initialization ---
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Complex parser using device: {DEVICE}")
13
+
14
+ # Initialize models and reader once to save resources
15
+ TABLE_STRUCTURE_MODEL = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition").to(DEVICE)
16
+ IMAGE_PROCESSOR = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
17
+ OCR_READER = easyocr.Reader(['en'])
18
+
19
+ # --- Helper Functions for Model Processing ---
20
+
21
+ def _get_bounding_box(tensor_box):
22
+ """Converts a tensor bounding box to a PIL-compatible format."""
23
+ return [round(i, 2) for i in tensor_box.tolist()]
24
+
25
+ def _get_cell_coordinates_by_row(table_data):
26
+ """Organizes cell coordinates by their row."""
27
+ rows = [sorted(row, key=lambda x: x['bbox'][0]) for row in table_data['rows']]
28
+ return [{'row': i, 'bbox': _get_bounding_box(cell['bbox'])} for i, row in enumerate(rows) for cell in row]
29
+
30
+ def _apply_ocr_to_cells(image: Image.Image, cells: List[dict]) -> List[dict]:
31
+ """Applies OCR to each cell in the table."""
32
+ for cell in cells:
33
+ cell_image = image.crop(cell['bbox'])
34
+ ocr_result = OCR_READER.readtext(cell_image, detail=0, paragraph=True)
35
+ cell['text'] = ' '.join(ocr_result)
36
+ return cells
37
+
38
+ # --- Main Public Functions ---
39
+
40
+ def process_image_element(image: Image.Image) -> str:
41
+ """Processes an image element using OCR to extract text."""
42
+ print("--- Processing image element with OCR ---")
43
+ try:
44
+ # Convert the PIL Image to a NumPy array before passing to easyocr
45
+ image_np = np.array(image)
46
+ ocr_result = OCR_READER.readtext(image_np, detail=0, paragraph=True)
47
+ text = ' '.join(ocr_result)
48
+ return f"\n\n[Image Content: {text}]\n\n" if text else "\n\n[Image Content: No text detected]\n\n"
49
+ except Exception as e:
50
+ print(f"Error during image OCR: {e}")
51
+ return "\n\n[Image Content: Error during processing]\n\n"
52
+
53
+ def process_table_element(image: Image.Image) -> str:
54
+ """Processes a table element using Table Transformer and OCR."""
55
+ print("--- Processing table element with Table Transformer ---")
56
+ try:
57
+ pixel_values, _ = IMAGE_PROCESSOR(image, return_tensors="pt")
58
+ with torch.no_grad():
59
+ outputs = TABLE_STRUCTURE_MODEL(pixel_values.to(DEVICE))
60
+
61
+ table_data = outputs.to('cpu').item()
62
+ if not table_data['rows']:
63
+ return process_image_element(image)
64
+
65
+ cells = _get_cell_coordinates_by_row(table_data)
66
+ cells_with_text = _apply_ocr_to_cells(image, cells)
67
+
68
+ df = pd.DataFrame(cells_with_text)
69
+ if 'row' not in df.columns or 'text' not in df.columns:
70
+ return "[Table Content: Could not form DataFrame]"
71
+
72
+ table_pivot = df.pivot_table(index='row', columns=df.groupby('row').cumcount(), values='text', aggfunc='first').fillna('')
73
+ markdown_table = table_pivot.to_markdown()
74
+
75
+ return f"\n\n[Table Content]:\n{markdown_table}\n\n"
76
+ except Exception as e:
77
+ print(f"Error during table processing: {e}")
78
+ return process_image_element(image)
79
+
80
+ def stitch_tables(table_markdowns: list[str]) -> str:
81
+ """Stitches markdown tables from consecutive pages together."""
82
+ if not table_markdowns:
83
+ return ""
84
+ full_table = table_markdowns[0]
85
+ for i in range(1, len(table_markdowns)):
86
+ lines = table_markdowns[i].split('\n')
87
+ header_separator_index = next((j for j, line in enumerate(lines) if '|---' in line), -1)
88
+ if header_separator_index != -1 and header_separator_index + 1 < len(lines):
89
+ rows_to_append = '\n'.join(lines[header_separator_index + 1:])
90
+ full_table += '\n' + rows_to_append
91
+ return full_table