Spaces:
Sleeping
Sleeping
import streamlit as st | |
import warnings | |
warnings.simplefilter("ignore", UserWarning) | |
from uuid import uuid4 | |
from laia.scripts.htr.decode_ctc import run as decode | |
from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs | |
import sys | |
from tempfile import NamedTemporaryFile, mkdtemp | |
from pathlib import Path | |
from contextlib import redirect_stdout | |
import re | |
from PIL import Image | |
from bidi.algorithm import get_display | |
import multiprocessing | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import logging | |
from typing import List, Optional | |
# Configure logging | |
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) | |
# Load YOLOv8 model | |
model = YOLO('model.pt') | |
images = Path(mkdtemp()) | |
DEFAULT_HEIGHT = 128 | |
TEXT_DIRECTION = "LTR" | |
NUM_WORKERS = multiprocessing.cpu_count() | |
# Regex pattern for extracting results | |
IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})" | |
CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line | |
TEXT_PATTERN = r"\s*(?P<text>.*)\s*" | |
LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}") | |
def get_width(image, height=DEFAULT_HEIGHT): | |
aspect_ratio = image.width / image.height | |
return height * aspect_ratio | |
def simplify_polygons(polygons: List[np.ndarray], approx_level: float = 0.01) -> List[Optional[np.ndarray]]: | |
"""Simplify polygon contours using Douglas-Peucker algorithm. | |
Args: | |
polygons: List of polygon contours | |
approx_level: Approximation level (0-1), lower values mean more simplification | |
Returns: | |
List of simplified polygons (or None for invalid polygons) | |
""" | |
result = [] | |
for polygon in polygons: | |
if len(polygon) < 4: | |
result.append(None) | |
continue | |
perimeter = cv2.arcLength(polygon, True) | |
approx = cv2.approxPolyDP(polygon, approx_level * perimeter, True) | |
if len(approx) < 4: | |
result.append(None) | |
continue | |
result.append(approx.squeeze()) | |
return result | |
def predict(model_name, input_img): | |
model_dir = 'catmus-medieval' | |
temperature = 2.0 | |
batch_size = 1 | |
weights_path = f"{model_dir}/weights.ckpt" | |
syms_path = f"{model_dir}/syms.txt" | |
language_model_params = {"language_model_weight": 1.0} | |
use_language_model = True | |
if use_language_model: | |
language_model_params.update({ | |
"language_model_path": f"{model_dir}/language_model.binary", | |
"lexicon_path": f"{model_dir}/lexicon.txt", | |
"tokens_path": f"{model_dir}/tokens.txt", | |
}) | |
common_args = CommonArgs( | |
checkpoint="weights.ckpt", | |
train_path=f"{model_dir}", | |
experiment_dirname="", | |
) | |
data_args = DataArgs(batch_size=batch_size, color_mode="L") | |
trainer_args = TrainerArgs(progress_bar_refresh_rate=0) | |
decode_args = DecodeArgs( | |
include_img_ids=True, | |
join_string="", | |
convert_spaces=True, | |
print_line_confidence_scores=True, | |
print_word_confidence_scores=False, | |
temperature=temperature, | |
use_language_model=use_language_model, | |
**language_model_params, | |
) | |
with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list: | |
image_id = uuid4() | |
input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT)) | |
input_img.save(f"{images}/{image_id}.jpg") | |
Path(img_list.name).write_text("\n".join([str(image_id)])) | |
with redirect_stdout(open(pred_stdout.name, mode="w")): | |
decode( | |
syms=str(syms_path), | |
img_list=img_list.name, | |
img_dirs=[str(images)], | |
common=common_args, | |
data=data_args, | |
trainer=trainer_args, | |
decode=decode_args, | |
num_workers=1, | |
) | |
sys.stdout.flush() | |
predictions = Path(pred_stdout.name).read_text().strip().splitlines() | |
_, score, text = LINE_PREDICTION.match(predictions[0]).groups() | |
if TEXT_DIRECTION == "RTL": | |
return input_img, {"text": get_display(text), "score": score} | |
else: | |
return input_img, {"text": text, "score": score} | |
def process_image(image): | |
# Perform inference on an image, select textline only | |
results = model(image, classes=0) | |
img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
masks = results[0].masks | |
polygons = [] | |
texts = [] | |
if masks is not None: | |
# Get masks data and original image dimensions | |
masks = masks.data.cpu().numpy() | |
img_height, img_width = img_cv2.shape[:2] | |
# Get bounding boxes in xyxy format | |
boxes = results[0].boxes.xyxy.cpu().numpy() | |
# Sort by y-coordinate of the top-left corner | |
sorted_indices = np.argsort(boxes[:, 1]) | |
masks = masks[sorted_indices] | |
boxes = boxes[sorted_indices] | |
for i, (mask, box) in enumerate(zip(masks, boxes)): | |
# Scale the mask to original image size | |
mask = cv2.resize(mask.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR) | |
mask = (mask > 0.5).astype(np.uint8) * 255 # Apply threshold | |
# Convert mask to polygon | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if contours: | |
# Get the largest contour | |
largest_contour = max(contours, key=cv2.contourArea) | |
simplified_polygon = simplify_polygons([largest_contour])[0] | |
if simplified_polygon is not None: | |
# Crop the image using the bounding box for text recognition | |
x1, y1, x2, y2 = map(int, box) | |
crop_img = img_cv2[y1:y2, x1:x2] | |
crop_pil = Image.fromarray(cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)) | |
# Recognize text using PyLaia model | |
predicted = predict('pylaia-samaritan_v1', crop_pil) | |
texts.append(predicted[1]["text"]) | |
# Convert polygon to list of points for display | |
poly_points = simplified_polygon.reshape(-1, 2).astype(int).tolist() | |
polygons.append(f"Line {i+1}: {poly_points}") | |
# Draw polygon on the image | |
cv2.polylines(img_cv2, [simplified_polygon.reshape(-1, 1, 2).astype(int)], | |
True, (0, 255, 0), 2) | |
# Convert image back to RGB for display in Streamlit | |
img_result = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) | |
# Combine polygons and texts into a DataFrame for table display | |
table_data = pd.DataFrame({"Polygons": polygons, "Recognized Text": texts}) | |
return Image.fromarray(img_result), table_data | |
def segment_and_recognize(image): | |
segmented_image, table_data = process_image(image) | |
return segmented_image, table_data | |
# Streamlit app layout | |
st.title("YOLOv11 Text Line Segmentation & PyLaia Text Recognition on CATMuS/medieval") | |
# File uploader | |
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
# Process the image if uploaded | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
if st.button("Segment and Recognize"): | |
# Perform segmentation and recognition | |
segmented_image, table_data = segment_and_recognize(image) | |
# Display the segmented image | |
st.image(segmented_image, caption="Segmented Image with Polygon Masks", use_container_width=True) | |
# Display the table with polygons and recognized text | |
st.table(table_data) | |