pollitoconpapass's picture
Add application file
528f18a
import time
import uvicorn
from typing import Dict
from fastapi import FastAPI
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
app = FastAPI()
def fix_tokenizer(tokenizer, new_lang='quz_Latn'):
"""
Add a new language token to the tokenizer vocabulary and update language mappings.
"""
# First ensure we're working with an NLLB tokenizer
if not hasattr(tokenizer, 'sp_model'):
raise ValueError("This function expects an NLLB tokenizer")
# Add the new language token if it's not already present
if new_lang not in tokenizer.additional_special_tokens:
tokenizer.add_special_tokens({
'additional_special_tokens': [new_lang]
})
# Initialize lang_code_to_id if it doesn't exist
if not hasattr(tokenizer, 'lang_code_to_id'):
tokenizer.lang_code_to_id = {}
# Add the new language to lang_code_to_id mapping
if new_lang not in tokenizer.lang_code_to_id:
# Get the ID for the new language token
new_lang_id = tokenizer.convert_tokens_to_ids(new_lang)
tokenizer.lang_code_to_id[new_lang] = new_lang_id
# Initialize id_to_lang_code if it doesn't exist
if not hasattr(tokenizer, 'id_to_lang_code'):
tokenizer.id_to_lang_code = {}
# Update the reverse mapping
tokenizer.id_to_lang_code[tokenizer.lang_code_to_id[new_lang]] = new_lang
return tokenizer
MODEL_URL = "pollitoconpapass/QnIA-translation-model"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL)
tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
fix_tokenizer(tokenizer)
# === HEALTH CHECK ===
@app.get("/_health")
async def health_check():
return {'status': 'ok'}
# === TRANSLATION ===
@app.post("/qnia-translate")
async def translate(data: Dict, a=32, b=3, max_input_length=1024, num_beams=4):
start = time.time()
text = data['text']
src_lang = data['src_lang']
tgt_lang = data['tgt_lang']
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length)
result = model.generate(
**inputs.to(model.device),
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
num_beams=num_beams,
)
translation = tokenizer.batch_decode(result, skip_special_tokens=True)
translation = translation[0]
end = time.time()
print(f"\nTime: {end - start}")
return {'translation': translation}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)