|
import re |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip())) |
|
|
|
|
|
model_name = "csebuetnlp/mT5_m2m_crossSum" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
get_lang_id = lambda lang: tokenizer._convert_token_to_id( |
|
model.config.task_specific_params["langid_map"][lang][1] |
|
) |
|
|
|
|
|
def cross_lingual_summarization(source_language, target_language, article_text): |
|
input_ids = tokenizer( |
|
[WHITESPACE_HANDLER(article_text)], |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=512 |
|
)["input_ids"] |
|
|
|
output_ids = model.generate( |
|
input_ids=input_ids, |
|
decoder_start_token_id=get_lang_id(target_language), |
|
max_length=84, |
|
no_repeat_ngram_size=2, |
|
num_beams=4, |
|
)[0] |
|
|
|
summary = tokenizer.decode( |
|
output_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
) |
|
|
|
|
|
return { |
|
'source_language': source_language, |
|
'target_language': target_language, |
|
'original_article': article_text, |
|
'summary': summary |
|
} |
|
|
|
|
|
iface = gr.Interface( |
|
fn=cross_lingual_summarization, |
|
inputs=[ |
|
gr.Dropdown(['amharic', 'arabic', 'azerbaijani', 'bengali', 'burmese', 'chinese_simplified', 'chinese_traditional', |
|
'english', 'french', 'gujarati', 'hausa', 'hindi', 'igbo', 'indonesian', 'japanese', 'kirundi', |
|
'korean', 'kyrgyz', 'marathi', 'nepali', 'oromo', 'pashto', 'persian', 'pidgin', 'portuguese', |
|
'punjabi', 'russian', 'scottish_gaelic', 'serbian_cyrillic', 'serbian_latin', 'sinhala', 'somali', |
|
'spanish', 'swahili', 'tamil', 'telugu', 'thai', 'tigrinya', 'turkish', 'ukrainian', 'urdu', 'uzbek', |
|
'vietnamese', 'welsh', 'yoruba'], label='Source Language'), |
|
gr.Dropdown(['amharic', 'arabic', 'azerbaijani', 'bengali', 'burmese', 'chinese_simplified', 'chinese_traditional', |
|
'english', 'french', 'gujarati', 'hausa', 'hindi', 'igbo', 'indonesian', 'japanese', 'kirundi', |
|
'korean', 'kyrgyz', 'marathi', 'nepali', 'oromo', 'pashto', 'persian', 'pidgin', 'portuguese', |
|
'punjabi', 'russian', 'scottish_gaelic', 'serbian_cyrillic', 'serbian_latin', 'sinhala', 'somali', |
|
'spanish', 'swahili', 'tamil', 'telugu', 'thai', 'tigrinya', 'turkish', 'ukrainian', 'urdu', 'uzbek', |
|
'vietnamese', 'welsh', 'yoruba'], label='Target Language'), |
|
gr.Textbox(label='Article Text') |
|
], |
|
outputs=[ |
|
gr.Textbox(label='Original Article'), |
|
gr.Textbox(label='Summary') |
|
], |
|
live=False, |
|
title="Cross-Lingual Summarization" |
|
) |
|
|
|
|
|
iface.launch(inline=False) |
|
|