"""
Interpolation transforms for 1D datasets.
This module provides interpolation transforms that resample datasets
to new x-values using various interpolation methods.
"""
from typing import Any
import numpy as np
from piblin_jax.backend import BACKEND, jnp
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform
[docs]
class Interpolate1D(DatasetTransform):
"""
Interpolate 1D dataset to new x-values.
This transform resamples a 1D dataset to a new set of independent
variable values using linear interpolation. It supports both JAX
and NumPy backends with automatic fallback.
Parameters
----------
new_x : array-like
New independent variable values for interpolation.
method : str, default='linear'
Interpolation method. Currently only 'linear' is fully supported.
Future versions may support 'cubic', 'spline', etc.
Attributes
----------
new_x : ndarray
Target x-values for interpolation.
method : str
Interpolation method name.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import Interpolate1D
>>>
>>> # Create original dataset
>>> x = np.array([0, 1, 2, 3, 4])
>>> y = np.array([0, 1, 4, 9, 16]) # y = x^2
>>> dataset = OneDimensionalDataset(
... independent_variable_data=x,
... dependent_variable_data=y
... )
>>>
>>> # Interpolate to finer grid
>>> new_x = np.linspace(0, 4, 17) # 17 points from 0 to 4
>>> interp = Interpolate1D(new_x, method='linear')
>>> result = interp.apply_to(dataset)
>>>
>>> # Result has new x-values with interpolated y-values
>>> result.independent_variable_data.shape
(17,)
Notes
-----
- Linear interpolation is used for both JAX and NumPy backends
- For JAX backend, uses jnp.interp (compiled with JIT)
- For NumPy backend, uses np.interp
- Extrapolation uses constant values (edge values)
- Metadata (conditions, details) is preserved from original dataset
"""
[docs]
def __init__(self, new_x: Any, method: str = "linear"):
"""
Initialize interpolation transform.
Parameters
----------
new_x : array-like
New independent variable values.
method : str, default='linear'
Interpolation method ('linear' supported).
"""
super().__init__()
self.new_x = jnp.asarray(new_x)
self.method = method
if method not in ["linear"]:
raise ValueError(
f"Interpolation method '{method}' not supported. "
"Currently only 'linear' is implemented."
)
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply interpolation to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset to interpolate.
Returns
-------
OneDimensionalDataset
New dataset with interpolated values.
Notes
-----
Creates a new dataset instance with interpolated data while
preserving metadata from the original dataset.
"""
# Get original data
x = dataset.independent_variable_data
y = dataset.dependent_variable_data
# Convert to backend arrays
x = jnp.asarray(x)
y = jnp.asarray(y)
# Perform interpolation
if BACKEND == "jax":
# JAX interpolation (JIT-compiled)
new_y = jnp.interp(self.new_x, x, y)
else:
# NumPy fallback
new_y = np.interp(self.new_x, x, y)
# Create new dataset with interpolated data
# Preserve metadata from original dataset
return OneDimensionalDataset(
independent_variable_data=self.new_x,
dependent_variable_data=new_y,
conditions=dataset.conditions,
details=dataset.details,
)
__all__ = ["Interpolate1D"]