Spaces:
Build error
Build error
Duplicate from Mathux/TMR
Browse filesCo-authored-by: Mathis Petrovich <[email protected]>
- .gitattributes +34 -0
- README.md +13 -0
- amass-annotations/amass_to_babel.json +0 -0
- amass-annotations/humanml3d.json +0 -0
- app.py +313 -0
- load.py +53 -0
- model.py +128 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TMR
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: Mathux/TMR
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
amass-annotations/amass_to_babel.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
amass-annotations/humanml3d.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
import gdown
|
8 |
+
|
9 |
+
from load import load_model, load_json
|
10 |
+
from load import load_unit_motion_embs_splits, load_keyids_splits
|
11 |
+
|
12 |
+
|
13 |
+
WEBSITE = """
|
14 |
+
<div class="embed_hidden">
|
15 |
+
<h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>
|
16 |
+
|
17 |
+
<h2 style='text-align: center'>
|
18 |
+
<a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a>  
|
19 |
+
<a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a>  
|
20 |
+
<a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>Gül Varol</nobr></a>
|
21 |
+
</h2>
|
22 |
+
|
23 |
+
<h2 style='text-align: center'>
|
24 |
+
<nobr>arXiv 2023</nobr>
|
25 |
+
</h2>
|
26 |
+
|
27 |
+
<h3 style="text-align:center;">
|
28 |
+
<a target="_blank" href="https://arxiv.org/abs/2305.00976"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
|
29 |
+
<a target="_blank" href="https://github.com/Mathux/TMR"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
|
30 |
+
<a target="_blank" href="https://mathis.petrovich.fr/tmr"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
|
31 |
+
<a target="_blank" href="https://mathis.petrovich.fr/tmr/tmr.bib"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
|
32 |
+
</h3>
|
33 |
+
|
34 |
+
<h3> Description </h3>
|
35 |
+
<p>
|
36 |
+
This space illustrates <a href='https://mathis.petrovich.fr/tmr/' target='_blank'><b>TMR</b></a>, a method for text-to-motion retrieval. Given a gallery of 3D human motions (which can be unseen during training) and a text query, the goal is to search for motions which are close to the text query.
|
37 |
+
</p>
|
38 |
+
</div>
|
39 |
+
"""
|
40 |
+
|
41 |
+
EXAMPLES = [
|
42 |
+
"A person is walking slowly",
|
43 |
+
"A person is walking in a circle",
|
44 |
+
"A person is jumping rope",
|
45 |
+
"Someone is doing a backflip",
|
46 |
+
"A person is doing a moonwalk",
|
47 |
+
"A person walks forward and then turns back",
|
48 |
+
"Picking up an object",
|
49 |
+
"A person is swimming in the sea",
|
50 |
+
"A human is squatting",
|
51 |
+
"Someone is jumping with one foot",
|
52 |
+
"A person is chopping vegetables",
|
53 |
+
"Someone walks backward",
|
54 |
+
"Somebody is ascending a staircase",
|
55 |
+
"A person is sitting down",
|
56 |
+
"A person is taking the stairs",
|
57 |
+
"Someone is doing jumping jacks",
|
58 |
+
"The person walked forward and is picking up his toolbox",
|
59 |
+
"The person angrily punching the air"
|
60 |
+
]
|
61 |
+
|
62 |
+
# Show closest text in the training
|
63 |
+
|
64 |
+
|
65 |
+
# css to make videos look nice
|
66 |
+
# var(--block-border-color);
|
67 |
+
CSS = """
|
68 |
+
.retrieved_video {
|
69 |
+
position: relative;
|
70 |
+
margin: 0;
|
71 |
+
box-shadow: var(--block-shadow);
|
72 |
+
border-width: var(--block-border-width);
|
73 |
+
border-color: #000000;
|
74 |
+
border-radius: var(--block-radius);
|
75 |
+
background: var(--block-background-fill);
|
76 |
+
width: 100%;
|
77 |
+
line-height: var(--line-sm);
|
78 |
+
}
|
79 |
+
|
80 |
+
.contour_video {
|
81 |
+
display: flex;
|
82 |
+
flex-direction: column;
|
83 |
+
justify-content: center;
|
84 |
+
align-items: center;
|
85 |
+
z-index: var(--layer-5);
|
86 |
+
border-radius: var(--block-radius);
|
87 |
+
background: var(--background-fill-primary);
|
88 |
+
padding: 0 var(--size-6);
|
89 |
+
max-height: var(--size-screen-h);
|
90 |
+
overflow: hidden;
|
91 |
+
}
|
92 |
+
"""
|
93 |
+
|
94 |
+
|
95 |
+
DEFAULT_TEXT = "A person is "
|
96 |
+
|
97 |
+
def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
98 |
+
# Don't show the mirrored version of HumanMl3D
|
99 |
+
if "M" in keyid:
|
100 |
+
return None
|
101 |
+
|
102 |
+
dico = h3d_index[keyid]
|
103 |
+
path = dico["path"]
|
104 |
+
|
105 |
+
# HumanAct12 motions are not rendered online
|
106 |
+
# so we skip them for now
|
107 |
+
if "humanact12" in path:
|
108 |
+
return None
|
109 |
+
|
110 |
+
# This motion is not rendered in BABEL
|
111 |
+
# so we skip them for now
|
112 |
+
if path not in amass_to_babel:
|
113 |
+
return None
|
114 |
+
|
115 |
+
babel_id = amass_to_babel[path].zfill(6)
|
116 |
+
url = f"https://babel-renders.s3.eu-central-1.amazonaws.com/{babel_id}.mp4"
|
117 |
+
|
118 |
+
# For the demo, we retrieve from the first annotation only
|
119 |
+
ann = dico["annotations"][0]
|
120 |
+
start = ann["start"]
|
121 |
+
end = ann["end"]
|
122 |
+
text = ann["text"]
|
123 |
+
|
124 |
+
data = {
|
125 |
+
"url": url,
|
126 |
+
"start": start,
|
127 |
+
"end": end,
|
128 |
+
"text": text,
|
129 |
+
"keyid": keyid,
|
130 |
+
"babel_id": babel_id,
|
131 |
+
"path": path
|
132 |
+
}
|
133 |
+
|
134 |
+
return data
|
135 |
+
|
136 |
+
|
137 |
+
def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits=["test"], nmax=8):
|
138 |
+
unit_motion_embs = torch.cat([all_unit_motion_embs[s] for s in splits])
|
139 |
+
keyids = np.concatenate([all_keyids[s] for s in splits])
|
140 |
+
|
141 |
+
scores = model.compute_scores(text, unit_embs=unit_motion_embs)
|
142 |
+
|
143 |
+
sorted_idxs = np.argsort(-scores)
|
144 |
+
best_keyids = keyids[sorted_idxs]
|
145 |
+
best_scores = scores[sorted_idxs]
|
146 |
+
|
147 |
+
datas = []
|
148 |
+
for keyid, score in zip(best_keyids, best_scores):
|
149 |
+
if len(datas) == nmax:
|
150 |
+
break
|
151 |
+
|
152 |
+
data = keyid_to_url(keyid)
|
153 |
+
if data is None:
|
154 |
+
continue
|
155 |
+
data["score"] = round(float(score), 2)
|
156 |
+
datas.append(data)
|
157 |
+
return datas
|
158 |
+
|
159 |
+
|
160 |
+
# HTML component
|
161 |
+
def get_video_html(data, video_id, width=700, height=700):
|
162 |
+
url = data["url"]
|
163 |
+
start = data["start"]
|
164 |
+
end = data["end"]
|
165 |
+
score = data["score"]
|
166 |
+
text = data["text"]
|
167 |
+
keyid = data["keyid"]
|
168 |
+
babel_id = data["babel_id"]
|
169 |
+
path = data["path"]
|
170 |
+
|
171 |
+
trim = f"#t={start},{end}"
|
172 |
+
title = f'''Score = {score}
|
173 |
+
|
174 |
+
Corresponding text: {text}
|
175 |
+
|
176 |
+
HumanML3D keyid: {keyid}
|
177 |
+
|
178 |
+
BABEL keyid: {babel_id}
|
179 |
+
|
180 |
+
AMASS path: {path}'''
|
181 |
+
|
182 |
+
# class="wrap default svelte-gjihhp hide"
|
183 |
+
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
184 |
+
# width="{width}" height="{height}"
|
185 |
+
video_html = f'''
|
186 |
+
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
187 |
+
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
188 |
+
<source src="{url}{trim}" type="video/mp4">
|
189 |
+
Your browser does not support the video tag.
|
190 |
+
</video>
|
191 |
+
'''
|
192 |
+
return video_html
|
193 |
+
|
194 |
+
|
195 |
+
def retrieve_component(retrieve_function, text, splits_choice, nvids, n_component=24):
|
196 |
+
if text == DEFAULT_TEXT or text == "" or text is None:
|
197 |
+
return [None for _ in range(n_component)]
|
198 |
+
|
199 |
+
# cannot produce more than n_compoenent
|
200 |
+
nvids = min(nvids, n_component)
|
201 |
+
|
202 |
+
if "Unseen" in splits_choice:
|
203 |
+
splits = ["test"]
|
204 |
+
else:
|
205 |
+
splits = ["train", "val", "test"]
|
206 |
+
|
207 |
+
datas = retrieve_function(text, splits=splits, nmax=nvids)
|
208 |
+
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
209 |
+
# get n_component exactly if asked less
|
210 |
+
# pad with dummy blocks
|
211 |
+
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
212 |
+
return htmls
|
213 |
+
|
214 |
+
|
215 |
+
if not os.path.exists("data"):
|
216 |
+
gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
|
217 |
+
use_cookies=False)
|
218 |
+
|
219 |
+
|
220 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
221 |
+
|
222 |
+
# LOADING
|
223 |
+
model = load_model(device)
|
224 |
+
splits = ["train", "val", "test"]
|
225 |
+
all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
|
226 |
+
all_keyids = load_keyids_splits(splits)
|
227 |
+
|
228 |
+
h3d_index = load_json("amass-annotations/humanml3d.json")
|
229 |
+
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
230 |
+
|
231 |
+
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
232 |
+
retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
|
233 |
+
|
234 |
+
# DEMO
|
235 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
236 |
+
retrieve_and_show = partial(retrieve_component, retrieve_function)
|
237 |
+
|
238 |
+
with gr.Blocks(css=CSS, theme=theme) as demo:
|
239 |
+
gr.Markdown(WEBSITE)
|
240 |
+
videos = []
|
241 |
+
|
242 |
+
with gr.Row():
|
243 |
+
with gr.Column(scale=3):
|
244 |
+
with gr.Column(scale=2):
|
245 |
+
text = gr.Textbox(placeholder="Type the motion you want to search with a sentence",
|
246 |
+
show_label=True, label="Text prompt", value=DEFAULT_TEXT)
|
247 |
+
with gr.Column(scale=1):
|
248 |
+
btn = gr.Button("Retrieve", variant='primary')
|
249 |
+
clear = gr.Button("Clear", variant='secondary')
|
250 |
+
|
251 |
+
with gr.Row():
|
252 |
+
with gr.Column(scale=1):
|
253 |
+
splits_choice = gr.Radio(["All motions", "Unseen motions"], label="Gallery of motion",
|
254 |
+
value="All motions",
|
255 |
+
info="The motion gallery is coming from HumanML3D")
|
256 |
+
|
257 |
+
with gr.Column(scale=1):
|
258 |
+
# nvideo_slider = gr.Slider(minimum=4, maximum=24, step=4, value=8, label="Number of videos")
|
259 |
+
nvideo_slider = gr.Radio([4, 8, 12, 16, 24], label="Videos",
|
260 |
+
value=8,
|
261 |
+
info="Number of videos to display")
|
262 |
+
|
263 |
+
with gr.Column(scale=2):
|
264 |
+
def retrieve_example(text, splits_choice, nvideo_slider):
|
265 |
+
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
266 |
+
|
267 |
+
examples = gr.Examples(examples=[[x, None, None] for x in EXAMPLES],
|
268 |
+
inputs=[text, splits_choice, nvideo_slider],
|
269 |
+
examples_per_page=20,
|
270 |
+
run_on_click=False, cache_examples=False,
|
271 |
+
fn=retrieve_example, outputs=[])
|
272 |
+
|
273 |
+
i = -1
|
274 |
+
# should indent
|
275 |
+
for _ in range(6):
|
276 |
+
with gr.Row():
|
277 |
+
for _ in range(4):
|
278 |
+
i += 1
|
279 |
+
video = gr.HTML()
|
280 |
+
videos.append(video)
|
281 |
+
|
282 |
+
# connect the examples to the output
|
283 |
+
# a bit hacky
|
284 |
+
examples.outputs = videos
|
285 |
+
|
286 |
+
def load_example(example_id):
|
287 |
+
processed_example = examples.non_none_processed_examples[example_id]
|
288 |
+
return gr.utils.resolve_singleton(processed_example)
|
289 |
+
|
290 |
+
examples.dataset.click(
|
291 |
+
load_example,
|
292 |
+
inputs=[examples.dataset],
|
293 |
+
outputs=examples.inputs_with_examples, # type: ignore
|
294 |
+
show_progress=False,
|
295 |
+
postprocess=False,
|
296 |
+
queue=False,
|
297 |
+
).then(
|
298 |
+
fn=retrieve_example,
|
299 |
+
inputs=examples.inputs,
|
300 |
+
outputs=videos
|
301 |
+
)
|
302 |
+
|
303 |
+
btn.click(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
304 |
+
text.submit(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
305 |
+
splits_choice.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
306 |
+
nvideo_slider.change(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
307 |
+
|
308 |
+
def clear_videos():
|
309 |
+
return [None for x in range(24)] + [DEFAULT_TEXT]
|
310 |
+
|
311 |
+
clear.click(fn=clear_videos, outputs=videos + [text])
|
312 |
+
|
313 |
+
demo.launch()
|
load.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import orjson
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from model import TMR_textencoder
|
6 |
+
|
7 |
+
EMBS = "data/unit_motion_embs"
|
8 |
+
|
9 |
+
|
10 |
+
def load_json(path):
|
11 |
+
with open(path, "rb") as ff:
|
12 |
+
return orjson.loads(ff.read())
|
13 |
+
|
14 |
+
|
15 |
+
def load_keyids(split):
|
16 |
+
path = os.path.join(EMBS, f"{split}.keyids")
|
17 |
+
with open(path) as ff:
|
18 |
+
keyids = np.array([x.strip() for x in ff.readlines()])
|
19 |
+
return keyids
|
20 |
+
|
21 |
+
|
22 |
+
def load_keyids_splits(splits):
|
23 |
+
return {
|
24 |
+
split: load_keyids(split)
|
25 |
+
for split in splits
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def load_unit_motion_embs(split, device):
|
30 |
+
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy")
|
31 |
+
tensor = torch.from_numpy(np.load(path)).to(device)
|
32 |
+
return tensor
|
33 |
+
|
34 |
+
|
35 |
+
def load_unit_motion_embs_splits(splits, device):
|
36 |
+
return {
|
37 |
+
split: load_unit_motion_embs(split, device)
|
38 |
+
for split in splits
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def load_model(device):
|
43 |
+
text_params = {
|
44 |
+
'latent_dim': 256, 'ff_size': 1024, 'num_layers': 6, 'num_heads': 4,
|
45 |
+
'activation': 'gelu', 'modelpath': 'distilbert-base-uncased'
|
46 |
+
}
|
47 |
+
"unit_motion_embs"
|
48 |
+
model = TMR_textencoder(**text_params)
|
49 |
+
state_dict = torch.load("data/textencoder.pt", map_location=device)
|
50 |
+
# load values for the transformer only
|
51 |
+
model.load_state_dict(state_dict, strict=False)
|
52 |
+
model = model.eval()
|
53 |
+
return model
|
model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from torch import Tensor
|
8 |
+
from transformers import AutoTokenizer, AutoModel
|
9 |
+
from transformers import logging
|
10 |
+
from torch.nn.functional import normalize
|
11 |
+
|
12 |
+
|
13 |
+
class PositionalEncoding(nn.Module):
|
14 |
+
def __init__(self, d_model, max_len=5000):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
pe = torch.zeros(max_len, d_model)
|
18 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
19 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
20 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
21 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
22 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
23 |
+
|
24 |
+
self.register_buffer('pe', pe, persistent=False)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return x + self.pe[:x.shape[0], :]
|
28 |
+
|
29 |
+
|
30 |
+
class TMR_textencoder(nn.Module):
|
31 |
+
def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
|
32 |
+
num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
logging.set_verbosity_error()
|
36 |
+
|
37 |
+
# Tokenizer
|
38 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
40 |
+
|
41 |
+
# Text model
|
42 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
43 |
+
# Then configure the model
|
44 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
45 |
+
|
46 |
+
# Projection of the text-outputs into the latent space
|
47 |
+
self.projection = nn.Sequential(
|
48 |
+
nn.ReLU(),
|
49 |
+
nn.Linear(self.text_encoded_dim, latent_dim)
|
50 |
+
)
|
51 |
+
|
52 |
+
self.mu_token = nn.Parameter(torch.randn(latent_dim))
|
53 |
+
self.logvar_token = nn.Parameter(torch.randn(latent_dim))
|
54 |
+
self.sequence_pos_encoding = PositionalEncoding(latent_dim)
|
55 |
+
|
56 |
+
seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
|
57 |
+
nhead=num_heads,
|
58 |
+
dim_feedforward=ff_size,
|
59 |
+
dropout=0.0,
|
60 |
+
activation=activation)
|
61 |
+
self.seqTransEncoder = nn.TransformerEncoder(
|
62 |
+
seq_trans_encoder_layer,
|
63 |
+
num_layers=num_layers
|
64 |
+
)
|
65 |
+
|
66 |
+
def get_last_hidden_state(self, texts: List[str],
|
67 |
+
return_mask: bool = False):
|
68 |
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
69 |
+
output = self.text_model(**encoded_inputs.to(self.text_model.device))
|
70 |
+
if not return_mask:
|
71 |
+
return output.last_hidden_state
|
72 |
+
return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)
|
73 |
+
|
74 |
+
def forward(self, texts: List[str]) -> Tensor:
|
75 |
+
text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)
|
76 |
+
|
77 |
+
x = self.projection(text_encoded)
|
78 |
+
bs, nframes, _ = x.shape
|
79 |
+
# bs, nframes, totjoints, nfeats = x.shape
|
80 |
+
# Switch sequence and batch_size because the input of
|
81 |
+
# Pytorch Transformer is [Sequence, Batch size, ...]
|
82 |
+
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
|
83 |
+
|
84 |
+
mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
|
85 |
+
logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)
|
86 |
+
|
87 |
+
# adding the distribution tokens for all sequences
|
88 |
+
xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)
|
89 |
+
|
90 |
+
# create a bigger mask, to allow attend to mu and logvar
|
91 |
+
token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
|
92 |
+
aug_mask = torch.cat((token_mask, mask), 1)
|
93 |
+
|
94 |
+
# add positional encoding
|
95 |
+
xseq = self.sequence_pos_encoding(xseq)
|
96 |
+
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
|
97 |
+
|
98 |
+
# only mu for inference
|
99 |
+
mu = final[0]
|
100 |
+
return mu
|
101 |
+
|
102 |
+
# compute score for retrieval
|
103 |
+
def compute_scores(self, texts, unit_embs=None, embs=None):
|
104 |
+
# not both empty
|
105 |
+
assert not (unit_embs is None and embs is None)
|
106 |
+
# not both filled
|
107 |
+
assert not (unit_embs is not None and embs is not None)
|
108 |
+
|
109 |
+
output_str = False
|
110 |
+
# if one input, squeeze the output
|
111 |
+
if isinstance(texts, str):
|
112 |
+
texts = [texts]
|
113 |
+
output_str = True
|
114 |
+
|
115 |
+
# compute unit_embs from embs if not given
|
116 |
+
if embs is not None:
|
117 |
+
unit_embs = normalize(embs)
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
latent_unit_texts = normalize(self(texts))
|
121 |
+
# compute cosine similarity between 0 and 1
|
122 |
+
scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
|
123 |
+
scores = scores.cpu().numpy()
|
124 |
+
|
125 |
+
if output_str:
|
126 |
+
scores = scores[0]
|
127 |
+
|
128 |
+
return scores
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
orjson
|
3 |
+
numpy
|
4 |
+
gdown
|
5 |
+
transformers
|