|
""" |
|
This module contains the helper functions to get the word alignment mapping between two sentences. |
|
""" |
|
|
|
import torch |
|
import itertools |
|
import transformers |
|
from transformers import logging |
|
|
|
|
|
logging.set_verbosity_warning() |
|
logging.set_verbosity_error() |
|
|
|
|
|
def select_model(model_name): |
|
""" |
|
Select Model |
|
""" |
|
if model_name == "Google-mBERT (Base-Multilingual)": |
|
model_name="bert-base-multilingual-cased" |
|
elif model_name == "Neulab-AwesomeAlign (Bn-En-0.5M)": |
|
model_name="musfiqdehan/bn-en-word-aligner" |
|
elif model_name == "BUET-BanglaBERT (Large)": |
|
model_name="csebuetnlp/banglabert_large" |
|
elif model_name == "SagorSarker-BanglaBERT (Base)": |
|
model_name="sagorsarker/bangla-bert-base" |
|
elif model_name == "SentenceTransformers-LaBSE (Multilingual)": |
|
model_name="sentence-transformers/LaBSE" |
|
|
|
return model_name |
|
|
|
|
|
def get_alignment_mapping(source="", target="", model_name=""): |
|
""" |
|
Get Aligned Words |
|
""" |
|
model_name = select_model(model_name) |
|
|
|
model = transformers.BertModel.from_pretrained(model_name) |
|
tokenizer = transformers.BertTokenizer.from_pretrained(model_name) |
|
|
|
|
|
sent_src, sent_tgt = source.strip().split(), target.strip().split() |
|
|
|
token_src, token_tgt = [tokenizer.tokenize(word) for word in sent_src], [ |
|
tokenizer.tokenize(word) for word in sent_tgt] |
|
|
|
wid_src, wid_tgt = [tokenizer.convert_tokens_to_ids(x) for x in token_src], [ |
|
tokenizer.convert_tokens_to_ids(x) for x in token_tgt] |
|
|
|
ids_src, ids_tgt = tokenizer.prepare_for_model(list(itertools.chain(*wid_src)), return_tensors='pt', model_max_length=tokenizer.model_max_length, truncation=True)['input_ids'], tokenizer.prepare_for_model(list(itertools.chain(*wid_tgt)), return_tensors='pt', truncation=True, model_max_length=tokenizer.model_max_length)['input_ids'] |
|
sub2word_map_src = [] |
|
|
|
for i, word_list in enumerate(token_src): |
|
sub2word_map_src += [i for x in word_list] |
|
|
|
sub2word_map_tgt = [] |
|
|
|
for i, word_list in enumerate(token_tgt): |
|
sub2word_map_tgt += [i for x in word_list] |
|
|
|
|
|
align_layer = 8 |
|
|
|
threshold = 1e-3 |
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
out_src = model(ids_src.unsqueeze(0), output_hidden_states=True)[ |
|
2][align_layer][0, 1:-1] |
|
out_tgt = model(ids_tgt.unsqueeze(0), output_hidden_states=True)[ |
|
2][align_layer][0, 1:-1] |
|
|
|
dot_prod = torch.matmul(out_src, out_tgt.transpose(-1, -2)) |
|
|
|
softmax_srctgt = torch.nn.Softmax(dim=-1)(dot_prod) |
|
softmax_tgtsrc = torch.nn.Softmax(dim=-2)(dot_prod) |
|
|
|
softmax_inter = (softmax_srctgt > threshold) * \ |
|
(softmax_tgtsrc > threshold) |
|
|
|
align_subwords = torch.nonzero(softmax_inter, as_tuple=False) |
|
|
|
align_words = set() |
|
|
|
for i, j in align_subwords: |
|
align_words.add((sub2word_map_src[i], sub2word_map_tgt[j])) |
|
|
|
return sent_src, sent_tgt, align_words |
|
|
|
|
|
|
|
def get_word_mapping(source="", target="", model_name=""): |
|
""" |
|
Get Word Aligned Mapping Words |
|
""" |
|
sent_src, sent_tgt, align_words = get_alignment_mapping( |
|
source=source, target=target, model_name=model_name) |
|
|
|
result = [] |
|
|
|
for i, j in sorted(align_words): |
|
result.append(f'bn:({sent_src[i]}) -> en:({sent_tgt[j]})') |
|
|
|
return result |
|
|
|
|
|
|
|
def get_word_index_mapping(source="", target="", model_name=""): |
|
""" |
|
Get Word Aligned Mapping Index |
|
""" |
|
sent_src, sent_tgt, align_words = get_alignment_mapping( |
|
source=source, target=target, model_name=model_name) |
|
|
|
result = [] |
|
|
|
for i, j in sorted(align_words): |
|
result.append(f'bn:({i}) -> en:({j})') |
|
|
|
return result |
|
|
|
|
|
def get_alignments_table( |
|
source="", |
|
target="", |
|
model_name=""): |
|
"""Get Spacy PoS Tags and return a Markdown table""" |
|
|
|
sent_src, sent_tgt, align_words = get_alignment_mapping( |
|
source=source, target=target, model_name=model_name |
|
) |
|
|
|
mapped_sent_src = [] |
|
|
|
html_table = ''' |
|
<table> |
|
<thead> |
|
<th>Source</th> |
|
<th>Target</th> |
|
</thead> |
|
''' |
|
|
|
for i, j in sorted(align_words): |
|
punc = r"""!()-[]{}।;:'"\,<>./?@#$%^&*_~""" |
|
if sent_src[i] in punc or sent_tgt[j] in punc: |
|
mapped_sent_src.append(sent_src[i]) |
|
|
|
html_table += f''' |
|
<tbody> |
|
<tr> |
|
<td> {sent_src[i]} </td> |
|
<td> {sent_tgt[j]} </td> |
|
</tr> |
|
''' |
|
else: |
|
mapped_sent_src.append(sent_src[i]) |
|
|
|
html_table += f''' |
|
<tr> |
|
<td> {sent_src[i]} </td> |
|
<td> {sent_tgt[j]} </td> |
|
</tr> |
|
''' |
|
|
|
unks = list(set(sent_src).difference(set(mapped_sent_src))) |
|
for word in unks: |
|
|
|
html_table += f''' |
|
<tr> |
|
<td> {word} </td> |
|
<td> N/A </td> |
|
</tr> |
|
''' |
|
|
|
html_table += ''' |
|
</tbody> |
|
</table> |
|
''' |
|
|
|
pos_accuracy = ((len(sent_src) - len(unks)) / len(sent_src)) |
|
pos_accuracy = f"{pos_accuracy:0.2%}" |
|
|
|
return html_table, pos_accuracy |
|
|