Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import os | |
import yaml | |
import transformers | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Adjust this as needed | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset") | |
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset") | |
# Where should output files be stored locally | |
drive_folder = "./serverlogs" | |
if not os.path.exists(drive_folder): | |
os.makedirs(drive_folder) | |
# Large batch sizes generally give good results for translation | |
effective_train_batch_size = 480 | |
train_batch_size = 6 | |
eval_batch_size = train_batch_size | |
gradient_accumulation_steps = int(effective_train_batch_size / train_batch_size) | |
# Everything in one yaml string, so that it can all be logged. | |
yaml_config = ''' | |
training_args: | |
output_dir: "{drive_folder}" | |
eval_strategy: steps | |
eval_steps: 100 | |
save_steps: 100 | |
gradient_accumulation_steps: {gradient_accumulation_steps} | |
learning_rate: 3.0e-4 # Include decimal point to parse as float | |
# optim: adafactor | |
per_device_train_batch_size: {train_batch_size} | |
per_device_eval_batch_size: {eval_batch_size} | |
weight_decay: 0.01 | |
save_total_limit: 3 | |
max_steps: 500 | |
predict_with_generate: True | |
fp16: True | |
logging_dir: "{drive_folder}" | |
load_best_model_at_end: True | |
metric_for_best_model: loss | |
seed: 123 | |
push_to_hub: False | |
max_input_length: 128 | |
eval_pretrained_model: False | |
early_stopping_patience: 4 | |
data_dir: . | |
# Use a 600M parameter model here, which is easier to train on a free Colab | |
# instance. Bigger models work better, however: results will be improved | |
# if able to train on nllb-200-1.3B instead. | |
model_checkpoint: facebook/nllb-200-distilled-600M | |
datasets: | |
train: | |
huggingface_load: | |
# We will load two datasets here: English/KSL Gloss, and also SALT | |
# Swahili/English, so that we can try out multi-way translation. | |
- path: EzekielMW/Eksl_dataset | |
split: train[:-1000] | |
- path: sunbird/salt | |
name: text-all | |
split: train | |
source: | |
# This is a text translation only, no audio. | |
type: text | |
# The source text can be any of English, KSL or Swahili. | |
language: [eng,ksl,swa] | |
preprocessing: | |
# The models are case sensitive, so if the training text is all | |
# capitals, then it will only learn to translate capital letters and | |
# won't understand lower case. Make everything lower case for now. | |
- lower_case | |
# We can also augment the spelling of the input text, which makes the | |
# model more robust to spelling errors. | |
- augment_characters | |
target: | |
type: text | |
# The target text with any of English, KSL or Swahili. | |
language: [eng,ksl,swa] | |
# The models are case sensitive: make everything lower case for now. | |
preprocessing: | |
- lower_case | |
shuffle: True | |
allow_same_src_and_tgt_language: False | |
validation: | |
huggingface_load: | |
# Use the last 500 of the KSL examples for validation. | |
- path: EzekielMW/Eksl_dataset | |
split: train[-1000:] | |
# Add some Swahili validation text. | |
- path: sunbird/salt | |
name: text-all | |
split: dev | |
source: | |
type: text | |
language: [swa,ksl,eng] | |
preprocessing: | |
- lower_case | |
target: | |
type: text | |
language: [swa,ksl,eng] | |
preprocessing: | |
- lower_case | |
allow_same_src_and_tgt_language: False | |
''' | |
yaml_config = yaml_config.format( | |
drive_folder=drive_folder, | |
train_batch_size=train_batch_size, | |
eval_batch_size=eval_batch_size, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
) | |
config = yaml.safe_load(yaml_config) | |
training_settings = transformers.Seq2SeqTrainingArguments( | |
**config["training_args"]) | |
# The pre-trained model that we use has support for some African languages, but | |
# we need to adapt the tokenizer to languages that it wasn't trained with, | |
# such as KSL. Here we reuse the token from a different language. | |
LANGUAGE_CODES = ["eng", "swa", "ksl"] | |
code_mapping = { | |
# Exact/close mapping | |
'eng': 'eng_Latn', | |
'swa': 'swh_Latn', | |
# Random mapping | |
'ksl': 'ace_Latn', | |
} | |
tokenizer = transformers.NllbTokenizer.from_pretrained( | |
config['model_checkpoint'], | |
src_lang='eng_Latn', | |
tgt_lang='eng_Latn') | |
offset = tokenizer.sp_model_size + tokenizer.fairseq_offset | |
for code in LANGUAGE_CODES: | |
i = tokenizer.convert_tokens_to_ids(code_mapping[code]) | |
tokenizer._added_tokens_encoder[code] = i | |
# Define a translation function | |
def translate(text, source_language, target_language): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
inputs = tokenizer(text.lower(), return_tensors="pt").to(device) | |
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language) | |
translated_tokens = model.to(device).generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language), | |
max_length=100, | |
num_beams=5, | |
) | |
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
if target_language == 'ksl': | |
result = result.upper() | |
return result | |
async def translate_text(request: Request): | |
data = await request.json() | |
text = data.get("text") | |
source_language = data.get("source_language") | |
target_language = data.get("target_language") | |
translation = translate(text, source_language, target_language) | |
return {"translation": translation} | |