|
import gradio as gr |
|
import nltk |
|
from PIL import Image |
|
import os |
|
from IndicPhotoOCR.ocr import OCR |
|
from IndicPhotoOCR.theme import Seafoam |
|
import numpy as np |
|
import torch |
|
from transformers import ( |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
) |
|
from IndicTransToolkit import IndicProcessor |
|
|
|
|
|
|
|
import torch |
|
|
|
DEVICE = "cpu" |
|
|
|
|
|
ocr = OCR(device="cpu", verbose=False) |
|
def translate(given_str,lang): |
|
model_name = "ai4bharat/indictrans2-en-indic-1B" if lang=="english" else "ai4bharat/indictrans2-indic-en-1B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
ip = IndicProcessor(inference=True) |
|
|
|
model = model.to(DEVICE) |
|
model.eval() |
|
src_lang, tgt_lang = ("eng_Latn", "hin_Deva") if lang=="english" else ("hin_Deva", "eng_Latn" ) |
|
|
|
batch = ip.preprocess_batch( |
|
[given_str], |
|
src_lang=src_lang, |
|
tgt_lang=tgt_lang, |
|
) |
|
inputs = tokenizer( |
|
batch, |
|
truncation=True, |
|
padding="longest", |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
).to(DEVICE) |
|
with torch.no_grad(): |
|
generated_tokens = model.generate( |
|
**inputs, |
|
use_cache=True, |
|
min_length=0, |
|
max_length=256, |
|
num_beams=5, |
|
num_return_sequences=1, |
|
) |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
generated_tokens = tokenizer.batch_decode( |
|
generated_tokens.detach().cpu().tolist(), |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0] |
|
return translation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_para(bbox_dict): |
|
alpha1 = 0.2 |
|
alpha2 = 0.7 |
|
beta1 = 0.4 |
|
data = bbox_dict |
|
word_crops = list(data.keys()) |
|
for i in word_crops: |
|
data[i]["x1"], data[i]["y1"], data[i]["x2"], data[i]["y2"] = data[i]["bbox"] |
|
data[i]["xc"] = (data[i]["x1"] + data[i]["x2"]) / 2 |
|
data[i]["yc"] = (data[i]["y1"] + data[i]["y2"]) / 2 |
|
data[i]["w"] = data[i]["x2"] - data[i]["x1"] |
|
data[i]["h"] = data[i]["y2"] - data[i]["y1"] |
|
|
|
patch_info = {} |
|
while word_crops: |
|
img_name = word_crops[0].split("_")[0] |
|
word_crop_collection = [ |
|
word_crop for word_crop in word_crops if word_crop.startswith(img_name) |
|
] |
|
centroids = {} |
|
lines = [] |
|
img_word_crops = word_crop_collection.copy() |
|
para = [] |
|
while img_word_crops: |
|
clusters = [] |
|
para_words_group = [ |
|
img_word_crops[0], |
|
] |
|
added = [ |
|
img_word_crops[0], |
|
] |
|
img_word_crops.remove(img_word_crops[0]) |
|
|
|
while added: |
|
word_crop = added.pop() |
|
for i in range(len(img_word_crops)): |
|
word_crop_ = img_word_crops[i] |
|
if ( |
|
abs(data[word_crop_]["yc"] - data[word_crop]["yc"]) |
|
< data[word_crop]["h"] * alpha1 |
|
): |
|
if data[word_crop]["xc"] > data[word_crop_]["xc"]: |
|
if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[ |
|
word_crop |
|
]["h"] * alpha2: |
|
para_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
else: |
|
if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[ |
|
word_crop |
|
]["h"] * alpha2: |
|
para_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
else: |
|
if data[word_crop]["yc"] > data[word_crop_]["yc"]: |
|
if (data[word_crop]["y1"] - data[word_crop_]["y2"]) < data[ |
|
word_crop |
|
]["h"] * beta1 and ( |
|
( |
|
(data[word_crop_]["x1"] < data[word_crop]["x2"]) |
|
and (data[word_crop_]["x1"] > data[word_crop]["x1"]) |
|
) |
|
or ( |
|
(data[word_crop_]["x2"] < data[word_crop]["x2"]) |
|
and (data[word_crop_]["x2"] > data[word_crop]["x1"]) |
|
) |
|
or ( |
|
(data[word_crop]["x1"] > data[word_crop_]["x1"]) |
|
and (data[word_crop]["x2"] < data[word_crop_]["x2"]) |
|
) |
|
): |
|
para_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
else: |
|
if (data[word_crop_]["y1"] - data[word_crop]["y2"]) < data[ |
|
word_crop |
|
]["h"] * beta1 and ( |
|
( |
|
(data[word_crop_]["x1"] < data[word_crop]["x2"]) |
|
and (data[word_crop_]["x1"] > data[word_crop]["x1"]) |
|
) |
|
or ( |
|
(data[word_crop_]["x2"] < data[word_crop]["x2"]) |
|
and (data[word_crop_]["x2"] > data[word_crop]["x1"]) |
|
) |
|
or ( |
|
(data[word_crop]["x1"] > data[word_crop_]["x1"]) |
|
and (data[word_crop]["x2"] < data[word_crop_]["x2"]) |
|
) |
|
): |
|
para_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
img_word_crops = [p for p in img_word_crops if p not in para_words_group] |
|
|
|
while para_words_group: |
|
line_words_group = [ |
|
para_words_group[0], |
|
] |
|
added = [ |
|
para_words_group[0], |
|
] |
|
para_words_group.remove(para_words_group[0]) |
|
|
|
while added: |
|
word_crop = added.pop() |
|
for i in range(len(para_words_group)): |
|
word_crop_ = para_words_group[i] |
|
if ( |
|
abs(data[word_crop_]["yc"] - data[word_crop]["yc"]) |
|
< data[word_crop]["h"] * alpha1 |
|
): |
|
if data[word_crop]["xc"] > data[word_crop_]["xc"]: |
|
if (data[word_crop]["x1"] - data[word_crop_]["x2"]) < data[ |
|
word_crop |
|
]["h"] * alpha2: |
|
line_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
else: |
|
if (data[word_crop_]["x1"] - data[word_crop]["x2"]) < data[ |
|
word_crop |
|
]["h"] * alpha2: |
|
line_words_group.append(word_crop_) |
|
added.append(word_crop_) |
|
para_words_group = [ |
|
p for p in para_words_group if p not in line_words_group |
|
] |
|
xc = [data[word_crop]["xc"] for word_crop in line_words_group] |
|
idxs = np.argsort(xc) |
|
patch_cluster_ = [line_words_group[i] for i in idxs] |
|
line_words_group = patch_cluster_ |
|
x1 = [data[word_crop]["x1"] for word_crop in line_words_group] |
|
x2 = [data[word_crop]["x2"] for word_crop in line_words_group] |
|
y1 = [data[word_crop]["y1"] for word_crop in line_words_group] |
|
y2 = [data[word_crop]["y2"] for word_crop in line_words_group] |
|
txt_line = [data[word_crop]["txt"] for word_crop in line_words_group] |
|
txt = " ".join(txt_line) |
|
x = [x1[0]] |
|
y1_ = [y1[0]] |
|
y2_ = [y2[0]] |
|
l = [len(txt_l) for txt_l in txt_line] |
|
for i in range(1, len(x1)): |
|
x.append((x1[i] + x2[i - 1]) / 2) |
|
y1_.append((y1[i] + y1[i - 1]) / 2) |
|
y2_.append((y2[i] + y2[i - 1]) / 2) |
|
x.append(x2[-1]) |
|
y1_.append(y1[-1]) |
|
y2_.append(y2[-1]) |
|
line_info = { |
|
"x": x, |
|
"y1": y1_, |
|
"y2": y2_, |
|
"l": l, |
|
"txt": txt, |
|
"word_crops": line_words_group, |
|
} |
|
clusters.append(line_info) |
|
y_ = [clusters[i]["y1"][0] for i in range(len(clusters))] |
|
idxs = np.argsort(y_) |
|
clusters_ = [clusters[i] for i in idxs] |
|
txt = [clusters[i]["txt"] for i in idxs] |
|
l = [len(t) for t in txt] |
|
txt = " ".join(txt) |
|
para_info = {"lines": clusters_, "l": l, "txt": txt} |
|
para.append(para_info) |
|
|
|
for word_crop in word_crop_collection: |
|
word_crops.remove(word_crop) |
|
return "\n".join([para[i]["txt"] for i in range(len(para))]) |
|
|
|
def process_image(image): |
|
""" |
|
Processes the uploaded image for text detection and recognition. |
|
- Detects bounding boxes in the image |
|
- Draws bounding boxes on the image and identifies script in each detected area |
|
- Recognizes text in each cropped region and returns the annotated image and recognized text |
|
|
|
Parameters: |
|
image (PIL.Image): The input image to be processed. |
|
|
|
Returns: |
|
tuple: A PIL.Image with bounding boxes and a string of recognized text. |
|
""" |
|
|
|
|
|
image_path = "input_image.jpg" |
|
image.save(image_path) |
|
|
|
|
|
detections = ocr.detect(image_path) |
|
|
|
|
|
ocr.visualize_detection(image_path, detections, save_path="output_image.png") |
|
|
|
|
|
output_image = Image.open("output_image.png") |
|
|
|
|
|
recognized_texts = {} |
|
pil_image = Image.open(image_path) |
|
script_lang = "english" |
|
|
|
for id,bbox in enumerate(detections): |
|
|
|
script_lang, cropped_path = ocr.crop_and_identify_script(pil_image, bbox) |
|
x1 = min([bbox[i][0] for i in range(len(bbox))]) |
|
y1 = min([bbox[i][1] for i in range(len(bbox))]) |
|
x2 = max([bbox[i][0] for i in range(len(bbox))]) |
|
y2 = max([bbox[i][1] for i in range(len(bbox))]) |
|
if script_lang: |
|
recognized_text = ocr.recognise(cropped_path,script_lang) |
|
recognized_texts[f"img_{id}"] = {"txt":recognized_text,"bbox":[x1,y1,x2,y2]} |
|
|
|
|
|
|
|
translated = translate(detect_para(recognized_texts),script_lang) |
|
|
|
|
|
|
|
return output_image,translated |
|
|
|
|
|
interface_html = """ |
|
<div style="text-align: left; padding: 10px;"> |
|
<div style="background-color: white; padding: 10px; display: inline-block;"> |
|
<img src="https://iitj.ac.in/images/logo/Design-of-New-Logo-of-IITJ-2.png" alt="IITJ Logo" style="width: 100px; height: 100px;"> |
|
</div> |
|
<img src="https://play-lh.googleusercontent.com/_FXSr4xmhPfBykmNJvKvC0GIAVJmOLhFl6RA5fobCjV-8zVSypxX8yb8ka6zu6-4TEft=w240-h480-rw" alt="Bhashini Logo" style="width: 100px; height: 100px; float: right;"> |
|
</div> |
|
""" |
|
|
|
|
|
|
|
|
|
links_html = """ |
|
<div style="text-align: center; padding-top: 20px;"> |
|
<a href="https://github.com/Bhashini-IITJ/visualTranslation" target="_blank" style="margin-right: 20px; font-size: 18px; text-decoration: none;"> |
|
GitHub Repository |
|
</a> |
|
<a href="https://vl2g.github.io/projects/visTrans" target="_blank" style="font-size: 18px; text-decoration: none;"> |
|
Project Page |
|
</a> |
|
</div> |
|
""" |
|
|
|
|
|
custom_css = """ |
|
.custom-textbox textarea { |
|
font-size: 20px !important; |
|
} |
|
""" |
|
|
|
|
|
seafoam = Seafoam() |
|
|
|
|
|
examples = [ |
|
["test_images/208.jpg"], |
|
["test_images/1310.jpg"] |
|
] |
|
|
|
title = "<h1 style='text-align: center;'>Developed by IITJ</h1>" |
|
|
|
|
|
demo = gr.Interface( |
|
allow_flagging="never", |
|
|
|
fn=process_image, |
|
inputs=gr.Image(type="pil", image_mode="RGB"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Detected Bounding Boxes"), |
|
gr.Textbox(label="Translated Text", elem_classes="custom-textbox") |
|
], |
|
title="IndicPhotoOCR - Indic Scene Text Recogniser Toolkit", |
|
description=title+interface_html+links_html, |
|
theme=seafoam, |
|
css=custom_css, |
|
examples=examples |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|