nisheeth commited on
Commit
6aba38a
·
verified ·
1 Parent(s): 6b680ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {
12
+ 'nllb-3.3B': 'facebook/nllb-200-3.3B',
13
+ #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
14
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
15
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
16
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
17
+ # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
18
+ }
19
+
20
+ model_dict = {}
21
+
22
+ for call_name, real_name in model_name_dict.items():
23
+ print('\tLoading model: %s' % call_name)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
25
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
26
+ model_dict[call_name+'_model'] = model
27
+ model_dict[call_name+'_tokenizer'] = tokenizer
28
+
29
+ return model_dict
30
+
31
+
32
+ def translation(source, target, text):
33
+ if len(model_dict) == 2:
34
+ model_name = 'nllb-3.3B'
35
+
36
+ start_time = time.time()
37
+ source = flores_codes[source]
38
+ target = flores_codes[target]
39
+
40
+ model = model_dict[model_name + '_model']
41
+ tokenizer = model_dict[model_name + '_tokenizer']
42
+
43
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
44
+ output = translator(text, max_length=400)
45
+
46
+ end_time = time.time()
47
+
48
+ output = output[0]['translation_text']
49
+ result = {'inference_time': end_time - start_time,
50
+ 'source': source,
51
+ 'target': target,
52
+ 'result': output}
53
+ return result
54
+
55
+
56
+ if __name__ == '__main__':
57
+ print('\tinit models')
58
+
59
+ global model_dict
60
+
61
+ model_dict = load_models()
62
+
63
+ # define gradio demo
64
+ lang_codes = list(flores_codes.keys())
65
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
66
+ inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
67
+ gr.inputs.Dropdown(lang_codes, default='Hindi', label='Target'),
68
+ gr.inputs.Textbox(lines=5, label="Input text"),
69
+ ]
70
+
71
+ outputs = gr.outputs.JSON()
72
+
73
+ title = "Machine Translation Demo"
74
+
75
+ demo_status = "Demo is running on CPU"
76
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
77
+ examples = [
78
+ ['English', 'Hindi', 'Hi. nice to meet you']
79
+ ]
80
+
81
+ gr.Interface(translation,
82
+ inputs,
83
+ outputs,
84
+ title=title,
85
+ description=description,
86
+ ).launch()