Spaces:
Runtime error
Runtime error
Added Translation Endpoint
Browse files- Dockerfile +35 -0
- README.md +3 -1
- app.py +1 -1
- config.py +14 -1
- logger.py +0 -4
- requirements.txt +1 -0
- seamless_requirements.txt +2 -0
- tasks/pose_estimation.py +0 -0
- tasks/sentence_embeddings.py +83 -0
- tasks/translation.py +135 -0
Dockerfile
CHANGED
@@ -1,6 +1,36 @@
|
|
1 |
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
|
2 |
ENV DEBIAN_FRONTEND=noninteractive
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
RUN useradd -m -u 1000 user
|
5 |
USER user
|
6 |
ENV HOME=/home/user \
|
@@ -15,5 +45,10 @@ RUN pip install -r ${HOME}/app/requirements.txt
|
|
15 |
# RUN mkdir content
|
16 |
# ADD --chown=user https://<SOME_ASSET_URL> content/<SOME_ASSET_NAME>
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
19 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
1 |
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
|
2 |
ENV DEBIAN_FRONTEND=noninteractive
|
3 |
|
4 |
+
RUN apt-get update && \
|
5 |
+
apt-get upgrade -y && \
|
6 |
+
apt-get install -y --no-install-recommends \
|
7 |
+
git \
|
8 |
+
git-lfs \
|
9 |
+
wget \
|
10 |
+
curl \
|
11 |
+
# python build dependencies \
|
12 |
+
build-essential \
|
13 |
+
libssl-dev \
|
14 |
+
zlib1g-dev \
|
15 |
+
libbz2-dev \
|
16 |
+
libreadline-dev \
|
17 |
+
libsqlite3-dev \
|
18 |
+
libncursesw5-dev \
|
19 |
+
xz-utils \
|
20 |
+
tk-dev \
|
21 |
+
libxml2-dev \
|
22 |
+
libxmlsec1-dev \
|
23 |
+
libffi-dev \
|
24 |
+
liblzma-dev \
|
25 |
+
# gradio dependencies \
|
26 |
+
ffmpeg
|
27 |
+
|
28 |
+
# fairseq2 dependencies
|
29 |
+
RUN apt-get install -y --no-install-recommends \
|
30 |
+
libsndfile-dev
|
31 |
+
|
32 |
+
RUN apt-get clean && rm -rf /var/lib/apt/lists/*
|
33 |
+
|
34 |
RUN useradd -m -u 1000 user
|
35 |
USER user
|
36 |
ENV HOME=/home/user \
|
|
|
45 |
# RUN mkdir content
|
46 |
# ADD --chown=user https://<SOME_ASSET_URL> content/<SOME_ASSET_NAME>
|
47 |
|
48 |
+
# SeamlessCommunication requirements
|
49 |
+
RUN pip install -r ${HOME}/app/seamless_requirements.txt && \
|
50 |
+
pip install fairseq2 --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/pt2.1.0/cu121 && \
|
51 |
+
pip install ${HOME}/app/whl/seamless_communication-1.0.0-py3-none-any.whl
|
52 |
+
|
53 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
54 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -22,4 +22,6 @@ Users should be able to call the task and get back in the standard format
|
|
22 |
"model": "BAAI/bge-base-en-v1.5",
|
23 |
"inputs: ["This is one text", "This is second text"],
|
24 |
"parameters": {}
|
25 |
-
}
|
|
|
|
|
|
22 |
"model": "BAAI/bge-base-en-v1.5",
|
23 |
"inputs: ["This is one text", "This is second text"],
|
24 |
"parameters": {}
|
25 |
+
}
|
26 |
+
|
27 |
+
TODO: Models are cached in volume directory
|
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi.middleware.cors import CORSMiddleware
|
2 |
from fastapi import FastAPI
|
3 |
-
import sentence_embeddings
|
4 |
|
5 |
app = FastAPI(docs_url="/", redoc_url=None)
|
6 |
|
|
|
1 |
from fastapi.middleware.cors import CORSMiddleware
|
2 |
from fastapi import FastAPI
|
3 |
+
from tasks import sentence_embeddings
|
4 |
|
5 |
app = FastAPI(docs_url="/", redoc_url=None)
|
6 |
|
config.py
CHANGED
@@ -1,6 +1,19 @@
|
|
|
|
1 |
import os
|
2 |
import dotenv
|
3 |
|
4 |
dotenv.load_dotenv()
|
5 |
|
6 |
-
TEST_MODE = (os.getenv('TEST_MODE', 'False') == "True")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import os
|
3 |
import dotenv
|
4 |
|
5 |
dotenv.load_dotenv()
|
6 |
|
7 |
+
TEST_MODE = (os.getenv('TEST_MODE', 'False') == "True")
|
8 |
+
|
9 |
+
if torch.cuda.is_available():
|
10 |
+
device = torch.device("cuda:0")
|
11 |
+
dtype = torch.float16
|
12 |
+
else:
|
13 |
+
device = torch.device("cpu")
|
14 |
+
dtype = torch.float32
|
15 |
+
|
16 |
+
from datetime import datetime
|
17 |
+
|
18 |
+
def log(data: dict):
|
19 |
+
print(f"{datetime.now().isoformat()}: {data}")
|
logger.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
|
3 |
-
def log(data: dict):
|
4 |
-
print(f"{datetime.now().isoformat()}: {data}")
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ transformers
|
|
2 |
torch
|
3 |
fastapi
|
4 |
uvicorn
|
|
|
5 |
python-dotenv
|
|
|
2 |
torch
|
3 |
fastapi
|
4 |
uvicorn
|
5 |
+
pydantic
|
6 |
python-dotenv
|
seamless_requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
fasttext==0.9.2
|
tasks/pose_estimation.py
ADDED
File without changes
|
tasks/sentence_embeddings.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from fastapi import APIRouter
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from transformers import AutoTokenizer, AutoModel
|
5 |
+
import torch
|
6 |
+
from datetime import datetime
|
7 |
+
from config import TEST_MODE, device, log
|
8 |
+
|
9 |
+
router = APIRouter()
|
10 |
+
|
11 |
+
class SentenceEmbeddingsInput(BaseModel):
|
12 |
+
inputs: list[str]
|
13 |
+
model: str
|
14 |
+
parameters: dict
|
15 |
+
|
16 |
+
class SentenceEmbeddingsOutput(BaseModel):
|
17 |
+
embeddings: Optional[list[list[float]]] = None
|
18 |
+
error: Optional[str] = None
|
19 |
+
|
20 |
+
@router.post('/sentence-embeddings')
|
21 |
+
def sentence_embeddings(inputs: SentenceEmbeddingsInput):
|
22 |
+
start_time = datetime.now()
|
23 |
+
fn = sentence_embeddings_mapping.get(inputs.model)
|
24 |
+
if not fn:
|
25 |
+
return SentenceEmbeddingsOutput(
|
26 |
+
error=f'No sentence embeddings model found for {inputs.model}'
|
27 |
+
)
|
28 |
+
|
29 |
+
try:
|
30 |
+
embeddings = fn(inputs.inputs, inputs.parameters)
|
31 |
+
|
32 |
+
log({
|
33 |
+
"task": "sentence_embeddings",
|
34 |
+
"model": inputs.model,
|
35 |
+
"start_time": start_time.isoformat(),
|
36 |
+
"time_taken": (datetime.now() - start_time).total_seconds(),
|
37 |
+
"inputs": inputs.inputs,
|
38 |
+
"outputs": embeddings,
|
39 |
+
"parameters": inputs.parameters,
|
40 |
+
})
|
41 |
+
loaded_models_last_updated[inputs.model] = datetime.now()
|
42 |
+
return SentenceEmbeddingsOutput(
|
43 |
+
embeddings=embeddings
|
44 |
+
)
|
45 |
+
except Exception as e:
|
46 |
+
return SentenceEmbeddingsOutput(
|
47 |
+
error=str(e)
|
48 |
+
)
|
49 |
+
|
50 |
+
def generic_sentence_embeddings(model_name: str):
|
51 |
+
global loaded_models
|
52 |
+
|
53 |
+
def process_texts(texts: list[str], parameters: dict):
|
54 |
+
if TEST_MODE:
|
55 |
+
return [[0.1,0.2]] * len(texts)
|
56 |
+
|
57 |
+
if model_name in loaded_models:
|
58 |
+
tokenizer, model = loaded_models[model_name]
|
59 |
+
else:
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
61 |
+
model = AutoModel.from_pretrained(model_name).to(device)
|
62 |
+
loaded_models[model] = (tokenizer, model)
|
63 |
+
|
64 |
+
# Tokenize sentences
|
65 |
+
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)
|
66 |
+
with torch.no_grad():
|
67 |
+
model_output = model(**encoded_input)
|
68 |
+
sentence_embeddings = model_output[0][:, 0]
|
69 |
+
|
70 |
+
# normalize embeddings
|
71 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
72 |
+
return sentence_embeddings.tolist()
|
73 |
+
|
74 |
+
return process_texts
|
75 |
+
|
76 |
+
# Polling every X minutes to
|
77 |
+
loaded_models = {}
|
78 |
+
loaded_models_last_updated = {}
|
79 |
+
|
80 |
+
sentence_embeddings_mapping = {
|
81 |
+
'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'),
|
82 |
+
'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'),
|
83 |
+
}
|
tasks/translation.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import Optional
|
4 |
+
from config import TEST_MODE, device, dtype, log
|
5 |
+
from fairseq2.data.text.text_tokenizer import TextTokenEncoder
|
6 |
+
from seamless_communication.inference import Translator
|
7 |
+
import spacy
|
8 |
+
import re
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
router = APIRouter()
|
12 |
+
|
13 |
+
class TranslateInput(BaseModel):
|
14 |
+
inputs: list[str]
|
15 |
+
model: str
|
16 |
+
src_lang: str
|
17 |
+
dst_lang: str
|
18 |
+
|
19 |
+
|
20 |
+
class TranslateOutput(BaseModel):
|
21 |
+
src_lang: str
|
22 |
+
dst_lang: str
|
23 |
+
translations: Optional[list[str]] = None
|
24 |
+
error: Optional[str] = None
|
25 |
+
|
26 |
+
|
27 |
+
@router.post('/t2tt')
|
28 |
+
def t2tt(inputs: TranslateInput) -> TranslateOutput:
|
29 |
+
start_time = datetime.now()
|
30 |
+
fn = t2tt_mapping.get(inputs.model)
|
31 |
+
if not fn:
|
32 |
+
return TranslateOutput(
|
33 |
+
src_lang=inputs.src_lang,
|
34 |
+
dst_lang=inputs.dst_lang,
|
35 |
+
error=f'No sentence embeddings model found for {inputs.model}'
|
36 |
+
)
|
37 |
+
|
38 |
+
try:
|
39 |
+
translations = fn(**inputs.dict())
|
40 |
+
log({
|
41 |
+
"task": "sentence_embeddings",
|
42 |
+
"model": inputs.model,
|
43 |
+
"start_time": start_time.isoformat(),
|
44 |
+
"time_taken": (datetime.now() - start_time).total_seconds(),
|
45 |
+
"inputs": inputs.inputs,
|
46 |
+
"outputs": translations,
|
47 |
+
"parameters": {
|
48 |
+
"src_lang": inputs.src_lang,
|
49 |
+
"dst_lang": inputs.dst_lang,
|
50 |
+
},
|
51 |
+
})
|
52 |
+
loaded_models_last_updated[inputs.model] = datetime.now()
|
53 |
+
return TranslateOutput(**translations)
|
54 |
+
except Exception as e:
|
55 |
+
return TranslateOutput(
|
56 |
+
src_lang=inputs.src_lang,
|
57 |
+
dst_lang=inputs.dst_lang,
|
58 |
+
error=str(e)
|
59 |
+
)
|
60 |
+
|
61 |
+
cmn_nlp = spacy.load("zh_core_web_sm")
|
62 |
+
xx_nlp = spacy.load("xx_sent_ud_sm")
|
63 |
+
unk_re = re.compile(r"\s?<unk>|\s?⁇")
|
64 |
+
|
65 |
+
def seamless_t2tt(inputs: list[str], src_lang: str, dst_lang: str = 'eng'):
|
66 |
+
if TEST_MODE:
|
67 |
+
return {
|
68 |
+
"src_lang": src_lang,
|
69 |
+
"dst_lang": dst_lang,
|
70 |
+
"translations": None,
|
71 |
+
"error": None
|
72 |
+
}
|
73 |
+
|
74 |
+
# Load model
|
75 |
+
if 'facebook/seamless-m4t-v2-large' in loaded_models:
|
76 |
+
translator = loaded_models['facebook/seamless-m4t-v2-large']
|
77 |
+
else:
|
78 |
+
translator = Translator(
|
79 |
+
model_name_or_card="seamlessM4T_v2_large",
|
80 |
+
vocoder_name_or_card="vocoder_v2",
|
81 |
+
device=device,
|
82 |
+
dtype=dtype,
|
83 |
+
apply_mintox=False,
|
84 |
+
)
|
85 |
+
loaded_models['facebook/seamless-m4t-v2-large'] = translator
|
86 |
+
|
87 |
+
|
88 |
+
def sent_tokenize(text, lang) -> list[str]:
|
89 |
+
if lang == 'cmn':
|
90 |
+
return [str(t) for t in cmn_nlp(text).sents]
|
91 |
+
return [str(t) for t in xx_nlp(text).sents]
|
92 |
+
|
93 |
+
|
94 |
+
def tokenize_and_translate(token_encoder: TextTokenEncoder, text: str, src_lang: str, dst_lang: str) -> str:
|
95 |
+
# Convert text into paragraphs and replace new lines with spaces
|
96 |
+
lines = [sent_tokenize(line.replace("\n", " "), src_lang) for line in text.split('\n\n') if line]
|
97 |
+
lines = [item for sublist in lines for item in sublist if item]
|
98 |
+
|
99 |
+
# Tokenize and translate
|
100 |
+
input_tokens = translator.collate([token_encoder(line) for line in lines])
|
101 |
+
translations = [
|
102 |
+
unk_re.sub("", str(t))
|
103 |
+
for t in translator.predict(
|
104 |
+
input=input_tokens,
|
105 |
+
task_str="T2TT",
|
106 |
+
src_lang=src_lang,
|
107 |
+
tgt_lang=dst_lang,
|
108 |
+
)[0]
|
109 |
+
]
|
110 |
+
return " ".join(translations)
|
111 |
+
|
112 |
+
translations = None
|
113 |
+
token_encoder = translator.text_tokenizer.create_encoder(
|
114 |
+
task="translation", lang=src_lang, mode="source", device=translator.device
|
115 |
+
)
|
116 |
+
try:
|
117 |
+
translations = [tokenize_and_translate(token_encoder, text, src_lang, dst_lang) for text in inputs]
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error translating text: {e}")
|
120 |
+
|
121 |
+
return {
|
122 |
+
"src_lang": src_lang,
|
123 |
+
"dst_lang": dst_lang,
|
124 |
+
"translations": translations,
|
125 |
+
"error": None if translations else "Failed to translate text"
|
126 |
+
}
|
127 |
+
|
128 |
+
|
129 |
+
# Polling every X minutes to
|
130 |
+
loaded_models = {}
|
131 |
+
loaded_models_last_updated = {}
|
132 |
+
|
133 |
+
t2tt_mapping = {
|
134 |
+
'facebook/seamless-m4t-v2-large': seamless_t2tt,
|
135 |
+
}
|