nisheeth commited on
Commit
6b680ee
·
verified ·
1 Parent(s): 6198599

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -88
app.py DELETED
@@ -1,88 +0,0 @@
1
- import os
2
- import torch
3
- import gradio as gr
4
- import time
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- from flores200_codes import flores_codes
7
-
8
-
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {
12
- 'nllb-3.3B': 'facebook/nllb-200-3.3B',
13
- #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
14
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
15
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
16
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
17
- # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
18
- }
19
-
20
- model_dict = {}
21
-
22
- for call_name, real_name in model_name_dict.items():
23
- print('\tLoading model: %s' % call_name)
24
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
25
- tokenizer = AutoTokenizer.from_pretrained(real_name)
26
- model_dict[call_name+'_model'] = model
27
- model_dict[call_name+'_tokenizer'] = tokenizer
28
-
29
- return model_dict
30
-
31
-
32
- def translation(source, target, text):
33
- if len(model_dict) == 2:
34
- model_name = 'nllb-distilled-1.3B'
35
-
36
- start_time = time.time()
37
- source = flores_codes[source]
38
- target = flores_codes[target]
39
-
40
- model = model_dict[model_name + '_model']
41
- tokenizer = model_dict[model_name + '_tokenizer']
42
-
43
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
44
- output = translator(text, max_length=400)
45
-
46
- end_time = time.time()
47
-
48
- output = output[0]['translation_text']
49
- result = {'inference_time': end_time - start_time,
50
- 'source': source,
51
- 'target': target,
52
- 'result': output}
53
- return result
54
-
55
-
56
- if __name__ == '__main__':
57
- print('\tinit models')
58
-
59
- global model_dict
60
-
61
- model_dict = load_models()
62
-
63
- # define gradio demo
64
- lang_codes = list(flores_codes.keys())
65
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
66
- inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
67
- gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
68
- gr.inputs.Textbox(lines=5, label="Input text"),
69
- ]
70
-
71
- outputs = gr.outputs.JSON()
72
-
73
- title = "Machine Translation Demo"
74
-
75
- demo_status = "Demo is running on CPU"
76
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
77
- examples = [
78
- ['English', 'Hindi', 'Hi. nice to meet you']
79
- ]
80
-
81
- gr.Interface(translation,
82
- inputs,
83
- outputs,
84
- title=title,
85
- description=description,
86
- ).launch()
87
-
88
-