Spaces:
Runtime error
Runtime error
## Set Environment | |
import os | |
from pdf2image import convert_from_path | |
import cv2 | |
import base64 | |
import numpy as np | |
import numpy as np | |
from PIL import Image | |
import json | |
from anthropic import Anthropic, Client | |
import gradio as gr | |
## Set Environment | |
os.system('python -m venv env') | |
os.system('source env/bin/activate') | |
## Install poppler in os | |
import os | |
os.system('apt-get update') | |
os.system('sudo apt-get install poppler-utils') | |
## The rest of your app.py code goes here | |
def get_base64_encorded_image(image_path): | |
with open(image_path, "rb") as image_file: | |
binary_data = image_file.read() | |
base64_encorded_data = base64.b64encode(binary_data) | |
base64_string = base64_encorded_data.decode('utf-8') | |
return base64_string | |
## Process pdf | |
def convert_pdf_to_image(pdf_path): | |
# Convert PDF to images | |
pages = convert_from_path(pdf_path, dpi=400) | |
# Save images as PNG files | |
for i, page in enumerate(pages): | |
page.save(f'page_{i}.png', 'PNG') | |
print(f"Converted {len(pages)} pages to images.") | |
return pages | |
## Image process Subprocess - De-stamp | |
def destamp_image(img_path): | |
bgr_img = cv2.imread(img_path) | |
hsv_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2HSV) | |
# Convert the BGR image to grayscale | |
gray_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY) | |
# HSV ragne: (0-180, 0-255, 0-120) | |
# for character black color: | |
# H: 0-180, | |
# S: 0-255 , | |
# V: 0-120 , | |
lower_black = np.array([0,0,0]) | |
upper_black = np.array([180,255,120]) | |
mask = cv2.inRange(hsv_img, lower_black, upper_black) | |
deRed_img = ~mask # Single channel image | |
# thresholding -2 | |
ret, threshold_img_2 = cv2.threshold(deRed_img, 120, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) | |
# Desired shape: (x, y, 1) | |
new_shape = (threshold_img_2.shape[0], threshold_img_2.shape[1], 1) | |
# Resize using numpy.resize() | |
result_img = np.resize(threshold_img_2, new_shape) | |
print(f"result_img.shape: {result_img.shape}") | |
#cv2.imshow(result_img) | |
#save result_img | |
result_filepath="result_img_0.png" | |
cv2.imwrite(result_filepath, result_img) | |
return result_filepath | |
def extract_image_table(image_path): | |
# extract table information | |
response = {} | |
response = extract_table_info(image_path) | |
# Get text element from response | |
check_response(response) | |
# Extract response.content[0].text | |
json_data = extract_json(response) #type(json_data) = "dict" | |
print(f"json_data: {json_data}") | |
return json_data | |
## Extract Table Information | |
def extract_table_info(image_path): | |
my_api_key = os.getenv('ANTHROPIC_API_KEY') | |
# Claude | |
client = Anthropic(api_key=my_api_key) # Pass the API key here | |
MODEL_NAME = "claude-3-5-sonnet-20240620" | |
#Do ascending sort with index of value of "代碼" for all the rows in each section. If there is "X" or "x" in "代碼", treat it as "9". | |
message_list = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": get_base64_encorded_image(image_path)}}, | |
{ | |
"type": "text", | |
"text": """ | |
You are analyzing an Financial Statement in traditional Chinese. | |
Please extract all the information of the statement image, keep the context in Traditional Chinese without translation. | |
Extract information row by row, and cell by cell. | |
Keep document title, header, date, currency, section header, summary, footer, ... as part of the information. | |
OCR all the cells precisely with the best accuracy. Any Chinese character, if you can not make the best guess, please return "?". Do not ignore it. | |
Do not do any correction with the content of the cell related with "代碼", even it is not 100% correct from your experience. Keep as what it is. | |
Makd sure the length of the string of each cell is same as the image. | |
Save all the information as a markdown table. | |
Keep alignment of each column with the image. | |
Repsonse as below structure: | |
<mark> | |
... | |
... | |
... | |
</mark> | |
""" | |
} | |
] | |
} | |
] | |
# Update how the API is called | |
response = client.messages.create( | |
model=MODEL_NAME, | |
max_tokens=3072, # limit the amount of response information | |
messages=message_list, | |
temperature=0.6, | |
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} # Changed to a dictionary | |
) | |
tokens = response.usage.output_tokens | |
print(f"Generated Tokens: {tokens}") | |
print(f"Response: {response}") | |
return response | |
## Check Response | |
def check_response(response): | |
# Check the type and content of the response | |
print(type(response.content)) | |
print(response.content) | |
# Assuming the text content is in the first element of the list | |
if isinstance(response.content, list) and response.content: | |
content_text = response.content[0].text | |
#print(json.dumps(content_text, sort_keys=True, indent=4)) | |
else: | |
print("Unexpected response format. Unable to extract text.") | |
return None | |
## Extract markdown data | |
def extract_markdown(response): | |
response_text = response.content[0].text # Access the 'text' attribute of the TextBlock object | |
# Try to find the start and end of the JSON object more robustly | |
# skip <json> | |
mark_start = response_text.find("<mark>")+6 # Skip the <json> tag | |
mark_end = response_text.find("</mark>") # Include the closing brace | |
print(f"mark_start: {mark_start}") | |
print(f"mark_end: {mark_end}") | |
# Check if valid start and end indices were found | |
if mark_start >= 0 and mark_end > mark_start: | |
mark_data = response_text[mark_start:mark_end] | |
print(f"mark_data: {mark_data}") | |
return mark_data | |
else: | |
print("Could not find valid Markdown object in response.") | |
return | |
## Extract Json data | |
def extract_json(response): | |
response_text = response.content[0].text # Access the 'text' attribute of the TextBlock object | |
# Try to find the start and end of the JSON object more robustly | |
# skip <json> | |
json_start = response_text.find("<json>")+6 # Skip the <json> tag | |
json_end = response_text.rfind("</json>") # Include the closing brace | |
# Check if valid start and end indices were found | |
if json_start >= 0 and json_end > json_start: | |
try: | |
return json.loads(response_text[json_start:json_end]) | |
except json.JSONDecodeError as e: | |
print(f"Error decoding JSON: {e}") | |
print(f"Problematic JSON string: {response_text[json_start+1:json_end]}") | |
return {response_text[json_start+1:json_end]} | |
else: | |
print("Could not find valid JSON object in response.") | |
return | |
## Convert json to Dataframe | |
## Convert to csv | |
## Process PDF | |
def pipeline(pdf_path): | |
pages = convert_pdf_to_image(pdf_path) | |
print(f"pages: {pages}") | |
destamp_img = destamp_image("page_0.png") | |
response = {} | |
response = extract_table_info(destamp_img) | |
check_response(response) | |
mark_data = extract_markdown(response) | |
#json_data = extract_json(response) | |
return len(pages), destamp_img, mark_data | |
## Gradio Interface | |
title = "Demo: Financial Statement(PDF) information Extraction - Traditional Chinese" | |
description = """Demo pdf, either editable or scanned image, information extraction for Traditional Chinese without OCR""" | |
examples = [['text_pdf.pdf'], ['image_pdf.pdf']] | |
pdf_file = gr.File(label="Upload PDF", type="filepath") | |
pages = gr.File(label="Pages", type="filepath") | |
num_pages = gr.Number(label="Number of Pages") | |
destamp_img = gr.Image(type="numpy", label="De-stamped Image") | |
#json_data = gr.JSON(label="JSON Data") | |
mark_data = gr.Markdown(label="Markdown Data") | |
app = gr.Interface(fn=pipeline, | |
inputs=pdf_file, | |
outputs=[num_pages, destamp_img, mark_data], | |
title=title, | |
description=description, | |
examples=examples) | |
app.queue() | |
app.launch(debug=True, share=True) | |
#app.launch() |