Spaces:
Running
Running
File size: 5,686 Bytes
9791162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import numpy as np
import torch
from torch.nn import functional as F
###############################################################################
# Sequence filters
###############################################################################
def mean(signals, win_length=9):
"""Averave filtering for signals containing nan values
Arguments
signals (torch.tensor (shape=(batch, time)))
The signals to filter
win_length
The size of the analysis window
Returns
filtered (torch.tensor (shape=(batch, time)))
"""
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
signals = signals.unsqueeze(1)
# Apply the mask by setting masked elements to zero, or make NaNs zero
mask = ~torch.isnan(signals)
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
# Create a ones kernel with the same number of channels as the input tensor
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
# Perform sum pooling
sum_pooled = F.conv1d(
masked_x,
ones_kernel,
stride=1,
padding=win_length // 2,
)
# Count the non-masked (valid) elements in each pooling window
valid_count = F.conv1d(
mask.float(),
ones_kernel,
stride=1,
padding=win_length // 2,
)
valid_count = valid_count.clamp(min=1) # Avoid division by zero
# Perform masked average pooling
avg_pooled = sum_pooled / valid_count
# Fill zero values with NaNs
avg_pooled[avg_pooled == 0] = float("nan")
return avg_pooled.squeeze(1)
def median(signals, win_length):
"""Median filtering for signals containing nan values
Arguments
signals (torch.tensor (shape=(batch, time)))
The signals to filter
win_length
The size of the analysis window
Returns
filtered (torch.tensor (shape=(batch, time)))
"""
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
signals = signals.unsqueeze(1)
mask = ~torch.isnan(signals)
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
padding = win_length // 2
x = F.pad(masked_x, (padding, padding), mode="reflect")
mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)
x = x.unfold(2, win_length, 1)
mask = mask.unfold(2, win_length, 1)
x = x.contiguous().view(x.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
# Combine the mask with the input tensor
x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x)
# Sort the masked tensor along the last dimension
x_sorted, _ = torch.sort(x_masked, dim=-1)
# Compute the count of non-masked (valid) values
valid_count = mask.sum(dim=-1)
# Calculate the index of the median value for each pooling window
median_idx = ((valid_count - 1) // 2).clamp(min=0)
# Gather the median values using the calculated indices
median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
# Fill infinite values with NaNs
median_pooled[torch.isinf(median_pooled)] = float("nan")
return median_pooled.squeeze(1)
###############################################################################
# Utilities
###############################################################################
def nanfilter(signals, win_length, filter_fn):
"""Filters a sequence, ignoring nan values
Arguments
signals (torch.tensor (shape=(batch, time)))
The signals to filter
win_length
The size of the analysis window
filter_fn (function)
The function to use for filtering
Returns
filtered (torch.tensor (shape=(batch, time)))
"""
# Output buffer
filtered = torch.empty_like(signals)
# Loop over frames
for i in range(signals.size(1)):
# Get analysis window bounds
start = max(0, i - win_length // 2)
end = min(signals.size(1), i + win_length // 2 + 1)
# Apply filter to window
filtered[:, i] = filter_fn(signals[:, start:end])
return filtered
def nanmean(signals):
"""Computes the mean, ignoring nans
Arguments
signals (torch.tensor [shape=(batch, time)])
The signals to filter
Returns
filtered (torch.tensor [shape=(batch, time)])
"""
signals = signals.clone()
# Find nans
nans = torch.isnan(signals)
# Set nans to 0.
signals[nans] = 0.
# Compute average
return signals.sum(dim=1) / (~nans).float().sum(dim=1)
def nanmedian(signals):
"""Computes the median, ignoring nans
Arguments
signals (torch.tensor [shape=(batch, time)])
The signals to filter
Returns
filtered (torch.tensor [shape=(batch, time)])
"""
# Find nans
nans = torch.isnan(signals)
# Compute median for each slice
medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)]
# Stack results
return torch.tensor(medians, dtype=signals.dtype, device=signals.device)
def nanmedian1d(signal):
"""Computes the median. If signal is empty, returns torch.nan
Arguments
signal (torch.tensor [shape=(time,)])
Returns
median (torch.tensor [shape=(1,)])
"""
return torch.median(signal) if signal.numel() else np.nan
|