|
import time |
|
import uvicorn |
|
from typing import Dict |
|
from fastapi import FastAPI |
|
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM |
|
|
|
app = FastAPI() |
|
|
|
def fix_tokenizer(tokenizer, new_lang='quz_Latn'): |
|
""" |
|
Add a new language token to the tokenizer vocabulary and update language mappings. |
|
""" |
|
|
|
if not hasattr(tokenizer, 'sp_model'): |
|
raise ValueError("This function expects an NLLB tokenizer") |
|
|
|
|
|
if new_lang not in tokenizer.additional_special_tokens: |
|
tokenizer.add_special_tokens({ |
|
'additional_special_tokens': [new_lang] |
|
}) |
|
|
|
|
|
if not hasattr(tokenizer, 'lang_code_to_id'): |
|
tokenizer.lang_code_to_id = {} |
|
|
|
|
|
if new_lang not in tokenizer.lang_code_to_id: |
|
|
|
new_lang_id = tokenizer.convert_tokens_to_ids(new_lang) |
|
tokenizer.lang_code_to_id[new_lang] = new_lang_id |
|
|
|
|
|
if not hasattr(tokenizer, 'id_to_lang_code'): |
|
tokenizer.id_to_lang_code = {} |
|
|
|
|
|
tokenizer.id_to_lang_code[tokenizer.lang_code_to_id[new_lang]] = new_lang |
|
|
|
return tokenizer |
|
|
|
|
|
MODEL_URL = "pollitoconpapass/QnIA-translation-model" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL) |
|
tokenizer = NllbTokenizer.from_pretrained(MODEL_URL) |
|
fix_tokenizer(tokenizer) |
|
|
|
|
|
|
|
@app.get("/_health") |
|
async def health_check(): |
|
return {'status': 'ok'} |
|
|
|
|
|
|
|
@app.post("/qnia-translate") |
|
async def translate(data: Dict, a=32, b=3, max_input_length=1024, num_beams=4): |
|
start = time.time() |
|
|
|
text = data['text'] |
|
src_lang = data['src_lang'] |
|
tgt_lang = data['tgt_lang'] |
|
|
|
tokenizer.src_lang = src_lang |
|
tokenizer.tgt_lang = tgt_lang |
|
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length) |
|
|
|
result = model.generate( |
|
**inputs.to(model.device), |
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang), |
|
max_new_tokens=int(a + b * inputs.input_ids.shape[1]), |
|
num_beams=num_beams, |
|
) |
|
|
|
translation = tokenizer.batch_decode(result, skip_special_tokens=True) |
|
translation = translation[0] |
|
|
|
end = time.time() |
|
print(f"\nTime: {end - start}") |
|
|
|
return {'translation': translation} |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|