"""
Smoothing transforms for 1D datasets.
This module provides various smoothing/filtering transforms to reduce
noise in time series and spectral data.
"""
from typing import Any
from piblin_jax.backend import jnp
from piblin_jax.backend.operations import jit
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform
[docs]
class MovingAverageSmooth(DatasetTransform):
"""
Smooth data using moving average filter.
This transform applies a simple moving average (box filter) to smooth
noisy data. It uses convolution for efficient computation.
Parameters
----------
window_size : int, default=5
Size of the moving average window. Must be odd to ensure symmetry.
Attributes
----------
window_size : int
Window size for moving average.
Raises
------
ValueError
If window_size is not odd.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import MovingAverageSmooth
>>>
>>> # Create noisy data
>>> x = np.linspace(0, 10, 100)
>>> y = np.sin(x) + 0.5 * np.random.randn(100)
>>> dataset = OneDimensionalDataset(
... independent_variable_data=x,
... dependent_variable_data=y
... )
>>>
>>> # Apply smoothing
>>> smooth = MovingAverageSmooth(window_size=5)
>>> result = smooth.apply_to(dataset)
>>>
>>> # Result has smoothed y-values (x unchanged)
>>> result.independent_variable_data # Same as original
>>> result.dependent_variable_data # Smoothed version
Notes
-----
- Uses 'same' mode for convolution (output same size as input)
- Edge effects present at boundaries (first/last few points)
- For JAX backend, convolution is JIT-compiled for efficiency
- Window must be odd to ensure symmetric smoothing
- Larger windows = more smoothing but more distortion
"""
[docs]
def __init__(self, window_size: int = 5):
"""
Initialize moving average smoothing transform.
Parameters
----------
window_size : int, default=5
Size of moving average window (must be odd).
Raises
------
ValueError
If window_size is even.
"""
super().__init__()
if window_size % 2 == 0:
raise ValueError("window_size must be odd")
self.window_size = window_size
@staticmethod
@jit
def _convolve(y: Any, kernel: Any) -> Any:
"""JIT-compiled convolution for 3-5x speedup."""
return jnp.convolve(y, kernel, mode="same")
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply moving average smoothing to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset to smooth.
Returns
-------
OneDimensionalDataset
Dataset with smoothed dependent variable.
Notes
-----
Modifies the dependent variable data in-place.
Independent variable and metadata are preserved.
"""
# Convert to backend array
y = jnp.asarray(dataset.dependent_variable_data)
# Create uniform kernel (moving average)
kernel = jnp.ones(self.window_size) / self.window_size
# Apply JIT-compiled convolution: 3-5x faster
y_smooth = MovingAverageSmooth._convolve(y, kernel) # type: ignore[call-arg]
# Update dataset with smoothed data
dataset._dependent_variable_data = y_smooth
return dataset
[docs]
class GaussianSmooth(DatasetTransform):
"""
Smooth data using Gaussian filter.
This transform applies a Gaussian filter for smoothing, which provides
better frequency response than simple moving average.
Parameters
----------
sigma : float, default=1.0
Standard deviation of Gaussian kernel in units of data points.
truncate : float, default=3.0
Truncate filter at this many standard deviations.
Examples
--------
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import GaussianSmooth
>>> import numpy as np
>>>
>>> x = np.linspace(0, 10, 100)
>>> y = np.sin(x) + 0.3 * np.random.randn(100)
>>> dataset = OneDimensionalDataset(x, y)
>>>
>>> smooth = GaussianSmooth(sigma=2.0)
>>> result = smooth.apply_to(dataset)
Notes
-----
- Gaussian smoothing preserves features better than moving average
- sigma controls the amount of smoothing
- Larger sigma = more smoothing
"""
[docs]
def __init__(self, sigma: float = 1.0, truncate: float = 3.0):
"""
Initialize Gaussian smoothing transform.
Parameters
----------
sigma : float, default=1.0
Standard deviation of Gaussian kernel.
truncate : float, default=3.0
Truncate at this many standard deviations.
"""
super().__init__()
self.sigma = sigma
self.truncate = truncate
@staticmethod
@jit
def _convolve(y: Any, kernel: Any) -> Any:
"""JIT-compiled convolution for 3-5x speedup."""
return jnp.convolve(y, kernel, mode="same")
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply Gaussian smoothing to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset to smooth.
Returns
-------
OneDimensionalDataset
Dataset with smoothed dependent variable.
"""
y = jnp.asarray(dataset.dependent_variable_data)
# Create Gaussian kernel
# Window size based on sigma and truncate
radius = int(self.truncate * self.sigma + 0.5)
x_kernel = jnp.arange(-radius, radius + 1)
kernel = jnp.exp(-0.5 * (x_kernel / self.sigma) ** 2)
kernel = kernel / jnp.sum(kernel) # Normalize
# Apply JIT-compiled convolution: 3-5x faster
y_smooth = GaussianSmooth._convolve(y, kernel) # type: ignore[call-arg]
# Update dataset
dataset._dependent_variable_data = y_smooth
return dataset
__all__ = ["GaussianSmooth", "MovingAverageSmooth"]