Spaces:
Sleeping
Sleeping
Upload complex_parser.py
Browse files- complex_parser.py +91 -0
complex_parser.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: complex_parser.py
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from transformers import AutoImageProcessor, TableTransformerForObjectDetection
|
7 |
+
import easyocr
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
# --- Configuration & Model Initialization ---
|
11 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
print(f"Complex parser using device: {DEVICE}")
|
13 |
+
|
14 |
+
# Initialize models and reader once to save resources
|
15 |
+
TABLE_STRUCTURE_MODEL = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition").to(DEVICE)
|
16 |
+
IMAGE_PROCESSOR = AutoImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
|
17 |
+
OCR_READER = easyocr.Reader(['en'])
|
18 |
+
|
19 |
+
# --- Helper Functions for Model Processing ---
|
20 |
+
|
21 |
+
def _get_bounding_box(tensor_box):
|
22 |
+
"""Converts a tensor bounding box to a PIL-compatible format."""
|
23 |
+
return [round(i, 2) for i in tensor_box.tolist()]
|
24 |
+
|
25 |
+
def _get_cell_coordinates_by_row(table_data):
|
26 |
+
"""Organizes cell coordinates by their row."""
|
27 |
+
rows = [sorted(row, key=lambda x: x['bbox'][0]) for row in table_data['rows']]
|
28 |
+
return [{'row': i, 'bbox': _get_bounding_box(cell['bbox'])} for i, row in enumerate(rows) for cell in row]
|
29 |
+
|
30 |
+
def _apply_ocr_to_cells(image: Image.Image, cells: List[dict]) -> List[dict]:
|
31 |
+
"""Applies OCR to each cell in the table."""
|
32 |
+
for cell in cells:
|
33 |
+
cell_image = image.crop(cell['bbox'])
|
34 |
+
ocr_result = OCR_READER.readtext(cell_image, detail=0, paragraph=True)
|
35 |
+
cell['text'] = ' '.join(ocr_result)
|
36 |
+
return cells
|
37 |
+
|
38 |
+
# --- Main Public Functions ---
|
39 |
+
|
40 |
+
def process_image_element(image: Image.Image) -> str:
|
41 |
+
"""Processes an image element using OCR to extract text."""
|
42 |
+
print("--- Processing image element with OCR ---")
|
43 |
+
try:
|
44 |
+
# Convert the PIL Image to a NumPy array before passing to easyocr
|
45 |
+
image_np = np.array(image)
|
46 |
+
ocr_result = OCR_READER.readtext(image_np, detail=0, paragraph=True)
|
47 |
+
text = ' '.join(ocr_result)
|
48 |
+
return f"\n\n[Image Content: {text}]\n\n" if text else "\n\n[Image Content: No text detected]\n\n"
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error during image OCR: {e}")
|
51 |
+
return "\n\n[Image Content: Error during processing]\n\n"
|
52 |
+
|
53 |
+
def process_table_element(image: Image.Image) -> str:
|
54 |
+
"""Processes a table element using Table Transformer and OCR."""
|
55 |
+
print("--- Processing table element with Table Transformer ---")
|
56 |
+
try:
|
57 |
+
pixel_values, _ = IMAGE_PROCESSOR(image, return_tensors="pt")
|
58 |
+
with torch.no_grad():
|
59 |
+
outputs = TABLE_STRUCTURE_MODEL(pixel_values.to(DEVICE))
|
60 |
+
|
61 |
+
table_data = outputs.to('cpu').item()
|
62 |
+
if not table_data['rows']:
|
63 |
+
return process_image_element(image)
|
64 |
+
|
65 |
+
cells = _get_cell_coordinates_by_row(table_data)
|
66 |
+
cells_with_text = _apply_ocr_to_cells(image, cells)
|
67 |
+
|
68 |
+
df = pd.DataFrame(cells_with_text)
|
69 |
+
if 'row' not in df.columns or 'text' not in df.columns:
|
70 |
+
return "[Table Content: Could not form DataFrame]"
|
71 |
+
|
72 |
+
table_pivot = df.pivot_table(index='row', columns=df.groupby('row').cumcount(), values='text', aggfunc='first').fillna('')
|
73 |
+
markdown_table = table_pivot.to_markdown()
|
74 |
+
|
75 |
+
return f"\n\n[Table Content]:\n{markdown_table}\n\n"
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error during table processing: {e}")
|
78 |
+
return process_image_element(image)
|
79 |
+
|
80 |
+
def stitch_tables(table_markdowns: list[str]) -> str:
|
81 |
+
"""Stitches markdown tables from consecutive pages together."""
|
82 |
+
if not table_markdowns:
|
83 |
+
return ""
|
84 |
+
full_table = table_markdowns[0]
|
85 |
+
for i in range(1, len(table_markdowns)):
|
86 |
+
lines = table_markdowns[i].split('\n')
|
87 |
+
header_separator_index = next((j for j, line in enumerate(lines) if '|---' in line), -1)
|
88 |
+
if header_separator_index != -1 and header_separator_index + 1 < len(lines):
|
89 |
+
rows_to_append = '\n'.join(lines[header_separator_index + 1:])
|
90 |
+
full_table += '\n' + rows_to_append
|
91 |
+
return full_table
|