|
from typing import List, Dict, Literal, Union, Tuple |
|
import os |
|
import string |
|
import logging |
|
|
|
import torch |
|
import numpy as np |
|
from einops import rearrange, repeat |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def generate_tasks_of_dir( |
|
path: str, |
|
output_dir: str, |
|
exts: Tuple[str], |
|
same_dir_name: bool = False, |
|
**kwargs, |
|
) -> List[Dict]: |
|
"""covert video directory into tasks |
|
|
|
Args: |
|
path (str): _description_ |
|
output_dir (str): _description_ |
|
exts (Tuple[str]): _description_ |
|
same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False. |
|
whether keep the same parent dir name as the source video |
|
Returns: |
|
List[Dict]: _description_ |
|
""" |
|
tasks = [] |
|
for rootdir, dirs, files in os.walk(path): |
|
for basename in files: |
|
if basename.lower().endswith(exts): |
|
video_path = os.path.join(rootdir, basename) |
|
filename, ext = basename.split(".") |
|
rootdir_name = os.path.basename(rootdir) |
|
if same_dir_name: |
|
save_path = os.path.join( |
|
output_dir, rootdir_name, f"{filename}.h5py" |
|
) |
|
save_dir = os.path.join(output_dir, rootdir_name) |
|
else: |
|
save_path = os.path.join(output_dir, f"{filename}.h5py") |
|
save_dir = output_dir |
|
task = { |
|
"video_path": video_path, |
|
"output_path": save_path, |
|
"output_dir": save_dir, |
|
"filename": filename, |
|
"ext": ext, |
|
} |
|
task.update(kwargs) |
|
tasks.append(task) |
|
return tasks |
|
|
|
|
|
def sample_by_idx( |
|
T: int, |
|
n_sample: int, |
|
sample_rate: int, |
|
sample_start_idx: int = None, |
|
change_sample_rate: bool = False, |
|
seed: int = None, |
|
whether_random: bool = True, |
|
n_independent: int = 0, |
|
) -> List[int]: |
|
"""given a int to represent candidate list, sample n_sample with sample_rate from the candidate list |
|
|
|
Args: |
|
T (int): _description_ |
|
n_sample (int): 目标采样数目. sample number |
|
sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number |
|
sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0. |
|
change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False. |
|
whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False. |
|
|
|
Raises: |
|
ValueError: T / sample_rate should be larger than n_sample |
|
Returns: |
|
List[int]: 采样的索引位置. sampled index position |
|
""" |
|
if T < n_sample: |
|
raise ValueError(f"T({T}) < n_sample({n_sample})") |
|
else: |
|
if T / sample_rate < n_sample: |
|
if not change_sample_rate: |
|
raise ValueError( |
|
f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})" |
|
) |
|
else: |
|
while T / sample_rate < n_sample: |
|
sample_rate -= 1 |
|
logger.error( |
|
f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}" |
|
) |
|
if sample_rate == 0: |
|
raise ValueError("T / sample_rate < n_sample") |
|
|
|
if sample_start_idx is None: |
|
if whether_random: |
|
sample_start_idx_candidates = np.arange(T - n_sample * sample_rate) |
|
if seed is not None: |
|
np.random.seed(seed) |
|
sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0] |
|
|
|
else: |
|
sample_start_idx = 0 |
|
sample_end_idx = sample_start_idx + sample_rate * n_sample |
|
sample = list(range(sample_start_idx, sample_end_idx, sample_rate)) |
|
if n_independent == 0: |
|
n_independent_sample = None |
|
else: |
|
left_candidate = np.array( |
|
list(range(0, sample_start_idx)) + list(range(sample_end_idx, T)) |
|
) |
|
if len(left_candidate) >= n_independent: |
|
|
|
n_independent_sample = np.random.choice(left_candidate, n_independent) |
|
else: |
|
|
|
|
|
left_candidate = np.array(list(set(range(T) - set(sample)))) |
|
n_independent_sample = np.random.choice(left_candidate, n_independent) |
|
|
|
return sample, sample_rate, n_independent_sample |
|
|
|
|
|
def sample_tensor_by_idx( |
|
tensor: Union[torch.Tensor, np.ndarray], |
|
n_sample: int, |
|
sample_rate: int, |
|
sample_start_idx: int = 0, |
|
change_sample_rate: bool = False, |
|
seed: int = None, |
|
dim: int = 0, |
|
return_type: Literal["numpy", "torch"] = "torch", |
|
whether_random: bool = True, |
|
n_independent: int = 0, |
|
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: |
|
"""sample sub_tensor |
|
|
|
Args: |
|
tensor (Union[torch.Tensor, np.ndarray]): _description_ |
|
n_sample (int): _description_ |
|
sample_rate (int): _description_ |
|
sample_start_idx (int, optional): _description_. Defaults to 0. |
|
change_sample_rate (bool, optional): _description_. Defaults to False. |
|
seed (int, optional): _description_. Defaults to None. |
|
dim (int, optional): _description_. Defaults to 0. |
|
return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch". |
|
whether_random (bool, optional): _description_. Defaults to True. |
|
n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0. |
|
n_independent sample number that is independent of n_sample |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor |
|
""" |
|
if isinstance(tensor, np.ndarray): |
|
tensor = torch.from_numpy(tensor) |
|
T = tensor.shape[dim] |
|
sample_idx, sample_rate, independent_sample_idx = sample_by_idx( |
|
T, |
|
n_sample, |
|
sample_rate, |
|
sample_start_idx, |
|
change_sample_rate, |
|
seed, |
|
whether_random=whether_random, |
|
n_independent=n_independent, |
|
) |
|
sample_idx = torch.LongTensor(sample_idx) |
|
sample = torch.index_select(tensor, dim, sample_idx) |
|
if independent_sample_idx is not None: |
|
independent_sample_idx = torch.LongTensor(independent_sample_idx) |
|
independent_sample = torch.index_select(tensor, dim, independent_sample_idx) |
|
else: |
|
independent_sample = None |
|
independent_sample_idx = None |
|
if return_type == "numpy": |
|
sample = sample.cpu().numpy() |
|
return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx |
|
|
|
|
|
def concat_two_tensor( |
|
data1: torch.Tensor, |
|
data2: torch.Tensor, |
|
dim: int, |
|
method: Literal[ |
|
"first_in_first_out", "first_in_last_out", "intertwine", "index" |
|
] = "first_in_first_out", |
|
data1_index: torch.long = None, |
|
data2_index: torch.long = None, |
|
return_index: bool = False, |
|
): |
|
"""concat two tensor along dim with given method |
|
|
|
Args: |
|
data1 (torch.Tensor): first in data |
|
data2 (torch.Tensor): last in data |
|
dim (int): _description_ |
|
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out". |
|
|
|
Raises: |
|
NotImplementedError: unsupported method |
|
ValueError: unsupported method |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
len_data1 = data1.shape[dim] |
|
len_data2 = data2.shape[dim] |
|
|
|
if method == "first_in_first_out": |
|
res = torch.concat([data1, data2], dim=dim) |
|
data1_index = range(len_data1) |
|
data2_index = [len_data1 + x for x in range(len_data2)] |
|
elif method == "first_in_last_out": |
|
res = torch.concat([data2, data1], dim=dim) |
|
data2_index = range(len_data2) |
|
data1_index = [len_data2 + x for x in range(len_data1)] |
|
elif method == "intertwine": |
|
raise NotImplementedError("intertwine") |
|
elif method == "index": |
|
res = concat_two_tensor_with_index( |
|
data1=data1, |
|
data1_index=data1_index, |
|
data2=data2, |
|
data2_index=data2_index, |
|
dim=dim, |
|
) |
|
else: |
|
raise ValueError( |
|
"only support first_in_first_out, first_in_last_out, intertwine, index" |
|
) |
|
if return_index: |
|
return res, data1_index, data2_index |
|
else: |
|
return res |
|
|
|
|
|
def concat_two_tensor_with_index( |
|
data1: torch.Tensor, |
|
data1_index: torch.LongTensor, |
|
data2: torch.Tensor, |
|
data2_index: torch.LongTensor, |
|
dim: int, |
|
) -> torch.Tensor: |
|
"""_summary_ |
|
|
|
Args: |
|
data1 (torch.Tensor): b1*c1*h1*w1*... |
|
data1_index (torch.LongTensor): N, if dim=1, N=c1 |
|
data2 (torch.Tensor): b2*c2*h2*w2*... |
|
data2_index (torch.LongTensor): M, if dim=1, M=c2 |
|
dim (int): int |
|
|
|
Returns: |
|
torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,... |
|
""" |
|
shape1 = list(data1.shape) |
|
shape2 = list(data2.shape) |
|
target_shape = list(shape1) |
|
target_shape[dim] = shape1[dim] + shape2[dim] |
|
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) |
|
target = batch_index_copy(target, dim=dim, index=data1_index, source=data1) |
|
target = batch_index_copy(target, dim=dim, index=data2_index, source=data2) |
|
return target |
|
|
|
|
|
def repeat_index_to_target_size( |
|
index: torch.LongTensor, target_size: int |
|
) -> torch.LongTensor: |
|
if len(index.shape) == 1: |
|
index = repeat(index, "n -> b n", b=target_size) |
|
if len(index.shape) == 2: |
|
remainder = target_size % index.shape[0] |
|
assert ( |
|
remainder == 0 |
|
), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}" |
|
index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0])) |
|
return index |
|
|
|
|
|
def batch_concat_two_tensor_with_index( |
|
data1: torch.Tensor, |
|
data1_index: torch.LongTensor, |
|
data2: torch.Tensor, |
|
data2_index: torch.LongTensor, |
|
dim: int, |
|
) -> torch.Tensor: |
|
return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim) |
|
|
|
|
|
def interwine_two_tensor( |
|
data1: torch.Tensor, |
|
data2: torch.Tensor, |
|
dim: int, |
|
return_index: bool = False, |
|
) -> torch.Tensor: |
|
shape1 = list(data1.shape) |
|
shape2 = list(data2.shape) |
|
target_shape = list(shape1) |
|
target_shape[dim] = shape1[dim] + shape2[dim] |
|
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) |
|
data1_reshape = torch.swapaxes(data1, 0, dim) |
|
data2_reshape = torch.swapaxes(data2, 0, dim) |
|
target = torch.swapaxes(target, 0, dim) |
|
total_index = set(range(target_shape[dim])) |
|
data1_index = range(0, 2 * shape1[dim], 2) |
|
data2_index = sorted(list(set(total_index) - set(data1_index))) |
|
data1_index = torch.LongTensor(data1_index) |
|
data2_index = torch.LongTensor(data2_index) |
|
target[data1_index, ...] = data1_reshape |
|
target[data2_index, ...] = data2_reshape |
|
target = torch.swapaxes(target, 0, dim) |
|
if return_index: |
|
return target, data1_index, data2_index |
|
else: |
|
return target |
|
|
|
|
|
def split_index( |
|
indexs: torch.Tensor, |
|
n_first: int = None, |
|
n_last: int = None, |
|
method: Literal[ |
|
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random" |
|
] = "first_in_first_out", |
|
): |
|
"""_summary_ |
|
|
|
Args: |
|
indexs (List): _description_ |
|
n_first (int): _description_ |
|
n_last (int): _description_ |
|
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out". |
|
|
|
Raises: |
|
NotImplementedError: _description_ |
|
|
|
Returns: |
|
first_index: _description_ |
|
last_index: |
|
""" |
|
|
|
|
|
|
|
n_total = len(indexs) |
|
if n_first is None: |
|
n_first = n_total - n_last |
|
if n_last is None: |
|
n_last = n_total - n_first |
|
assert len(indexs) == n_first + n_last |
|
if method == "first_in_first_out": |
|
first_index = indexs[:n_first] |
|
last_index = indexs[n_first:] |
|
elif method == "first_in_last_out": |
|
first_index = indexs[n_last:] |
|
last_index = indexs[:n_last] |
|
elif method == "intertwine": |
|
raise NotImplementedError |
|
elif method == "random": |
|
idx_ = torch.randperm(len(indexs)) |
|
first_index = indexs[idx_[:n_first]] |
|
last_index = indexs[idx_[n_first:]] |
|
return first_index, last_index |
|
|
|
|
|
def split_tensor( |
|
tensor: torch.Tensor, |
|
dim: int, |
|
n_first=None, |
|
n_last=None, |
|
method: Literal[ |
|
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random" |
|
] = "first_in_first_out", |
|
need_return_index: bool = False, |
|
): |
|
device = tensor.device |
|
total = tensor.shape[dim] |
|
if n_first is None: |
|
n_first = total - n_last |
|
if n_last is None: |
|
n_last = total - n_first |
|
indexs = torch.arange( |
|
total, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
( |
|
first_index, |
|
last_index, |
|
) = split_index( |
|
indexs=indexs, |
|
n_first=n_first, |
|
method=method, |
|
) |
|
first_tensor = torch.index_select(tensor, dim=dim, index=first_index) |
|
last_tensor = torch.index_select(tensor, dim=dim, index=last_index) |
|
if need_return_index: |
|
return ( |
|
first_tensor, |
|
last_tensor, |
|
first_index, |
|
last_index, |
|
) |
|
else: |
|
return (first_tensor, last_tensor) |
|
|
|
|
|
|
|
def batch_index_select( |
|
tensor: torch.Tensor, index: torch.LongTensor, dim: int |
|
) -> torch.Tensor: |
|
"""_summary_ |
|
|
|
Args: |
|
tensor (torch.Tensor): D1*D2*D3*D4... |
|
index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim] |
|
dim (int): dim to select |
|
|
|
Returns: |
|
torch.Tensor: D1*...*N*... |
|
""" |
|
|
|
if len(index.shape) == 1: |
|
return torch.index_select(tensor, dim=dim, index=index) |
|
else: |
|
index = repeat_index_to_target_size(index, tensor.shape[0]) |
|
out = [] |
|
for i in torch.arange(tensor.shape[0]): |
|
sub_tensor = tensor[i] |
|
sub_index = index[i] |
|
d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index) |
|
out.append(d) |
|
return torch.stack(out).to(dtype=tensor.dtype) |
|
|
|
|
|
def batch_index_copy( |
|
tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor |
|
) -> torch.Tensor: |
|
"""_summary_ |
|
|
|
Args: |
|
tensor (torch.Tensor): b*c*h |
|
dim (int): |
|
index (torch.LongTensor): b*d, |
|
source (torch.Tensor): |
|
b*d*h*..., if dim=1 |
|
b*c*d*..., if dim=2 |
|
|
|
Returns: |
|
torch.Tensor: b*c*d*... |
|
""" |
|
if len(index.shape) == 1: |
|
tensor.index_copy_(dim=dim, index=index, source=source) |
|
else: |
|
index = repeat_index_to_target_size(index, tensor.shape[0]) |
|
|
|
batch_size = tensor.shape[0] |
|
for b in torch.arange(batch_size): |
|
sub_index = index[b] |
|
sub_source = source[b] |
|
sub_tensor = tensor[b] |
|
sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source) |
|
tensor[b] = sub_tensor |
|
return tensor |
|
|
|
|
|
def batch_index_fill( |
|
tensor: torch.Tensor, |
|
dim: int, |
|
index: torch.LongTensor, |
|
value: Literal[torch.Tensor, torch.float], |
|
) -> torch.Tensor: |
|
"""_summary_ |
|
|
|
Args: |
|
tensor (torch.Tensor): b*c*h |
|
dim (int): |
|
index (torch.LongTensor): b*d, |
|
value (torch.Tensor): b |
|
|
|
Returns: |
|
torch.Tensor: b*c*d*... |
|
""" |
|
index = repeat_index_to_target_size(index, tensor.shape[0]) |
|
batch_size = tensor.shape[0] |
|
for b in torch.arange(batch_size): |
|
sub_index = index[b] |
|
sub_value = value[b] if isinstance(value, torch.Tensor) else value |
|
sub_tensor = tensor[b] |
|
sub_tensor.index_fill_(dim - 1, sub_index, sub_value) |
|
tensor[b] = sub_tensor |
|
return tensor |
|
|
|
|
|
def adaptive_instance_normalization( |
|
src: torch.Tensor, |
|
dst: torch.Tensor, |
|
eps: float = 1e-6, |
|
): |
|
""" |
|
Args: |
|
src (torch.Tensor): b c t h w |
|
dst (torch.Tensor): b c t h w |
|
""" |
|
ndim = src.ndim |
|
if ndim == 5: |
|
dim = (2, 3, 4) |
|
elif ndim == 4: |
|
dim = (2, 3) |
|
elif ndim == 3: |
|
dim = 2 |
|
else: |
|
raise ValueError("only support ndim in [3,4,5], but given {ndim}") |
|
var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0) |
|
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
|
dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0) |
|
mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0) |
|
|
|
|
|
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
|
src = (((src - mean) / std) * std_acc) + mean_acc |
|
return src |
|
|
|
|
|
def adaptive_instance_normalization_with_ref( |
|
src: torch.LongTensor, |
|
dst: torch.LongTensor, |
|
style_fidelity: float = 0.5, |
|
do_classifier_free_guidance: bool = True, |
|
): |
|
|
|
|
|
|
|
|
|
batch_size = src.shape[0] // 2 |
|
uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool() |
|
src_uc = adaptive_instance_normalization(src, dst) |
|
src_c = src_uc.clone() |
|
|
|
if do_classifier_free_guidance and style_fidelity > 0: |
|
src_c[uc_mask] = src[uc_mask] |
|
src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc |
|
return src |
|
|
|
|
|
def batch_adain_conditioned_tensor( |
|
tensor: torch.Tensor, |
|
src_index: torch.LongTensor, |
|
dst_index: torch.LongTensor, |
|
keep_dim: bool = True, |
|
num_frames: int = None, |
|
dim: int = 2, |
|
style_fidelity: float = 0.5, |
|
do_classifier_free_guidance: bool = True, |
|
need_style_fidelity: bool = False, |
|
): |
|
"""_summary_ |
|
|
|
Args: |
|
tensor (torch.Tensor): b c t h w |
|
src_index (torch.LongTensor): _description_ |
|
dst_index (torch.LongTensor): _description_ |
|
keep_dim (bool, optional): _description_. Defaults to True. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
ndim = tensor.ndim |
|
dtype = tensor.dtype |
|
if ndim == 4 and num_frames is not None: |
|
tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames) |
|
src = batch_index_select(tensor, dim=dim, index=src_index).contiguous() |
|
dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous() |
|
if need_style_fidelity: |
|
src = adaptive_instance_normalization_with_ref( |
|
src=src, |
|
dst=dst, |
|
style_fidelity=style_fidelity, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
need_style_fidelity=need_style_fidelity, |
|
) |
|
else: |
|
src = adaptive_instance_normalization( |
|
src=src, |
|
dst=dst, |
|
) |
|
if keep_dim: |
|
src = batch_concat_two_tensor_with_index( |
|
src.to(dtype=dtype), |
|
src_index, |
|
dst.to(dtype=dtype), |
|
dst_index, |
|
dim=dim, |
|
) |
|
|
|
if ndim == 4 and num_frames is not None: |
|
src = rearrange(tensor, "b c t h w ->(b t) c h w") |
|
return src |
|
|
|
|
|
def align_repeat_tensor_single_dim( |
|
src: torch.Tensor, |
|
target_length: int, |
|
dim: int = 0, |
|
n_src_base_length: int = 1, |
|
src_base_index: List[int] = None, |
|
) -> torch.Tensor: |
|
"""沿着 dim 纬度, 补齐 src 的长度到目标 target_length。 |
|
当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length |
|
|
|
align length of src to target_length along dim |
|
when src length is less than target_length, take the first n_src_base_length and repeat to target_length |
|
|
|
Args: |
|
src (torch.Tensor): 输入 tensor, input tensor |
|
target_length (int): 目标长度, target_length |
|
dim (int, optional): 处理纬度, target dim . Defaults to 0. |
|
n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1. |
|
|
|
Returns: |
|
torch.Tensor: _description_ |
|
""" |
|
src_dim_length = src.shape[dim] |
|
if target_length > src_dim_length: |
|
if target_length % src_dim_length == 0: |
|
new = src.repeat_interleave( |
|
repeats=target_length // src_dim_length, dim=dim |
|
) |
|
else: |
|
if src_base_index is None and n_src_base_length is not None: |
|
src_base_index = torch.arange(n_src_base_length) |
|
|
|
new = src.index_select( |
|
dim=dim, |
|
index=torch.LongTensor(src_base_index).to(device=src.device), |
|
) |
|
new = new.repeat_interleave( |
|
repeats=target_length // len(src_base_index), |
|
dim=dim, |
|
) |
|
elif target_length < src_dim_length: |
|
new = src.index_select( |
|
dim=dim, |
|
index=torch.LongTensor(torch.arange(target_length)).to(device=src.device), |
|
) |
|
else: |
|
new = src |
|
return new |
|
|
|
|
|
def fuse_part_tensor( |
|
src: torch.Tensor, |
|
dst: torch.Tensor, |
|
overlap: int, |
|
weight: float = 0.5, |
|
skip_step: int = 0, |
|
) -> torch.Tensor: |
|
"""fuse overstep tensor with weight of src into dst |
|
out = src_fused_part * weight + dst * (1-weight) for overlap |
|
|
|
Args: |
|
src (torch.Tensor): b c t h w |
|
dst (torch.Tensor): b c t h w |
|
overlap (int): 1 |
|
weight (float, optional): weight of src tensor part. Defaults to 0.5. |
|
|
|
Returns: |
|
torch.Tensor: fused tensor |
|
""" |
|
if overlap == 0: |
|
return dst |
|
else: |
|
dst[:, :, skip_step : skip_step + overlap] = ( |
|
weight * src[:, :, -overlap:] |
|
+ (1 - weight) * dst[:, :, skip_step : skip_step + overlap] |
|
) |
|
return dst |
|
|