Source code for piblin_jax.transform.dataset.smoothing

"""
Smoothing transforms for 1D datasets.

This module provides various smoothing/filtering transforms to reduce
noise in time series and spectral data.
"""

from typing import Any

from piblin_jax.backend import jnp
from piblin_jax.backend.operations import jit
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform


[docs] class MovingAverageSmooth(DatasetTransform): """ Smooth data using moving average filter. This transform applies a simple moving average (box filter) to smooth noisy data. It uses convolution for efficient computation. Parameters ---------- window_size : int, default=5 Size of the moving average window. Must be odd to ensure symmetry. Attributes ---------- window_size : int Window size for moving average. Raises ------ ValueError If window_size is not odd. Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MovingAverageSmooth >>> >>> # Create noisy data >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.5 * np.random.randn(100) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Apply smoothing >>> smooth = MovingAverageSmooth(window_size=5) >>> result = smooth.apply_to(dataset) >>> >>> # Result has smoothed y-values (x unchanged) >>> result.independent_variable_data # Same as original >>> result.dependent_variable_data # Smoothed version Notes ----- - Uses 'same' mode for convolution (output same size as input) - Edge effects present at boundaries (first/last few points) - For JAX backend, convolution is JIT-compiled for efficiency - Window must be odd to ensure symmetric smoothing - Larger windows = more smoothing but more distortion """
[docs] def __init__(self, window_size: int = 5): """ Initialize moving average smoothing transform. Parameters ---------- window_size : int, default=5 Size of moving average window (must be odd). Raises ------ ValueError If window_size is even. """ super().__init__() if window_size % 2 == 0: raise ValueError("window_size must be odd") self.window_size = window_size
@staticmethod @jit def _convolve(y: Any, kernel: Any) -> Any: """JIT-compiled convolution for 3-5x speedup.""" return jnp.convolve(y, kernel, mode="same") def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply moving average smoothing to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset to smooth. Returns ------- OneDimensionalDataset Dataset with smoothed dependent variable. Notes ----- Modifies the dependent variable data in-place. Independent variable and metadata are preserved. """ # Convert to backend array y = jnp.asarray(dataset.dependent_variable_data) # Create uniform kernel (moving average) kernel = jnp.ones(self.window_size) / self.window_size # Apply JIT-compiled convolution: 3-5x faster y_smooth = MovingAverageSmooth._convolve(y, kernel) # type: ignore[call-arg] # Update dataset with smoothed data dataset._dependent_variable_data = y_smooth return dataset
[docs] class GaussianSmooth(DatasetTransform): """ Smooth data using Gaussian filter. This transform applies a Gaussian filter for smoothing, which provides better frequency response than simple moving average. Parameters ---------- sigma : float, default=1.0 Standard deviation of Gaussian kernel in units of data points. truncate : float, default=3.0 Truncate filter at this many standard deviations. Examples -------- >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import GaussianSmooth >>> import numpy as np >>> >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.3 * np.random.randn(100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> smooth = GaussianSmooth(sigma=2.0) >>> result = smooth.apply_to(dataset) Notes ----- - Gaussian smoothing preserves features better than moving average - sigma controls the amount of smoothing - Larger sigma = more smoothing """
[docs] def __init__(self, sigma: float = 1.0, truncate: float = 3.0): """ Initialize Gaussian smoothing transform. Parameters ---------- sigma : float, default=1.0 Standard deviation of Gaussian kernel. truncate : float, default=3.0 Truncate at this many standard deviations. """ super().__init__() self.sigma = sigma self.truncate = truncate
@staticmethod @jit def _convolve(y: Any, kernel: Any) -> Any: """JIT-compiled convolution for 3-5x speedup.""" return jnp.convolve(y, kernel, mode="same") def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply Gaussian smoothing to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset to smooth. Returns ------- OneDimensionalDataset Dataset with smoothed dependent variable. """ y = jnp.asarray(dataset.dependent_variable_data) # Create Gaussian kernel # Window size based on sigma and truncate radius = int(self.truncate * self.sigma + 0.5) x_kernel = jnp.arange(-radius, radius + 1) kernel = jnp.exp(-0.5 * (x_kernel / self.sigma) ** 2) kernel = kernel / jnp.sum(kernel) # Normalize # Apply JIT-compiled convolution: 3-5x faster y_smooth = GaussianSmooth._convolve(y, kernel) # type: ignore[call-arg] # Update dataset dataset._dependent_variable_data = y_smooth return dataset
__all__ = ["GaussianSmooth", "MovingAverageSmooth"]