Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| def run_func_in_parts(func, vid_emb, aud_emb, part_len, dim, device): | |
| """ | |
| Run given function in parts, spliting the inputs on dimension dim | |
| This is used to save memory when inputs too large to compute on gpu | |
| """ | |
| dist_chunk = [] | |
| for v_spl, a_spl in list( | |
| zip(vid_emb.split(part_len, dim=dim), | |
| aud_emb.split(part_len, dim=dim))): | |
| dist_chunk.append(func(v_spl.to(device), a_spl.to(device))) | |
| dist = torch.cat(dist_chunk, dim - 1) | |
| return dist | |
| def logsoftmax_2d(logits): | |
| # Log softmax on last 2 dims because torch won't allow multiple dims | |
| orig_shape = logits.shape | |
| logprobs = torch.nn.LogSoftmax(dim=-1)( | |
| logits.reshape(list(logits.shape[:-2]) + [-1])).reshape(orig_shape) | |
| return logprobs | |