Spaces:
Sleeping
Sleeping
File size: 4,252 Bytes
b805057 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from fastapi import APIRouter
from pydantic import BaseModel
from typing import Optional
from config import TEST_MODE, device, dtype, log
from fairseq2.data.text.text_tokenizer import TextTokenEncoder
from seamless_communication.inference import Translator
import spacy
import re
from datetime import datetime
router = APIRouter()
class TranslateInput(BaseModel):
inputs: list[str]
model: str
src_lang: str
dst_lang: str
class TranslateOutput(BaseModel):
src_lang: str
dst_lang: str
translations: Optional[list[str]] = None
error: Optional[str] = None
@router.post('/t2tt')
def t2tt(inputs: TranslateInput) -> TranslateOutput:
start_time = datetime.now()
fn = t2tt_mapping.get(inputs.model)
if not fn:
return TranslateOutput(
src_lang=inputs.src_lang,
dst_lang=inputs.dst_lang,
error=f'No sentence embeddings model found for {inputs.model}'
)
try:
translations = fn(**inputs.dict())
log({
"task": "sentence_embeddings",
"model": inputs.model,
"start_time": start_time.isoformat(),
"time_taken": (datetime.now() - start_time).total_seconds(),
"inputs": inputs.inputs,
"outputs": translations,
"parameters": {
"src_lang": inputs.src_lang,
"dst_lang": inputs.dst_lang,
},
})
loaded_models_last_updated[inputs.model] = datetime.now()
return TranslateOutput(**translations)
except Exception as e:
return TranslateOutput(
src_lang=inputs.src_lang,
dst_lang=inputs.dst_lang,
error=str(e)
)
cmn_nlp = spacy.load("zh_core_web_sm")
xx_nlp = spacy.load("xx_sent_ud_sm")
unk_re = re.compile(r"\s?<unk>|\s?⁇")
def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'):
if TEST_MODE:
return {
"src_lang": src_lang,
"dst_lang": dst_lang,
"translations": None,
"error": None
}
# Load model
if 'facebook/seamless-m4t-v2-large' in loaded_models:
translator = loaded_models['facebook/seamless-m4t-v2-large']
else:
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card="vocoder_v2",
device=device,
dtype=dtype,
apply_mintox=False,
)
loaded_models['facebook/seamless-m4t-v2-large'] = translator
def sent_tokenize(text, lang) -> list[str]:
if lang == 'cmn':
return [str(t) for t in cmn_nlp(text).sents]
return [str(t) for t in xx_nlp(text).sents]
def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str:
# Convert text into paragraphs and replace new lines with spaces
lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line]
lines = [item for sublist in lines for item in sublist if item]
# Tokenize and translate
input_tokens = translator.collate([token_encoder(line) for line in lines])
translations = [
unk_re.sub("", str(t))
for t in translator.predict(
input=input_tokens,
task_str="T2TT",
src_lang=src_lang,
tgt_lang=dst_lang,
)[0]
]
return " ".join(translations)
translations = None
token_encoder = translator.text_tokenizer.create_encoder(
task="translation", lang=src_lang, mode="source", device=translator.device
)
try:
translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs]
except Exception as e:
print(f"Error translating text: {e}")
return {
"src_lang": src_lang,
"dst_lang": dst_lang,
"translations": translations,
"error": None if translations else "Failed to translate text"
}
# Polling every X minutes to
loaded_models = {}
loaded_models_last_updated = {}
t2tt_mapping = {
'facebook/seamless-m4t-v2-large': seamless_t2tt,
} |