shreyasvaidya's picture
Upload folder using huggingface_hub
e7cd7fe verified
import gradio as gr
import nltk
from PIL import Image
import os
from IndicPhotoOCR.ocr import OCR # Ensure OCR class is saved in a file named ocr.py
from IndicPhotoOCR.theme import Seafoam
import numpy as np
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from IndicTransToolkit import IndicProcessor
import torch
DEVICE = "cpu"
# Initialize the OCR object for text detection and recognition
ocr = OCR(device="cpu", verbose=False)
def translate(given_str,lang):
model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
model = model.to(DEVICE)
model.eval()
src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" )
batch = ip.preprocess_batch(
[given_str],
src_lang=src_lang,
tgt_lang=tgt_lang,
)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0]
return translation
def detect_para(bbox_dict):
alpha1 = 0.2
alpha2 = 0.7
beta1 = 0.4
data = bbox_dict
word_crops = list(data.keys())
for i in word_crops:
data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"]
data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2
data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2
data[i]["w"] = data[i]["x2"] - data[i]["x1"]
data[i]["h"] = data[i]["y2"] - data[i]["y1"]
patch_info = {}
while word_crops:
img_name = word_crops[0].split("_")[0]
word_crop_collection = [
word_crop for word_crop in word_crops if word_crop.startswith(img_name)
]
centroids = {}
lines = []
img_word_crops = word_crop_collection.copy()
para = []
while img_word_crops:
clusters = []
para_words_group = [
img_word_crops[0],
]
added = [
img_word_crops[0],
]
img_word_crops.remove(img_word_crops[0])
## determining the paragraph
while added:
word_crop = added.pop()
for i in range(len(img_word_crops)):
word_crop_ = img_word_crops[i]
if (
abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
< data[word_crop]["h"] * alpha1
):
if data[word_crop]["xc"] > data[word_crop_]["xc"]:
if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
word_crop
]["h"] * alpha2:
para_words_group.append(word_crop_)
added.append(word_crop_)
else:
if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
word_crop
]["h"] * alpha2:
para_words_group.append(word_crop_)
added.append(word_crop_)
else:
if data[word_crop]["yc"] > data[word_crop_]["yc"]:
if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[
word_crop
]["h"] * beta1 and (
(
(data[word_crop_]["x1"] < data[word_crop]["x2"])
and (data[word_crop_]["x1"] > data[word_crop]["x1"])
)
or (
(data[word_crop_]["x2"] < data[word_crop]["x2"])
and (data[word_crop_]["x2"] > data[word_crop]["x1"])
)
or (
(data[word_crop]["x1"] > data[word_crop_]["x1"])
and (data[word_crop]["x2"] < data[word_crop_]["x2"])
)
):
para_words_group.append(word_crop_)
added.append(word_crop_)
else:
if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[
word_crop
]["h"] * beta1 and (
(
(data[word_crop_]["x1"] < data[word_crop]["x2"])
and (data[word_crop_]["x1"] > data[word_crop]["x1"])
)
or (
(data[word_crop_]["x2"] < data[word_crop]["x2"])
and (data[word_crop_]["x2"] > data[word_crop]["x1"])
)
or (
(data[word_crop]["x1"] > data[word_crop_]["x1"])
and (data[word_crop]["x2"] < data[word_crop_]["x2"])
)
):
para_words_group.append(word_crop_)
added.append(word_crop_)
img_word_crops = [p for p in img_word_crops if p not in para_words_group]
## processing for the line
while para_words_group:
line_words_group = [
para_words_group[0],
]
added = [
para_words_group[0],
]
para_words_group.remove(para_words_group[0])
## determining the line
while added:
word_crop = added.pop()
for i in range(len(para_words_group)):
word_crop_ = para_words_group[i]
if (
abs(data[word_crop_]["yc"] - data[word_crop]["yc"])
< data[word_crop]["h"] * alpha1
):
if data[word_crop]["xc"] > data[word_crop_]["xc"]:
if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[
word_crop
]["h"] * alpha2:
line_words_group.append(word_crop_)
added.append(word_crop_)
else:
if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[
word_crop
]["h"] * alpha2:
line_words_group.append(word_crop_)
added.append(word_crop_)
para_words_group = [
p for p in para_words_group if p not in line_words_group
]
xc = [data[word_crop]["xc"] for word_crop in line_words_group]
idxs = np.argsort(xc)
patch_cluster_ = [line_words_group[i] for i in idxs]
line_words_group = patch_cluster_
x1 = [data[word_crop]["x1"] for word_crop in line_words_group]
x2 = [data[word_crop]["x2"] for word_crop in line_words_group]
y1 = [data[word_crop]["y1"] for word_crop in line_words_group]
y2 = [data[word_crop]["y2"] for word_crop in line_words_group]
txt_line = [data[word_crop]["txt"] for word_crop in line_words_group]
txt = " ".join(txt_line)
x = [x1[0]]
y1_ = [y1[0]]
y2_ = [y2[0]]
l = [len(txt_l) for txt_l in txt_line]
for i in range(1, len(x1)):
x.append((x1[i] + x2[i - 1]) / 2)
y1_.append((y1[i] + y1[i - 1]) / 2)
y2_.append((y2[i] + y2[i - 1]) / 2)
x.append(x2[-1])
y1_.append(y1[-1])
y2_.append(y2[-1])
line_info = {
"x": x,
"y1": y1_,
"y2": y2_,
"l": l,
"txt": txt,
"word_crops": line_words_group,
}
clusters.append(line_info)
y_ = [clusters[i]["y1"][0] for i in range(len(clusters))]
idxs = np.argsort(y_)
clusters_ = [clusters[i] for i in idxs]
txt = [clusters[i]["txt"] for i in idxs]
l = [len(t) for t in txt]
txt = " ".join(txt)
para_info = {"lines": clusters_, "l": l, "txt": txt}
para.append(para_info)
for word_crop in word_crop_collection:
word_crops.remove(word_crop)
return "\n".join([para[i]["txt"] for i in range(len(para))])
def process_image(image):
"""
Processes the uploaded image for text detection and recognition.
- Detects bounding boxes in the image
- Draws bounding boxes on the image and identifies script in each detected area
- Recognizes text in each cropped region and returns the annotated image and recognized text
Parameters:
image (PIL.Image): The input image to be processed.
Returns:
tuple: A PIL.Image with bounding boxes and a string of recognized text.
"""
# Save the input image temporarily
image_path = "input_image.jpg"
image.save(image_path)
# Detect bounding boxes on the image using OCR
detections = ocr.detect(image_path)
# Draw bounding boxes on the image and save it as output
ocr.visualize_detection(image_path, detections, save_path="output_image.png")
# Load the annotated image with bounding boxes drawn
output_image = Image.open("output_image.png")
# Initialize list to hold recognized text from each detected area
recognized_texts = {}
pil_image = Image.open(image_path)
script_lang = "english"
# Process each detected bounding box for script identification and text recognition
for id,bbox in enumerate(detections):
# Identify the script and crop the image to this region
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox)
x1 = min([bbox[i][0] for i in range(len(bbox))])
y1 = min([bbox[i][1] for i in range(len(bbox))])
x2 = max([bbox[i][0] for i in range(len(bbox))])
y2 = max([bbox[i][1] for i in range(len(bbox))])
if script_lang:
recognized_text = ocr.recognise(cropped_path,script_lang)
recognized_texts[f"img_{id}"] = {"txt":recognized_text,"bbox":[x1,y1,x2,y2]}
translated = translate(detect_para(recognized_texts),script_lang)
# Combine recognized texts into a single string for display
return output_image,translated
# Custom HTML for interface header with logos and alignment
interface_html = """
<div style="text-align: left; padding: 10px;">
<div style="background-color: white; padding: 10px; display: inline-block;">
<img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;">
</div>
<img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;">
</div>
"""
# Links to GitHub and Dataset repositories with GitHub icon
links_html = """
<div style="text-align: center; padding-top: 20px;">
<a href="https://github.com/Bhashini-IITJ/visualTranslation" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;">
GitHub Repository
</a>
<a href="https://vl2g.github.io/projects/visTrans" target="_blank" style="font-size: 18px; text-decoration: none;">
Project Page
</a>
</div>
"""
# Custom CSS to style the text box font size
custom_css = """
.custom-textbox textarea {
font-size: 20px !important;
}
"""
# Create an instance of the Seafoam theme for a consistent visual style
seafoam = Seafoam()
# Define examples for users to try out
examples = [
["test_images/208.jpg"],
["test_images/1310.jpg"]
]
title = "<h1 style='text-align: center;'>Developed by IITJ</h1>"
# Set up the Gradio Interface with the defined function and customizations
demo = gr.Interface(
allow_flagging="never",
fn=process_image,
inputs=gr.Image(type="pil", image_mode="RGB"),
outputs=[
gr.Image(type="pil", label="Detected Bounding Boxes"),
gr.Textbox(label="Translated Text", elem_classes="custom-textbox")
],
title="IndicPhotoOCR - Indic Scene Text Recogniser Toolkit",
description=title+interface_html+links_html,
theme=seafoam,
css=custom_css,
examples=examples
)
# Server setup and launch configuration
# if __name__ == "__main__":
# server = "0.0.0.0" # IP address for server
# port = 7865 # Port to run the server on
# demo.launch(server_name=server, server_port=port)
demo.launch()