Sanzana Lora commited on
Commit
314dfd2
·
verified ·
1 Parent(s): 9de6d41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
6
+
7
+ # Load the mT5 model and tokenizer
8
+ model_name = "csebuetnlp/mT5_m2m_crossSum"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
+
12
+ get_lang_id = lambda lang: tokenizer._convert_token_to_id(
13
+ model.config.task_specific_params["langid_map"][lang][1]
14
+ )
15
+
16
+ # Function for cross-lingual summarization
17
+ def cross_lingual_summarization(source_language, target_language, article_text):
18
+ input_ids = tokenizer(
19
+ [WHITESPACE_HANDLER(article_text)],
20
+ return_tensors="pt",
21
+ padding="max_length",
22
+ truncation=True,
23
+ max_length=512
24
+ )["input_ids"]
25
+
26
+ output_ids = model.generate(
27
+ input_ids=input_ids,
28
+ decoder_start_token_id=get_lang_id(target_language),
29
+ max_length=84,
30
+ no_repeat_ngram_size=2,
31
+ num_beams=4,
32
+ )[0]
33
+
34
+ summary = tokenizer.decode(
35
+ output_ids,
36
+ skip_special_tokens=True,
37
+ clean_up_tokenization_spaces=False
38
+ )
39
+
40
+
41
+ return {
42
+ 'source_language': source_language,
43
+ 'target_language': target_language,
44
+ 'original_article': article_text,
45
+ 'summary': summary
46
+ }
47
+
48
+ # Gradio Interface
49
+ iface = gr.Interface(
50
+ fn=cross_lingual_summarization,
51
+ inputs=[
52
+ gr.Dropdown(['amharic', 'arabic', 'azerbaijani', 'bengali', 'burmese', 'chinese_simplified', 'chinese_traditional',
53
+ 'english', 'french', 'gujarati', 'hausa', 'hindi', 'igbo', 'indonesian', 'japanese', 'kirundi',
54
+ 'korean', 'kyrgyz', 'marathi', 'nepali', 'oromo', 'pashto', 'persian', 'pidgin', 'portuguese',
55
+ 'punjabi', 'russian', 'scottish_gaelic', 'serbian_cyrillic', 'serbian_latin', 'sinhala', 'somali',
56
+ 'spanish', 'swahili', 'tamil', 'telugu', 'thai', 'tigrinya', 'turkish', 'ukrainian', 'urdu', 'uzbek',
57
+ 'vietnamese', 'welsh', 'yoruba'], label='Source Language'),
58
+ gr.Dropdown(['amharic', 'arabic', 'azerbaijani', 'bengali', 'burmese', 'chinese_simplified', 'chinese_traditional',
59
+ 'english', 'french', 'gujarati', 'hausa', 'hindi', 'igbo', 'indonesian', 'japanese', 'kirundi',
60
+ 'korean', 'kyrgyz', 'marathi', 'nepali', 'oromo', 'pashto', 'persian', 'pidgin', 'portuguese',
61
+ 'punjabi', 'russian', 'scottish_gaelic', 'serbian_cyrillic', 'serbian_latin', 'sinhala', 'somali',
62
+ 'spanish', 'swahili', 'tamil', 'telugu', 'thai', 'tigrinya', 'turkish', 'ukrainian', 'urdu', 'uzbek',
63
+ 'vietnamese', 'welsh', 'yoruba'], label='Target Language'),
64
+ gr.Textbox(label='Article Text')
65
+ ],
66
+ outputs=[
67
+ gr.Textbox(label='Original Article'),
68
+ gr.Textbox(label='Summary')
69
+ ],
70
+ live=False,
71
+ title="Cross-Lingual Summarization"
72
+ )
73
+
74
+ # Launch the Gradio app
75
+ iface.launch(inline=False)