Source code for piblin_jax.transform.lambda_transform
"""
Lambda transforms and dynamic parameter transforms for piblin-jax.
This module provides:
- LambdaTransform: Wraps arbitrary functions as transforms
- DynamicTransform: Base class for data-driven parameter computation
- AutoScaleTransform: Automatic data scaling to target range
- AutoBaselineTransform: Automatic baseline correction from data
Lambda transforms allow users to create custom transforms from simple functions
without defining new classes. Dynamic transforms compute parameters from the
data itself, enabling adaptive processing.
"""
from collections.abc import Callable
from typing import Any
from piblin_jax.backend import jnp
from piblin_jax.backend.operations import jit
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform
[docs]
class LambdaTransform(DatasetTransform):
"""
Transform that wraps an arbitrary function.
Allows users to create custom transforms from simple functions
without defining new classes. Supports both simple functions that
operate on the dependent variable only, and functions that use
both independent and dependent variables.
Parameters
----------
func : Callable
Function to apply to dependent variable.
Signature: func(y_data: ndarray) -> ndarray
Or: func(x_data: ndarray, y_data: ndarray) -> ndarray
use_x : bool, default=False
If True, pass both x and y to func.
If False, pass only y to func.
jit_compile : bool, default=True
If True, attempt JIT compilation of the function.
Falls back to regular function if compilation fails.
Examples
--------
>>> from piblin_jax.transform import LambdaTransform
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> import numpy as np
>>>
>>> # Simple function
>>> transform = LambdaTransform(lambda y: y * 2.0)
>>> dataset = OneDimensionalDataset(
... independent_variable_data=np.array([1, 2, 3]),
... dependent_variable_data=np.array([2, 4, 6])
... )
>>> result = transform.apply_to(dataset)
>>>
>>> # Function using x and y
>>> transform = LambdaTransform(
... lambda x, y: y / x.max(),
... use_x=True
... )
>>>
>>> # JAX-compatible function for JIT
>>> import jax.numpy as jnp
>>> transform = LambdaTransform(
... lambda y: jnp.exp(y) * jnp.sin(y),
... jit_compile=True
... )
Notes
-----
- Functions should use jnp instead of np for JIT compilation
- JIT compilation improves performance but requires JAX-compatible code
- If compilation fails, the transform falls back to the uncompiled function
- Only works with OneDimensionalDataset
"""
[docs]
def __init__(
self,
func: Callable[..., Any] | None = None,
use_x: bool = False,
jit_compile: bool = True,
lambda_func: Callable[..., Any] | None = None,
):
"""
Initialize lambda transform.
Parameters
----------
func : Callable, optional
Function to wrap as transform (works on arrays)
use_x : bool
If True, pass both x and y to func
jit_compile : bool
If True, attempt JIT compilation
lambda_func : Callable, optional
Alias for func (for piblin compatibility)
Raises
------
TypeError
If func is not callable
ValueError
If neither func nor lambda_func is provided
Notes
-----
The function should operate on arrays, not dataset objects.
For example: `lambda y: y * 2.0` not `lambda ds: ds.dependent_variable_data * 2.0`
"""
super().__init__()
# Accept either func or lambda_func (for piblin compatibility)
if lambda_func is not None:
func = lambda_func
if func is None:
raise ValueError("Either func or lambda_func must be provided")
if not callable(func):
raise TypeError("func must be callable")
self.func = func
self.use_x = use_x
self.jit_compile = jit_compile
# Try to JIT compile if requested
self._compiled_func: Callable[..., Any]
if jit_compile:
try:
self._compiled_func = jit(func)
except Exception:
# JIT compilation failed, use regular function
self._compiled_func = func
else:
self._compiled_func = func
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply lambda function to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to transform
Returns
-------
OneDimensionalDataset
Transformed dataset (same object, modified in-place)
Raises
------
TypeError
If dataset is not OneDimensionalDataset
"""
if not isinstance(dataset, OneDimensionalDataset):
raise TypeError("LambdaTransform only works with OneDimensionalDataset")
# Get backend arrays
y_data = dataset._dependent_variable_data
# Apply function
if self.use_x:
x_data = dataset._independent_variable_data
y_transformed = self._compiled_func(x_data, y_data)
else:
y_transformed = self._compiled_func(y_data)
# Update dataset
dataset._dependent_variable_data = y_transformed
return dataset
[docs]
class DynamicTransform(DatasetTransform):
"""
Base class for transforms with data-driven parameters.
Dynamic transforms compute parameters from the data itself, then
apply transformations using those parameters. This enables adaptive
processing where the transformation depends on data characteristics.
Subclasses must implement:
- _compute_parameters(dataset): Extract parameters from data
- _apply_with_parameters(dataset, params): Apply transformation
Parameters
----------
None
Examples
--------
>>> from piblin_jax.transform.lambda_transform import DynamicTransform
>>> from piblin_jax.backend import jnp
>>>
>>> class CustomDynamicTransform(DynamicTransform):
... def _compute_parameters(self, dataset):
... y_data = dataset._dependent_variable_data
... return {'mean': jnp.mean(y_data)}
...
... def _apply_with_parameters(self, dataset, params):
... dataset._dependent_variable_data -= params['mean']
... return dataset
Notes
-----
- Parameters are computed fresh each time the transform is applied
- Caching can be implemented in subclasses if needed
- Only works with OneDimensionalDataset
"""
[docs]
def __init__(self) -> None:
"""Initialize dynamic transform."""
super().__init__()
self._cached_params: dict[str, Any] | None = None
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply with dynamically computed parameters.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to transform
Returns
-------
OneDimensionalDataset
Transformed dataset
"""
# Compute parameters from data
params = self._compute_parameters(dataset)
# Apply transformation
return self._apply_with_parameters(dataset, params)
def _compute_parameters(self, dataset: OneDimensionalDataset) -> dict[str, Any]:
"""
Extract parameters from dataset.
Override in subclasses to implement parameter computation.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to analyze
Returns
-------
dict
Parameters extracted from data
"""
raise NotImplementedError("Subclasses must implement _compute_parameters")
def _apply_with_parameters(
self, dataset: OneDimensionalDataset, params: dict[str, Any]
) -> OneDimensionalDataset:
"""
Apply transformation with parameters.
Override in subclasses to implement transformation logic.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to transform
params : dict
Parameters to use for transformation
Returns
-------
OneDimensionalDataset
Transformed dataset
"""
raise NotImplementedError("Subclasses must implement _apply_with_parameters")
[docs]
class AutoScaleTransform(DynamicTransform):
"""
Automatically scale data to specified range.
Computes min/max from data and scales to target range.
Useful for normalizing data to standard ranges like [0, 1] or [-1, 1].
Parameters
----------
target_min : float, default=0.0
Target minimum value after scaling
target_max : float, default=1.0
Target maximum value after scaling
Examples
--------
>>> from piblin_jax.transform import AutoScaleTransform
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> import numpy as np
>>>
>>> # Scale to [0, 1]
>>> transform = AutoScaleTransform()
>>> dataset = OneDimensionalDataset(
... independent_variable_data=np.array([1, 2, 3]),
... dependent_variable_data=np.array([10, 20, 30])
... )
>>> result = transform.apply_to(dataset)
>>> # result.dependent_variable_data is now [0, 0.5, 1]
>>>
>>> # Scale to [-1, 1]
>>> transform = AutoScaleTransform(target_min=-1.0, target_max=1.0)
Notes
-----
- Handles constant data (where min == max) by setting all values to target_min
- Preserves data ordering and relative differences
- Computed parameters: 'scale' (multiplicative factor) and 'offset' (additive shift)
"""
[docs]
def __init__(self, target_min: float = 0.0, target_max: float = 1.0):
"""
Initialize auto-scale transform.
Parameters
----------
target_min : float
Target minimum value
target_max : float
Target maximum value
"""
super().__init__()
self.target_min = target_min
self.target_max = target_max
def _compute_parameters(self, dataset: OneDimensionalDataset) -> dict[str, Any]:
"""
Compute scaling parameters from data.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to analyze
Returns
-------
dict
Dictionary with 'scale' and 'offset' parameters
"""
y_data = dataset._dependent_variable_data
data_min = jnp.min(y_data)
data_max = jnp.max(y_data)
# Avoid division by zero for constant data
if data_max == data_min:
scale = 0.0
offset = self.target_min
else:
scale = (self.target_max - self.target_min) / (data_max - data_min)
offset = self.target_min - data_min * scale
return {"scale": scale, "offset": offset}
def _apply_with_parameters(
self, dataset: OneDimensionalDataset, params: dict[str, Any]
) -> OneDimensionalDataset:
"""
Apply scaling transformation.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to transform
params : dict
Parameters with 'scale' and 'offset'
Returns
-------
OneDimensionalDataset
Scaled dataset
"""
y_data = dataset._dependent_variable_data
y_scaled = y_data * params["scale"] + params["offset"]
dataset._dependent_variable_data = y_scaled
return dataset
[docs]
class AutoBaselineTransform(DynamicTransform):
"""
Automatically subtract baseline computed from data.
Computes baseline from first/last N points or minimum value,
then subtracts it from all data. Useful for removing offsets
and drift in experimental measurements.
Parameters
----------
n_points : int, default=10
Number of points to use for baseline computation
(only used for 'first' and 'last' methods)
method : str, default='first'
Method for computing baseline:
- 'first': Mean of first n_points
- 'last': Mean of last n_points
- 'min': Minimum value in data
Examples
--------
>>> from piblin_jax.transform import AutoBaselineTransform
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> import numpy as np
>>>
>>> # Subtract baseline from first 10 points
>>> transform = AutoBaselineTransform(n_points=10, method='first')
>>> dataset = OneDimensionalDataset(
... independent_variable_data=np.arange(100),
... dependent_variable_data=np.random.randn(100) + 5.0 # offset by 5
... )
>>> result = transform.apply_to(dataset)
>>>
>>> # Subtract minimum value
>>> transform = AutoBaselineTransform(method='min')
Notes
-----
- 'first' method: Good for time series where initial values are baseline
- 'last' method: Good for measurements that return to baseline
- 'min' method: Good for ensuring all values are non-negative
- Computed parameter: 'baseline' (value to subtract)
"""
[docs]
def __init__(self, n_points: int = 10, method: str = "first"):
"""
Initialize auto-baseline transform.
Parameters
----------
n_points : int
Number of points for baseline
method : str
Baseline computation method ('first', 'last', 'min')
"""
super().__init__()
self.n_points = n_points
self.method = method
def _compute_parameters(self, dataset: OneDimensionalDataset) -> dict[str, Any]:
"""
Compute baseline from data.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to analyze
Returns
-------
dict
Dictionary with 'baseline' parameter
Raises
------
ValueError
If method is not 'first', 'last', or 'min'
"""
y_data = dataset._dependent_variable_data
if self.method == "first":
baseline = jnp.mean(y_data[: self.n_points])
elif self.method == "last":
baseline = jnp.mean(y_data[-self.n_points :])
elif self.method == "min":
baseline = jnp.min(y_data)
else:
raise ValueError(f"Unknown method: {self.method}")
return {"baseline": baseline}
def _apply_with_parameters(
self, dataset: OneDimensionalDataset, params: dict[str, Any]
) -> OneDimensionalDataset:
"""
Subtract baseline.
Parameters
----------
dataset : OneDimensionalDataset
Dataset to transform
params : dict
Parameters with 'baseline'
Returns
-------
OneDimensionalDataset
Baseline-corrected dataset
"""
y_data = dataset._dependent_variable_data
y_corrected = y_data - params["baseline"]
dataset._dependent_variable_data = y_corrected
return dataset
__all__ = [
"AutoBaselineTransform",
"AutoScaleTransform",
"DynamicTransform",
"LambdaTransform",
]