import numpy as np import torch from torch.nn import functional as F ############################################################################### # Sequence filters ############################################################################### def mean(signals, win_length=9): """Averave filtering for signals containing nan values Arguments signals (torch.tensor (shape=(batch, time))) The signals to filter win_length The size of the analysis window Returns filtered (torch.tensor (shape=(batch, time))) """ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" signals = signals.unsqueeze(1) # Apply the mask by setting masked elements to zero, or make NaNs zero mask = ~torch.isnan(signals) masked_x = torch.where(mask, signals, torch.zeros_like(signals)) # Create a ones kernel with the same number of channels as the input tensor ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device) # Perform sum pooling sum_pooled = F.conv1d( masked_x, ones_kernel, stride=1, padding=win_length // 2, ) # Count the non-masked (valid) elements in each pooling window valid_count = F.conv1d( mask.float(), ones_kernel, stride=1, padding=win_length // 2, ) valid_count = valid_count.clamp(min=1) # Avoid division by zero # Perform masked average pooling avg_pooled = sum_pooled / valid_count # Fill zero values with NaNs avg_pooled[avg_pooled == 0] = float("nan") return avg_pooled.squeeze(1) def median(signals, win_length): """Median filtering for signals containing nan values Arguments signals (torch.tensor (shape=(batch, time))) The signals to filter win_length The size of the analysis window Returns filtered (torch.tensor (shape=(batch, time))) """ assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" signals = signals.unsqueeze(1) mask = ~torch.isnan(signals) masked_x = torch.where(mask, signals, torch.zeros_like(signals)) padding = win_length // 2 x = F.pad(masked_x, (padding, padding), mode="reflect") mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0) x = x.unfold(2, win_length, 1) mask = mask.unfold(2, win_length, 1) x = x.contiguous().view(x.size()[:3] + (-1,)) mask = mask.contiguous().view(mask.size()[:3] + (-1,)) # Combine the mask with the input tensor x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x) # Sort the masked tensor along the last dimension x_sorted, _ = torch.sort(x_masked, dim=-1) # Compute the count of non-masked (valid) values valid_count = mask.sum(dim=-1) # Calculate the index of the median value for each pooling window median_idx = ((valid_count - 1) // 2).clamp(min=0) # Gather the median values using the calculated indices median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) # Fill infinite values with NaNs median_pooled[torch.isinf(median_pooled)] = float("nan") return median_pooled.squeeze(1) ############################################################################### # Utilities ############################################################################### def nanfilter(signals, win_length, filter_fn): """Filters a sequence, ignoring nan values Arguments signals (torch.tensor (shape=(batch, time))) The signals to filter win_length The size of the analysis window filter_fn (function) The function to use for filtering Returns filtered (torch.tensor (shape=(batch, time))) """ # Output buffer filtered = torch.empty_like(signals) # Loop over frames for i in range(signals.size(1)): # Get analysis window bounds start = max(0, i - win_length // 2) end = min(signals.size(1), i + win_length // 2 + 1) # Apply filter to window filtered[:, i] = filter_fn(signals[:, start:end]) return filtered def nanmean(signals): """Computes the mean, ignoring nans Arguments signals (torch.tensor [shape=(batch, time)]) The signals to filter Returns filtered (torch.tensor [shape=(batch, time)]) """ signals = signals.clone() # Find nans nans = torch.isnan(signals) # Set nans to 0. signals[nans] = 0. # Compute average return signals.sum(dim=1) / (~nans).float().sum(dim=1) def nanmedian(signals): """Computes the median, ignoring nans Arguments signals (torch.tensor [shape=(batch, time)]) The signals to filter Returns filtered (torch.tensor [shape=(batch, time)]) """ # Find nans nans = torch.isnan(signals) # Compute median for each slice medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)] # Stack results return torch.tensor(medians, dtype=signals.dtype, device=signals.device) def nanmedian1d(signal): """Computes the median. If signal is empty, returns torch.nan Arguments signal (torch.tensor [shape=(time,)]) Returns median (torch.tensor [shape=(1,)]) """ return torch.median(signal) if signal.numel() else np.nan