File size: 4,252 Bytes
b805057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from fastapi import APIRouter
from pydantic import BaseModel
from typing import Optional
from config import TEST_MODE, device, dtype, log
from fairseq2.data.text.text_tokenizer import TextTokenEncoder
from seamless_communication.inference import Translator
import spacy
import re
from datetime import datetime

router = APIRouter()

class TranslateInput(BaseModel):
    inputs: list[str]
    model: str
    src_lang: str
    dst_lang: str


class TranslateOutput(BaseModel):
    src_lang: str
    dst_lang: str
    translations: Optional[list[str]] = None
    error: Optional[str] = None


@router.post('/t2tt')
def t2tt(inputs: TranslateInput) -> TranslateOutput:
    start_time = datetime.now()
    fn = t2tt_mapping.get(inputs.model)
    if not fn:
        return TranslateOutput(
            src_lang=inputs.src_lang,
            dst_lang=inputs.dst_lang,
            error=f'No sentence embeddings model found for {inputs.model}'
        )
    
    try:
        translations = fn(**inputs.dict())
        log({
            "task": "sentence_embeddings",
            "model": inputs.model,
            "start_time": start_time.isoformat(),
            "time_taken": (datetime.now() - start_time).total_seconds(),
            "inputs": inputs.inputs,
            "outputs": translations,
            "parameters": {
                "src_lang": inputs.src_lang,
                "dst_lang": inputs.dst_lang,
            },
        })
        loaded_models_last_updated[inputs.model] = datetime.now()
        return TranslateOutput(**translations)
    except Exception as e:
        return TranslateOutput(
            src_lang=inputs.src_lang,
            dst_lang=inputs.dst_lang,
            error=str(e)
        )

cmn_nlp = spacy.load("zh_core_web_sm")
xx_nlp = spacy.load("xx_sent_ud_sm")
unk_re = re.compile(r"\s?<unk>|\s?⁇")

def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'):
    if TEST_MODE:
        return {
            "src_lang": src_lang,
            "dst_lang": dst_lang,
            "translations": None,
            "error": None
        }

    # Load model
    if 'facebook/seamless-m4t-v2-large' in loaded_models:
        translator = loaded_models['facebook/seamless-m4t-v2-large']
    else:
        translator = Translator(
            model_name_or_card="seamlessM4T_v2_large",
            vocoder_name_or_card="vocoder_v2",
            device=device,
            dtype=dtype,
            apply_mintox=False,
        )
        loaded_models['facebook/seamless-m4t-v2-large'] = translator


    def sent_tokenize(text, lang) -> list[str]:
        if lang == 'cmn':
            return [str(t) for t in cmn_nlp(text).sents]
        return [str(t) for t in xx_nlp(text).sents]


    def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str:
        # Convert text into paragraphs and replace new lines with spaces
        lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line]
        lines = [item for sublist in lines for item in sublist if item]

        # Tokenize and translate
        input_tokens = translator.collate([token_encoder(line) for line in lines])
        translations = [
                unk_re.sub("", str(t))
                for t in translator.predict(
                    input=input_tokens,
                    task_str="T2TT",
                    src_lang=src_lang,
                    tgt_lang=dst_lang,
                )[0]
            ]
        return " ".join(translations)

    translations = None
    token_encoder = translator.text_tokenizer.create_encoder(
        task="translation", lang=src_lang, mode="source", device=translator.device
    )
    try:
        translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs]
    except Exception as e:
        print(f"Error translating text: {e}")

    return {
        "src_lang": src_lang,
        "dst_lang": dst_lang,
        "translations": translations,
        "error": None if translations else "Failed to translate text"
    }


# Polling every X minutes to 
loaded_models = {}
loaded_models_last_updated = {}

t2tt_mapping = {
    'facebook/seamless-m4t-v2-large': seamless_t2tt,
}