File size: 5,778 Bytes
1558a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
import torch
import os
import yaml
import transformers

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust this as needed
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")

# Where should output files be stored locally
drive_folder = "./serverlogs"

if not os.path.exists(drive_folder):
    os.makedirs(drive_folder)


# Large batch sizes generally give good results for translation
effective_train_batch_size = 480
train_batch_size = 6
eval_batch_size = train_batch_size

gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)

# Everything in one yaml string, so that it can all be logged.
yaml_config = '''
training_args:
  output_dir: "{drive_folder}"
  eval_strategy: steps
  eval_steps: 100
  save_steps: 100
  gradient_accumulation_steps: {gradient_accumulation_steps}
  learning_rate: 3.0e-4  # Include decimal point to parse as float
  # optim: adafactor
  per_device_train_batch_size: {train_batch_size}
  per_device_eval_batch_size: {eval_batch_size}
  weight_decay: 0.01
  save_total_limit: 3
  max_steps: 500
  predict_with_generate: True
  fp16: True
  logging_dir: "{drive_folder}"
  load_best_model_at_end: True
  metric_for_best_model: loss
  seed: 123
  push_to_hub: False

max_input_length: 128
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .

# Use a 600M parameter model here, which is easier to train on a free Colab
# instance. Bigger models work better, however: results will be improved
# if able to train on nllb-200-1.3B instead.
model_checkpoint: facebook/nllb-200-distilled-600M

datasets:
  train:
    huggingface_load:
      # We will load two datasets here: English/KSL Gloss, and also SALT
      # Swahili/English, so that we can try out multi-way translation.

      - path: EzekielMW/Eksl_dataset
        split: train[:-1000]
      - path: sunbird/salt
        name: text-all
        split: train
    source:
      # This is a text translation only, no audio.
      type: text
      # The source text can be any of English, KSL or Swahili.
      language: [eng,ksl,swa]
      preprocessing:
        # The models are case sensitive, so if the training text is all
        # capitals, then it will only learn to translate capital letters and
        # won't understand lower case. Make everything lower case for now.
        - lower_case
        # We can also augment the spelling of the input text, which makes the
        # model more robust to spelling errors.
        - augment_characters
    target:
      type: text
      # The target text with any of English, KSL or Swahili.
      language: [eng,ksl,swa]
      # The models are case sensitive: make everything lower case for now.
      preprocessing:
        - lower_case

    shuffle: True
    allow_same_src_and_tgt_language: False

  validation:
    huggingface_load:
      # Use the last 500 of the KSL examples for validation.
      - path: EzekielMW/Eksl_dataset
        split: train[-1000:]
      # Add some Swahili validation text.
      - path: sunbird/salt
        name: text-all
        split: dev
    source:
      type: text
      language: [swa,ksl,eng]
      preprocessing:
        - lower_case
    target:
      type: text
      language: [swa,ksl,eng]
      preprocessing:
        - lower_case
    allow_same_src_and_tgt_language: False
'''

yaml_config = yaml_config.format(
    drive_folder=drive_folder,
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
)

config = yaml.safe_load(yaml_config)

training_settings = transformers.Seq2SeqTrainingArguments(
    **config["training_args"])
# The pre-trained model that we use has support for some African languages, but
# we need to adapt the tokenizer to languages that it wasn't trained with,
# such as KSL. Here we reuse the token from a different language.
LANGUAGE_CODES = ["eng", "swa", "ksl"]

code_mapping = {
    # Exact/close mapping
    'eng': 'eng_Latn',
    'swa': 'swh_Latn',
    # Random mapping
    'ksl': 'ace_Latn',
}
tokenizer = transformers.NllbTokenizer.from_pretrained(
    config['model_checkpoint'],
    src_lang='eng_Latn',
    tgt_lang='eng_Latn')

offset = tokenizer.sp_model_size + tokenizer.fairseq_offset

for code in LANGUAGE_CODES:
    i = tokenizer.convert_tokens_to_ids(code_mapping[code])
    tokenizer._added_tokens_encoder[code] = i

# Define a translation function
def translate(text, source_language, target_language):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
    inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
    translated_tokens = model.to(device).generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
        max_length=100,
        num_beams=5,
    )
    result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

    if target_language == 'ksl':
        result = result.upper()

    return result

@app.post("/translate")
async def translate_text(request: Request):
    data = await request.json()
    text = data.get("text")
    source_language = data.get("source_language")
    target_language = data.get("target_language")
    
    translation = translate(text, source_language, target_language)
    return {"translation": translation}