albertvillanova HF staff commited on
Commit
6f8b6a0
·
verified ·
1 Parent(s): bec611f

Create tool.py

Browse files
Files changed (1) hide show
  1. tool.py +279 -0
tool.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Source: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/agents/translation.py
18
+
19
+ from transformers.models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
20
+
21
+ from smolagents.tools import PipelineTool, Tool
22
+
23
+ LANGUAGE_CODES = {
24
+ "Acehnese Arabic": "ace_Arab",
25
+ "Acehnese Latin": "ace_Latn",
26
+ "Mesopotamian Arabic": "acm_Arab",
27
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
28
+ "Tunisian Arabic": "aeb_Arab",
29
+ "Afrikaans": "afr_Latn",
30
+ "South Levantine Arabic": "ajp_Arab",
31
+ "Akan": "aka_Latn",
32
+ "Amharic": "amh_Ethi",
33
+ "North Levantine Arabic": "apc_Arab",
34
+ "Modern Standard Arabic": "arb_Arab",
35
+ "Modern Standard Arabic Romanized": "arb_Latn",
36
+ "Najdi Arabic": "ars_Arab",
37
+ "Moroccan Arabic": "ary_Arab",
38
+ "Egyptian Arabic": "arz_Arab",
39
+ "Assamese": "asm_Beng",
40
+ "Asturian": "ast_Latn",
41
+ "Awadhi": "awa_Deva",
42
+ "Central Aymara": "ayr_Latn",
43
+ "South Azerbaijani": "azb_Arab",
44
+ "North Azerbaijani": "azj_Latn",
45
+ "Bashkir": "bak_Cyrl",
46
+ "Bambara": "bam_Latn",
47
+ "Balinese": "ban_Latn",
48
+ "Belarusian": "bel_Cyrl",
49
+ "Bemba": "bem_Latn",
50
+ "Bengali": "ben_Beng",
51
+ "Bhojpuri": "bho_Deva",
52
+ "Banjar Arabic": "bjn_Arab",
53
+ "Banjar Latin": "bjn_Latn",
54
+ "Standard Tibetan": "bod_Tibt",
55
+ "Bosnian": "bos_Latn",
56
+ "Buginese": "bug_Latn",
57
+ "Bulgarian": "bul_Cyrl",
58
+ "Catalan": "cat_Latn",
59
+ "Cebuano": "ceb_Latn",
60
+ "Czech": "ces_Latn",
61
+ "Chokwe": "cjk_Latn",
62
+ "Central Kurdish": "ckb_Arab",
63
+ "Crimean Tatar": "crh_Latn",
64
+ "Welsh": "cym_Latn",
65
+ "Danish": "dan_Latn",
66
+ "German": "deu_Latn",
67
+ "Southwestern Dinka": "dik_Latn",
68
+ "Dyula": "dyu_Latn",
69
+ "Dzongkha": "dzo_Tibt",
70
+ "Greek": "ell_Grek",
71
+ "English": "eng_Latn",
72
+ "Esperanto": "epo_Latn",
73
+ "Estonian": "est_Latn",
74
+ "Basque": "eus_Latn",
75
+ "Ewe": "ewe_Latn",
76
+ "Faroese": "fao_Latn",
77
+ "Fijian": "fij_Latn",
78
+ "Finnish": "fin_Latn",
79
+ "Fon": "fon_Latn",
80
+ "French": "fra_Latn",
81
+ "Friulian": "fur_Latn",
82
+ "Nigerian Fulfulde": "fuv_Latn",
83
+ "Scottish Gaelic": "gla_Latn",
84
+ "Irish": "gle_Latn",
85
+ "Galician": "glg_Latn",
86
+ "Guarani": "grn_Latn",
87
+ "Gujarati": "guj_Gujr",
88
+ "Haitian Creole": "hat_Latn",
89
+ "Hausa": "hau_Latn",
90
+ "Hebrew": "heb_Hebr",
91
+ "Hindi": "hin_Deva",
92
+ "Chhattisgarhi": "hne_Deva",
93
+ "Croatian": "hrv_Latn",
94
+ "Hungarian": "hun_Latn",
95
+ "Armenian": "hye_Armn",
96
+ "Igbo": "ibo_Latn",
97
+ "Ilocano": "ilo_Latn",
98
+ "Indonesian": "ind_Latn",
99
+ "Icelandic": "isl_Latn",
100
+ "Italian": "ita_Latn",
101
+ "Javanese": "jav_Latn",
102
+ "Japanese": "jpn_Jpan",
103
+ "Kabyle": "kab_Latn",
104
+ "Jingpho": "kac_Latn",
105
+ "Kamba": "kam_Latn",
106
+ "Kannada": "kan_Knda",
107
+ "Kashmiri Arabic": "kas_Arab",
108
+ "Kashmiri Devanagari": "kas_Deva",
109
+ "Georgian": "kat_Geor",
110
+ "Central Kanuri Arabic": "knc_Arab",
111
+ "Central Kanuri Latin": "knc_Latn",
112
+ "Kazakh": "kaz_Cyrl",
113
+ "Kabiyè": "kbp_Latn",
114
+ "Kabuverdianu": "kea_Latn",
115
+ "Khmer": "khm_Khmr",
116
+ "Kikuyu": "kik_Latn",
117
+ "Kinyarwanda": "kin_Latn",
118
+ "Kyrgyz": "kir_Cyrl",
119
+ "Kimbundu": "kmb_Latn",
120
+ "Northern Kurdish": "kmr_Latn",
121
+ "Kikongo": "kon_Latn",
122
+ "Korean": "kor_Hang",
123
+ "Lao": "lao_Laoo",
124
+ "Ligurian": "lij_Latn",
125
+ "Limburgish": "lim_Latn",
126
+ "Lingala": "lin_Latn",
127
+ "Lithuanian": "lit_Latn",
128
+ "Lombard": "lmo_Latn",
129
+ "Latgalian": "ltg_Latn",
130
+ "Luxembourgish": "ltz_Latn",
131
+ "Luba-Kasai": "lua_Latn",
132
+ "Ganda": "lug_Latn",
133
+ "Luo": "luo_Latn",
134
+ "Mizo": "lus_Latn",
135
+ "Standard Latvian": "lvs_Latn",
136
+ "Magahi": "mag_Deva",
137
+ "Maithili": "mai_Deva",
138
+ "Malayalam": "mal_Mlym",
139
+ "Marathi": "mar_Deva",
140
+ "Minangkabau Arabic ": "min_Arab",
141
+ "Minangkabau Latin": "min_Latn",
142
+ "Macedonian": "mkd_Cyrl",
143
+ "Plateau Malagasy": "plt_Latn",
144
+ "Maltese": "mlt_Latn",
145
+ "Meitei Bengali": "mni_Beng",
146
+ "Halh Mongolian": "khk_Cyrl",
147
+ "Mossi": "mos_Latn",
148
+ "Maori": "mri_Latn",
149
+ "Burmese": "mya_Mymr",
150
+ "Dutch": "nld_Latn",
151
+ "Norwegian Nynorsk": "nno_Latn",
152
+ "Norwegian Bokmål": "nob_Latn",
153
+ "Nepali": "npi_Deva",
154
+ "Northern Sotho": "nso_Latn",
155
+ "Nuer": "nus_Latn",
156
+ "Nyanja": "nya_Latn",
157
+ "Occitan": "oci_Latn",
158
+ "West Central Oromo": "gaz_Latn",
159
+ "Odia": "ory_Orya",
160
+ "Pangasinan": "pag_Latn",
161
+ "Eastern Panjabi": "pan_Guru",
162
+ "Papiamento": "pap_Latn",
163
+ "Western Persian": "pes_Arab",
164
+ "Polish": "pol_Latn",
165
+ "Portuguese": "por_Latn",
166
+ "Dari": "prs_Arab",
167
+ "Southern Pashto": "pbt_Arab",
168
+ "Ayacucho Quechua": "quy_Latn",
169
+ "Romanian": "ron_Latn",
170
+ "Rundi": "run_Latn",
171
+ "Russian": "rus_Cyrl",
172
+ "Sango": "sag_Latn",
173
+ "Sanskrit": "san_Deva",
174
+ "Santali": "sat_Olck",
175
+ "Sicilian": "scn_Latn",
176
+ "Shan": "shn_Mymr",
177
+ "Sinhala": "sin_Sinh",
178
+ "Slovak": "slk_Latn",
179
+ "Slovenian": "slv_Latn",
180
+ "Samoan": "smo_Latn",
181
+ "Shona": "sna_Latn",
182
+ "Sindhi": "snd_Arab",
183
+ "Somali": "som_Latn",
184
+ "Southern Sotho": "sot_Latn",
185
+ "Spanish": "spa_Latn",
186
+ "Tosk Albanian": "als_Latn",
187
+ "Sardinian": "srd_Latn",
188
+ "Serbian": "srp_Cyrl",
189
+ "Swati": "ssw_Latn",
190
+ "Sundanese": "sun_Latn",
191
+ "Swedish": "swe_Latn",
192
+ "Swahili": "swh_Latn",
193
+ "Silesian": "szl_Latn",
194
+ "Tamil": "tam_Taml",
195
+ "Tatar": "tat_Cyrl",
196
+ "Telugu": "tel_Telu",
197
+ "Tajik": "tgk_Cyrl",
198
+ "Tagalog": "tgl_Latn",
199
+ "Thai": "tha_Thai",
200
+ "Tigrinya": "tir_Ethi",
201
+ "Tamasheq Latin": "taq_Latn",
202
+ "Tamasheq Tifinagh": "taq_Tfng",
203
+ "Tok Pisin": "tpi_Latn",
204
+ "Tswana": "tsn_Latn",
205
+ "Tsonga": "tso_Latn",
206
+ "Turkmen": "tuk_Latn",
207
+ "Tumbuka": "tum_Latn",
208
+ "Turkish": "tur_Latn",
209
+ "Twi": "twi_Latn",
210
+ "Central Atlas Tamazight": "tzm_Tfng",
211
+ "Uyghur": "uig_Arab",
212
+ "Ukrainian": "ukr_Cyrl",
213
+ "Umbundu": "umb_Latn",
214
+ "Urdu": "urd_Arab",
215
+ "Northern Uzbek": "uzn_Latn",
216
+ "Venetian": "vec_Latn",
217
+ "Vietnamese": "vie_Latn",
218
+ "Waray": "war_Latn",
219
+ "Wolof": "wol_Latn",
220
+ "Xhosa": "xho_Latn",
221
+ "Eastern Yiddish": "ydd_Hebr",
222
+ "Yoruba": "yor_Latn",
223
+ "Yue Chinese": "yue_Hant",
224
+ "Chinese Simplified": "zho_Hans",
225
+ "Chinese Traditional": "zho_Hant",
226
+ "Standard Malay": "zsm_Latn",
227
+ "Zulu": "zul_Latn",
228
+ }
229
+
230
+
231
+ class TranslationTool(PipelineTool):
232
+ """
233
+ Example:
234
+
235
+ ```py
236
+ translator = TranslationTool()
237
+ translator("This is a super nice API!", src_lang="English", tgt_lang="French")
238
+ ```
239
+ """
240
+
241
+ lang_to_code = LANGUAGE_CODES
242
+ default_checkpoint = "facebook/nllb-200-distilled-600M"
243
+ description = (
244
+ "This is a tool that translates text from a language to another."
245
+ f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
246
+ )
247
+ name = "translator"
248
+ pre_processor_class = AutoTokenizer
249
+ model_class = AutoModelForSeq2SeqLM
250
+
251
+ inputs = {
252
+ "text": {"type": "string", "description": "The text to translate"},
253
+ "src_lang": {
254
+ "type": "string",
255
+ "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
256
+ },
257
+ "tgt_lang": {
258
+ "type": "string",
259
+ "description": "The language for the desired output language. Written in plain English, such as 'Romanian', or 'Albanian'",
260
+ },
261
+ }
262
+ output_type = "string"
263
+
264
+ def encode(self, text, src_lang, tgt_lang):
265
+ if src_lang not in self.lang_to_code:
266
+ raise ValueError(f"{src_lang} is not a supported language.")
267
+ if tgt_lang not in self.lang_to_code:
268
+ raise ValueError(f"{tgt_lang} is not a supported language.")
269
+ src_lang = self.lang_to_code[src_lang]
270
+ tgt_lang = self.lang_to_code[tgt_lang]
271
+ return self.pre_processor._build_translation_inputs(
272
+ text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
273
+ )
274
+
275
+ def decode(self, outputs):
276
+ return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
277
+
278
+ def forward(self, inputs):
279
+ return self.model.generate(**inputs)