EzekielMW's picture
Create app.py
1558a49 verified
raw
history blame
5.78 kB
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi.middleware.cors import CORSMiddleware
import torch
import os
import yaml
import transformers
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this as needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")
# Where should output files be stored locally
drive_folder = "./serverlogs"
if not os.path.exists(drive_folder):
os.makedirs(drive_folder)
# Large batch sizes generally give good results for translation
effective_train_batch_size = 480
train_batch_size = 6
eval_batch_size = train_batch_size
gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size)
# Everything in one yaml string, so that it can all be logged.
yaml_config = '''
training_args:
output_dir: "{drive_folder}"
eval_strategy: steps
eval_steps: 100
save_steps: 100
gradient_accumulation_steps: {gradient_accumulation_steps}
learning_rate: 3.0e-4 # Include decimal point to parse as float
# optim: adafactor
per_device_train_batch_size: {train_batch_size}
per_device_eval_batch_size: {eval_batch_size}
weight_decay: 0.01
save_total_limit: 3
max_steps: 500
predict_with_generate: True
fp16: True
logging_dir: "{drive_folder}"
load_best_model_at_end: True
metric_for_best_model: loss
seed: 123
push_to_hub: False
max_input_length: 128
eval_pretrained_model: False
early_stopping_patience: 4
data_dir: .
# Use a 600M parameter model here, which is easier to train on a free Colab
# instance. Bigger models work better, however: results will be improved
# if able to train on nllb-200-1.3B instead.
model_checkpoint: facebook/nllb-200-distilled-600M
datasets:
train:
huggingface_load:
# We will load two datasets here: English/KSL Gloss, and also SALT
# Swahili/English, so that we can try out multi-way translation.
- path: EzekielMW/Eksl_dataset
split: train[:-1000]
- path: sunbird/salt
name: text-all
split: train
source:
# This is a text translation only, no audio.
type: text
# The source text can be any of English, KSL or Swahili.
language: [eng,ksl,swa]
preprocessing:
# The models are case sensitive, so if the training text is all
# capitals, then it will only learn to translate capital letters and
# won't understand lower case. Make everything lower case for now.
- lower_case
# We can also augment the spelling of the input text, which makes the
# model more robust to spelling errors.
- augment_characters
target:
type: text
# The target text with any of English, KSL or Swahili.
language: [eng,ksl,swa]
# The models are case sensitive: make everything lower case for now.
preprocessing:
- lower_case
shuffle: True
allow_same_src_and_tgt_language: False
validation:
huggingface_load:
# Use the last 500 of the KSL examples for validation.
- path: EzekielMW/Eksl_dataset
split: train[-1000:]
# Add some Swahili validation text.
- path: sunbird/salt
name: text-all
split: dev
source:
type: text
language: [swa,ksl,eng]
preprocessing:
- lower_case
target:
type: text
language: [swa,ksl,eng]
preprocessing:
- lower_case
allow_same_src_and_tgt_language: False
'''
yaml_config = yaml_config.format(
drive_folder=drive_folder,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
)
config = yaml.safe_load(yaml_config)
training_settings = transformers.Seq2SeqTrainingArguments(
**config["training_args"])
# The pre-trained model that we use has support for some African languages, but
# we need to adapt the tokenizer to languages that it wasn't trained with,
# such as KSL. Here we reuse the token from a different language.
LANGUAGE_CODES = ["eng", "swa", "ksl"]
code_mapping = {
# Exact/close mapping
'eng': 'eng_Latn',
'swa': 'swh_Latn',
# Random mapping
'ksl': 'ace_Latn',
}
tokenizer = transformers.NllbTokenizer.from_pretrained(
config['model_checkpoint'],
src_lang='eng_Latn',
tgt_lang='eng_Latn')
offset = tokenizer.sp_model_size + tokenizer.fairseq_offset
for code in LANGUAGE_CODES:
i = tokenizer.convert_tokens_to_ids(code_mapping[code])
tokenizer._added_tokens_encoder[code] = i
# Define a translation function
def translate(text, source_language, target_language):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
translated_tokens = model.to(device).generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
max_length=100,
num_beams=5,
)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
if target_language == 'ksl':
result = result.upper()
return result
@app.post("/translate")
async def translate_text(request: Request):
data = await request.json()
text = data.get("text")
source_language = data.get("source_language")
target_language = data.get("target_language")
translation = translate(text, source_language, target_language)
return {"translation": translation}