Transformations
Overview
The piblin_jax.transform module provides a comprehensive framework for data processing
and manipulation in piblin_jax. Built on functional programming principles, the transform
system enables composable, reusable, and efficient data pipelines with JAX optimization
support.
The transform architecture is designed around several key principles:
Hierarchy-Aware: Transforms operate at different levels of the data hierarchy (Dataset, Measurement, MeasurementSet, etc.). Each transform level knows how to process its corresponding data structure while preserving metadata and relationships.
Composability: Transforms can be chained together using the Pipeline pattern, allowing complex operations to be built from simple, well-tested components. Pipelines themselves are transforms, enabling recursive composition.
JAX Optimization: The framework supports lazy evaluation and JIT compilation for high-performance numerical processing. Transform pipelines can be compiled once and executed efficiently on GPU/TPU devices.
Extensibility: Custom transforms can be created by subclassing base transform classes or by wrapping arbitrary functions using LambdaTransform. The framework also supports dynamic transforms that adapt parameters based on input data.
The module includes built-in transforms for common operations like smoothing, interpolation, baseline correction, normalization, and calculus operations, as well as collection-level transforms for filtering, splitting, and merging datasets.
Quick Examples
Basic Transform Pipeline
Chain multiple transforms together for sequential processing:
from piblin_jax.transform import Pipeline
from piblin_jax.transform.dataset import SmoothingTransform, NormalizeTransform
# Create a pipeline
pipeline = Pipeline([
SmoothingTransform(window_length=5, polyorder=2),
NormalizeTransform(method="minmax")
])
# Apply to dataset
processed_dataset = pipeline.apply(dataset)
Lazy Evaluation for Performance
Use lazy pipelines for JIT compilation and deferred execution:
from piblin_jax.transform import LazyPipeline
from piblin_jax.transform.dataset import InterpolateTransform, BaselineTransform
# Create lazy pipeline
lazy_pipeline = LazyPipeline([
InterpolateTransform(num_points=1000),
BaselineTransform(method="polynomial", degree=2)
])
# Apply lazily (can be JIT compiled)
result = lazy_pipeline.apply(dataset)
processed = result.compute() # Trigger computation
Custom Lambda Transforms
Create custom transforms from functions:
from piblin_jax.transform import LambdaTransform
import numpy as np
def custom_processing(dataset):
# Custom logic
new_y = np.log10(dataset.y)
return dataset.replace(y=new_y)
# Wrap as transform
log_transform = LambdaTransform(custom_processing)
# Use in pipeline
pipeline = Pipeline([
log_transform,
NormalizeTransform()
])
See Also
Data Structures - Data structures that transforms operate on
Backend Abstraction - JAX/NumPy backend for array operations
JAX Transformations - JAX JIT compilation
API Reference
Module Contents
Transform system for piblin-jax.
This module provides the transform framework for data processing: - Base transform classes for each hierarchy level - Pipeline composition for sequential transforms - Lazy evaluation for JAX optimization - JIT compilation support - Region-based transforms for selective processing - Lambda transforms for user-defined functions - Dynamic transforms for data-driven parameters - Core dataset transforms (interpolation, smoothing, normalization, etc.) - Collection-level transforms (filtering, splitting, merging)
Hierarchy: - Transform: Abstract base class - DatasetTransform: Operates on Dataset objects - MeasurementTransform: Operates on Measurement objects - MeasurementSetTransform: Operates on MeasurementSet objects - ExperimentTransform: Operates on Experiment objects - ExperimentSetTransform: Operates on ExperimentSet objects
Pipeline: - Pipeline: Sequential composition of transforms - LazyPipeline: Pipeline with lazy evaluation
Region-Based: - RegionTransform: Base class for region-based transforms - RegionMultiplyTransform: Example region-based transform
Lambda and Dynamic: - LambdaTransform: Wrap arbitrary functions as transforms - DynamicTransform: Base class for data-driven transforms - AutoScaleTransform: Automatic data scaling - AutoBaselineTransform: Automatic baseline correction
Dataset Transforms:
dataset: Module containing core dataset-level transforms (Interpolation, smoothing, baseline correction, normalization, calculus)
Measurement Transforms:
measurement: Module containing collection-level transforms (FilterDatasets, FilterMeasurements, SplitByRegion, MergeReplicates)
- class piblin_jax.transform.AutoBaselineTransform(n_points=10, method='first')[source]
Bases:
DynamicTransformAutomatically subtract baseline computed from data.
Computes baseline from first/last N points or minimum value, then subtracts it from all data. Useful for removing offsets and drift in experimental measurements.
- Parameters:
Examples
>>> from piblin_jax.transform import AutoBaselineTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Subtract baseline from first 10 points >>> transform = AutoBaselineTransform(n_points=10, method='first') >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.arange(100), ... dependent_variable_data=np.random.randn(100) + 5.0 # offset by 5 ... ) >>> result = transform.apply_to(dataset) >>> >>> # Subtract minimum value >>> transform = AutoBaselineTransform(method='min')
Notes
‘first’ method: Good for time series where initial values are baseline
‘last’ method: Good for measurements that return to baseline
‘min’ method: Good for ensuring all values are non-negative
Computed parameter: ‘baseline’ (value to subtract)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.AutoScaleTransform(target_min=0.0, target_max=1.0)[source]
Bases:
DynamicTransformAutomatically scale data to specified range.
Computes min/max from data and scales to target range. Useful for normalizing data to standard ranges like [0, 1] or [-1, 1].
- Parameters:
Examples
>>> from piblin_jax.transform import AutoScaleTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Scale to [0, 1] >>> transform = AutoScaleTransform() >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.array([1, 2, 3]), ... dependent_variable_data=np.array([10, 20, 30]) ... ) >>> result = transform.apply_to(dataset) >>> # result.dependent_variable_data is now [0, 0.5, 1] >>> >>> # Scale to [-1, 1] >>> transform = AutoScaleTransform(target_min=-1.0, target_max=1.0)
Notes
Handles constant data (where min == max) by setting all values to target_min
Preserves data ordering and relative differences
Computed parameters: ‘scale’ (multiplicative factor) and ‘offset’ (additive shift)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.DatasetTransform[source]
Bases:
Transform[Dataset]Transform that operates on Dataset objects.
This is the lowest level transform in the hierarchy, operating on individual datasets (1D, 2D, 3D, etc.).
Examples
>>> from piblin_jax.transform.base import DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> class SmoothTransform(DatasetTransform): ... def _apply(self, dataset): ... # Smooth y-values with moving average ... dataset.y_data = np.convolve( ... dataset.y_data, ... np.ones(3)/3, ... mode='same' ... ) ... return dataset >>> >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3, 4, 5]), ... y_data=np.array([1, 5, 2, 8, 3]) ... ) >>> transform = SmoothTransform() >>> smoothed = transform.apply_to(dataset)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.DynamicTransform[source]
Bases:
DatasetTransformBase class for transforms with data-driven parameters.
Dynamic transforms compute parameters from the data itself, then apply transformations using those parameters. This enables adaptive processing where the transformation depends on data characteristics.
Subclasses must implement: - _compute_parameters(dataset): Extract parameters from data - _apply_with_parameters(dataset, params): Apply transformation
- Parameters:
None
Examples
>>> from piblin_jax.transform.lambda_transform import DynamicTransform >>> from piblin_jax.backend import jnp >>> >>> class CustomDynamicTransform(DynamicTransform): ... def _compute_parameters(self, dataset): ... y_data = dataset._dependent_variable_data ... return {'mean': jnp.mean(y_data)} ... ... def _apply_with_parameters(self, dataset, params): ... dataset._dependent_variable_data -= params['mean'] ... return dataset
Notes
Parameters are computed fresh each time the transform is applied
Caching can be implemented in subclasses if needed
Only works with OneDimensionalDataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.ExperimentSetTransform[source]
Bases:
Transform[ExperimentSet]Transform that operates on ExperimentSet objects.
ExperimentSets are the top-level container for multiple experiments. This transform level can operate across all experiments in a set.
Examples
>>> from piblin_jax.transform.base import ExperimentSetTransform >>> >>> class CrossExperimentNormalizeTransform(ExperimentSetTransform): ... def _apply(self, experiment_set): ... # Normalize across all experiments ... global_max = 0 ... for exp in experiment_set.experiments.values(): ... for mset in exp.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... global_max = max(global_max, ds.y_data.max()) ... ... for exp in experiment_set.experiments.values(): ... for mset in exp.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data / global_max ... return experiment_set
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to ExperimentSet.
- class piblin_jax.transform.ExperimentTransform[source]
Bases:
Transform[Experiment]Transform that operates on Experiment objects.
Experiments contain multiple measurement sets. This transform level can operate across measurement sets within an experiment.
Examples
>>> from piblin_jax.transform.base import ExperimentTransform >>> >>> class TemperatureCorrectionTransform(ExperimentTransform): ... def _apply(self, experiment): ... # Apply temperature correction to all measurement sets ... temp = experiment.metadata.get('temperature', 300) ... correction = 1.0 + (temp - 300) * 0.001 ... ... for mset in experiment.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data * correction ... return experiment
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Experiment.
- class piblin_jax.transform.LambdaTransform(func=None, use_x=False, jit_compile=True, lambda_func=None)[source]
Bases:
DatasetTransformTransform that wraps an arbitrary function.
Allows users to create custom transforms from simple functions without defining new classes. Supports both simple functions that operate on the dependent variable only, and functions that use both independent and dependent variables.
- Parameters:
func (
Callable) – Function to apply to dependent variable. Signature: func(y_data: ndarray) -> ndarray Or: func(x_data: ndarray, y_data: ndarray) -> ndarrayuse_x (
bool, defaultFalse) – If True, pass both x and y to func. If False, pass only y to func.jit_compile (
bool, defaultTrue) – If True, attempt JIT compilation of the function. Falls back to regular function if compilation fails.
Examples
>>> from piblin_jax.transform import LambdaTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Simple function >>> transform = LambdaTransform(lambda y: y * 2.0) >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.array([1, 2, 3]), ... dependent_variable_data=np.array([2, 4, 6]) ... ) >>> result = transform.apply_to(dataset) >>> >>> # Function using x and y >>> transform = LambdaTransform( ... lambda x, y: y / x.max(), ... use_x=True ... ) >>> >>> # JAX-compatible function for JIT >>> import jax.numpy as jnp >>> transform = LambdaTransform( ... lambda y: jnp.exp(y) * jnp.sin(y), ... jit_compile=True ... )
Notes
Functions should use jnp instead of np for JIT compilation
JIT compilation improves performance but requires JAX-compatible code
If compilation fails, the transform falls back to the uncompiled function
Only works with OneDimensionalDataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(func=None, use_x=False, jit_compile=True, lambda_func=None)[source]
Initialize lambda transform.
- Parameters:
- Raises:
TypeError – If func is not callable
ValueError – If neither func nor lambda_func is provided
Notes
The function should operate on arrays, not dataset objects. For example: lambda y: y * 2.0 not lambda ds: ds.dependent_variable_data * 2.0
- class piblin_jax.transform.LazyPipeline(transforms=None)[source]
Bases:
Pipeline[T]Pipeline with lazy evaluation support.
Unlike the standard Pipeline, LazyPipeline defers computation until the results are actually accessed. This allows JAX to optimize the entire computation graph as a single operation.
Lazy evaluation is triggered on: - Property access (e.g., result.y_data) - Method calls (e.g., result.visualize()) - Export operations (e.g., result.export())
- Parameters:
transforms (
list[Transform], optional) – Initial list of transforms to include in pipeline
Examples
>>> from piblin_jax.transform import LazyPipeline >>> >>> # Create lazy pipeline >>> pipeline = LazyPipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), ... ]) >>> >>> # Apply to dataset (computation deferred) >>> lazy_result = pipeline.apply_to(dataset, make_copy=True) >>> >>> # Access property (triggers computation) >>> y_values = lazy_result.y_data # Computation happens here
Notes
Lazy evaluation allows JAX to optimize the entire pipeline
First property access triggers computation and caches result
Subsequent accesses use cached result
More efficient than eager evaluation for complex pipelines
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
append(transform)Add transform to end of pipeline.
apply_to(target[, make_copy, ...])Apply lazy pipeline to target.
clear()count(value)extend(values)S.extend(iterable) -- extend sequence by appending elements from the iterable
index(value, [start, [stop]])Raises ValueError if the value is not present.
insert(index, value)Insert transform at index.
Invalidate cached results.
pop([index])Raise IndexError if list is empty or index is out of range.
remove(value)S.remove(value) -- remove first occurrence of value.
reverse()S.reverse() -- reverse IN PLACE
- __init__(transforms=None)[source]
Initialize lazy pipeline.
- Parameters:
transforms (
list[Transform], optional) – Initial transforms to include in pipeline
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply lazy pipeline to target.
Computation is deferred until results are accessed. Returns a LazyResult wrapper that triggers computation on property/method access.
- Parameters:
- Returns:
Wrapper that triggers computation on access
- Return type:
Notes
The actual transformation is not performed until the result is accessed. This allows JAX to optimize the entire computation graph.
- invalidate_cache()[source]
Invalidate cached results.
Forces recomputation on next access. Useful if transforms have been modified or parameters changed.
Examples
>>> pipeline = LazyPipeline([transform1, transform2]) >>> result = pipeline.apply_to(dataset) >>> _ = result.y_data # Triggers computation >>> >>> # Modify pipeline >>> pipeline.append(transform3) >>> pipeline.invalidate_cache() # Force recomputation
- class piblin_jax.transform.LazyResult(pipeline)[source]
Bases:
objectWrapper that triggers lazy computation on property access.
This class wraps the actual result and defers computation until properties or methods are accessed.
- Parameters:
pipeline (
LazyPipeline) – The lazy pipeline that will compute the result
Examples
>>> lazy_result = LazyResult(pipeline) >>> # No computation yet >>> y = lazy_result.y_data # Triggers computation here >>> # Subsequent accesses use cached result >>> x = lazy_result.x_data # No recomputation
Notes
This class is transparent to the user - it behaves like the actual result object, but triggers computation on first access.
- __getattr__(name)[source]
Get attribute from computed result.
Triggers computation on first access.
- Parameters:
name (
str) – Attribute name- Returns:
Attribute value from computed result
- Return type:
Any
- __init__(pipeline)[source]
Initialize lazy result wrapper.
- Parameters:
pipeline (
LazyPipeline) – Pipeline that will compute the result
- class piblin_jax.transform.MeasurementSetTransform[source]
Bases:
Transform[MeasurementSet]Transform that operates on MeasurementSet objects.
MeasurementSets contain multiple measurements. This transform level can operate across measurements (e.g., normalization relative to the entire set).
Examples
>>> from piblin_jax.transform.base import MeasurementSetTransform >>> >>> class GlobalNormalizeTransform(MeasurementSetTransform): ... def _apply(self, measurement_set): ... # Find global max across all measurements ... global_max = 0 ... for meas in measurement_set.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... global_max = max(global_max, ds.y_data.max()) ... ... # Normalize all datasets by global max ... for meas in measurement_set.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data / global_max ... return measurement_set
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- class piblin_jax.transform.MeasurementTransform[source]
Bases:
Transform[Measurement]Transform that operates on Measurement objects.
Measurements contain multiple datasets with associated metadata. This transform level can operate on all datasets within a measurement.
Examples
>>> from piblin_jax.transform.base import MeasurementTransform >>> >>> class NormalizeTransform(MeasurementTransform): ... def _apply(self, measurement): ... # Normalize all datasets in measurement ... for dataset in measurement.datasets.values(): ... if hasattr(dataset, 'y_data'): ... max_val = dataset.y_data.max() ... dataset.y_data = dataset.y_data / max_val ... return measurement
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
- class piblin_jax.transform.Pipeline(transforms=None)[source]
Bases:
Transform[T],MutableSequence[Transform[T]]Pipeline for composing multiple transforms sequentially.
A pipeline applies a sequence of transforms to data in order. It implements the MutableSequence interface, so it can be used like a list of transforms.
The pipeline is memory-efficient: when make_copy=True, it creates a single copy at entry, then applies all transforms in-place.
- Parameters:
transforms (
list[Transform], optional) – Initial list of transforms to include in pipeline
Examples
>>> from piblin_jax.transform import Pipeline, DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Create transforms >>> class MultiplyTransform(DatasetTransform): ... def __init__(self, factor): ... super().__init__() ... self.factor = factor ... ... def _apply(self, dataset): ... dataset.y_data = dataset.y_data * self.factor ... return dataset >>> >>> # Create pipeline >>> pipeline = Pipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), # Net effect: 6x ... ]) >>> >>> # Apply to dataset >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3]), ... y_data=np.array([2, 4, 6]) ... ) >>> result = pipeline.apply_to(dataset, make_copy=True) >>> # result.y_data is now [12, 24, 36]
Notes
Pipelines can be nested: a pipeline can contain other pipelines
Only one copy is made at entry, then all transforms apply in-place
This is much more memory efficient than copying at each step
Use lazy evaluation for even better performance with JAX
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
append(transform)Add transform to end of pipeline.
apply_to(target[, make_copy, ...])Apply pipeline to target.
clear()count(value)extend(values)S.extend(iterable) -- extend sequence by appending elements from the iterable
index(value, [start, [stop]])Raises ValueError if the value is not present.
insert(index, value)Insert transform at index.
pop([index])Raise IndexError if list is empty or index is out of range.
remove(value)S.remove(value) -- remove first occurrence of value.
reverse()S.reverse() -- reverse IN PLACE
- __delitem__(index)[source]
Delete transform(s) at index.
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> del pipeline[0] # Remove first transform >>> del pipeline[1:] # Remove all but first transform
- __getitem__(index)[source]
Get transform(s) at index.
- Parameters:
- Returns:
Transform at index, or list of transforms for slice
- Return type:
Transformorlist[Transform]
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] # Get first transform >>> pipeline[1:3] # Get slice of transforms
- __init__(transforms=None)[source]
Initialize pipeline.
- Parameters:
transforms (
list[Transform], optional) – Initial transforms to include in pipeline
- __len__()[source]
Get number of transforms in pipeline.
- Returns:
Number of transforms
- Return type:
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> len(pipeline) 3
- __repr__()[source]
String representation of pipeline.
- Returns:
String representation showing number of transforms
- Return type:
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> repr(pipeline) 'Pipeline(3 transforms)'
- __setitem__(index, value)[source]
Set transform(s) at index.
- Parameters:
- Raises:
TypeError – If value is not a Transform instance
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] = new_transform # Replace first transform
- __str__()[source]
Human-readable string representation.
- Returns:
String showing all transforms in pipeline
- Return type:
- append(transform)[source]
Add transform to end of pipeline.
- Parameters:
transform (
Transform) – Transform to append- Raises:
TypeError – If transform is not a Transform instance
Examples
>>> pipeline = Pipeline([t1, t2]) >>> pipeline.append(t3) # Add t3 to end
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply pipeline to target.
Only makes copy once at entry, then applies all transforms in-place for memory efficiency.
- Parameters:
- Returns:
Transformed data structure
- Return type:
T
Notes
This is much more efficient than copying at each transform step. The single copy at entry ensures immutability while minimizing memory overhead.
When propagate_uncertainty=True, uncertainty is efficiently propagated through the entire pipeline in a single pass.
- class piblin_jax.transform.RegionMultiplyTransform(region, factor)[source]
Bases:
RegionTransformExample transform: Multiply region by a factor.
This is a concrete implementation of RegionTransform that multiplies the dependent variable within the specified region(s) by a constant factor.
- Parameters:
region (
LinearRegion | CompoundRegion) – Region(s) to transformfactor (
float) – Multiplication factor
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion, CompoundRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Single region example >>> x_data = np.linspace(0, 10, 11) >>> y_data = np.ones(11) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> region = LinearRegion(x_min=3.0, x_max=7.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [3, 7] are multiplied by 2.0, others unchanged >>> # Multiple disjoint regions example >>> region1 = LinearRegion(x_min=1.0, x_max=2.0) >>> region2 = LinearRegion(x_min=8.0, x_max=9.0) >>> compound = CompoundRegion([region1, region2]) >>> transform = RegionMultiplyTransform(compound, factor=0.5) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [1, 2] OR [8, 9] are multiplied by 0.5
Notes
This is a simple example transform for demonstration and testing. More complex transforms can be implemented following the same pattern.
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.RegionTransform(region)[source]
Bases:
DatasetTransformBase class for transforms that operate on specific regions.
RegionTransform applies a transformation only within specified region(s), preserving data outside the regions. This enables selective processing of data based on independent variable ranges.
Subclasses should implement the _apply_to_region() method to define the specific transformation to apply within the region(s).
- Parameters:
region (
LinearRegion | CompoundRegion) – Region(s) to transform
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Create dataset >>> x_data = np.array([0, 1, 2, 3, 4, 5]) >>> y_data = np.array([1, 1, 1, 1, 1, 1]) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> # Define region and transform >>> region = LinearRegion(x_min=2.0, x_max=4.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> # Apply transform (only region [2, 4] is multiplied) >>> result = transform.apply_to(dataset, make_copy=True) >>> result.dependent_variable_data array([1., 1., 2., 2., 2., 1.])
Notes
Currently optimized for OneDimensionalDataset
Data outside regions is preserved exactly
Transformations use NumPy arrays internally for compatibility
Region masks are generated from the independent variable
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.Transform[source]
-
Abstract base class for all transforms.
Transforms operate on data structures at various hierarchy levels: - Dataset level: individual measurements - Measurement level: collections of datasets - MeasurementSet level: collections of measurements - Experiment level: collections of measurement sets - ExperimentSet level: collections of experiments
Transforms support: - Lazy evaluation: computation deferred until results accessed - JIT compilation: automatic compilation with JAX backend - Immutability: optional copying for functional programming style - Pipeline composition: chaining multiple transforms
Examples
>>> from piblin_jax.transform.base import DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> class MultiplyTransform(DatasetTransform): ... def __init__(self, factor): ... super().__init__() ... self.factor = factor ... ... def _apply(self, dataset): ... dataset.y_data = dataset.y_data * self.factor ... return dataset >>> >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3]), ... y_data=np.array([2, 4, 6]) ... ) >>> transform = MultiplyTransform(2.0) >>> result = transform.apply_to(dataset, make_copy=True)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to target data structure.
- __call__(target, make_copy=True, propagate_uncertainty=False)[source]
Shorthand for apply_to.
Allows using transform objects as callables: >>> result = transform(data)
instead of: >>> result = transform.apply_to(data)
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply transform to target data structure.
This is the main public interface for applying transforms. It handles copying (if requested) and delegates to the subclass-specific _apply method.
- Parameters:
target (
T) – Data structure to transform (Dataset, Measurement, etc.)make_copy (
bool, defaultTrue) – If True, create a deep copy before transforming (default). If False, transform in-place (more memory efficient but modifies the original object).propagate_uncertainty (
bool, defaultFalse) – If True and target has uncertainty samples, propagate uncertainty through the transform using Monte Carlo sampling.
- Returns:
Transformed data structure
- Return type:
T
Examples
>>> transform = MyTransform() >>> # Create copy and transform >>> result = transform.apply_to(data, make_copy=True) >>> # Transform in-place (memory efficient) >>> result = transform.apply_to(data, make_copy=False) >>> # With uncertainty propagation >>> result = transform.apply_to(data_with_unc, propagate_uncertainty=True)
Notes
make_copy=True ensures functional programming style (immutability)
make_copy=False is more memory efficient for large datasets
In pipelines, only the first copy is made at entry
Uncertainty propagation applies the transform to each uncertainty sample
- piblin_jax.transform.jit_transform(func)[source]
Decorator to enable JIT compilation for transform _apply methods.
This decorator automatically compiles transform methods using JAX’s JIT compiler when the JAX backend is available. For NumPy backend, it gracefully falls back to the uncompiled function.
- Parameters:
func (
Callable) – The _apply method to compile- Returns:
JIT-compiled function (JAX) or original function (NumPy)
- Return type:
Callable
Examples
>>> class MyTransform(DatasetTransform): ... @jit_transform ... def _apply(self, dataset): ... # This will be JIT compiled with JAX ... dataset.y_data = dataset.y_data * 2.0 ... return dataset
Notes
JIT compilation can significantly improve performance
Only works with JAX backend (graceful fallback for NumPy)
First call may be slow (compilation), subsequent calls are fast
Static arguments should be marked appropriately
Base Transform Classes
Transform base classes for piblin-jax.
This module provides the abstract base class and hierarchy for all transforms: - Transform: Abstract base class for all transforms - DatasetTransform: Operates on Dataset objects - MeasurementTransform: Operates on Measurement objects - MeasurementSetTransform: Operates on MeasurementSet objects - ExperimentTransform: Operates on Experiment objects - ExperimentSetTransform: Operates on ExperimentSet objects
Transforms support: - Lazy evaluation (computation deferred until needed) - JIT compilation (via JAX backend) - Immutability (via make_copy parameter) - Pipeline composition (via Pipeline class)
- class piblin_jax.transform.base.DatasetTransform[source]
Bases:
Transform[Dataset]Transform that operates on Dataset objects.
This is the lowest level transform in the hierarchy, operating on individual datasets (1D, 2D, 3D, etc.).
Examples
>>> from piblin_jax.transform.base import DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> class SmoothTransform(DatasetTransform): ... def _apply(self, dataset): ... # Smooth y-values with moving average ... dataset.y_data = np.convolve( ... dataset.y_data, ... np.ones(3)/3, ... mode='same' ... ) ... return dataset >>> >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3, 4, 5]), ... y_data=np.array([1, 5, 2, 8, 3]) ... ) >>> transform = SmoothTransform() >>> smoothed = transform.apply_to(dataset)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.base.ExperimentSetTransform[source]
Bases:
Transform[ExperimentSet]Transform that operates on ExperimentSet objects.
ExperimentSets are the top-level container for multiple experiments. This transform level can operate across all experiments in a set.
Examples
>>> from piblin_jax.transform.base import ExperimentSetTransform >>> >>> class CrossExperimentNormalizeTransform(ExperimentSetTransform): ... def _apply(self, experiment_set): ... # Normalize across all experiments ... global_max = 0 ... for exp in experiment_set.experiments.values(): ... for mset in exp.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... global_max = max(global_max, ds.y_data.max()) ... ... for exp in experiment_set.experiments.values(): ... for mset in exp.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data / global_max ... return experiment_set
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to ExperimentSet.
- class piblin_jax.transform.base.ExperimentTransform[source]
Bases:
Transform[Experiment]Transform that operates on Experiment objects.
Experiments contain multiple measurement sets. This transform level can operate across measurement sets within an experiment.
Examples
>>> from piblin_jax.transform.base import ExperimentTransform >>> >>> class TemperatureCorrectionTransform(ExperimentTransform): ... def _apply(self, experiment): ... # Apply temperature correction to all measurement sets ... temp = experiment.metadata.get('temperature', 300) ... correction = 1.0 + (temp - 300) * 0.001 ... ... for mset in experiment.measurement_sets.values(): ... for meas in mset.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data * correction ... return experiment
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Experiment.
- class piblin_jax.transform.base.MeasurementSetTransform[source]
Bases:
Transform[MeasurementSet]Transform that operates on MeasurementSet objects.
MeasurementSets contain multiple measurements. This transform level can operate across measurements (e.g., normalization relative to the entire set).
Examples
>>> from piblin_jax.transform.base import MeasurementSetTransform >>> >>> class GlobalNormalizeTransform(MeasurementSetTransform): ... def _apply(self, measurement_set): ... # Find global max across all measurements ... global_max = 0 ... for meas in measurement_set.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... global_max = max(global_max, ds.y_data.max()) ... ... # Normalize all datasets by global max ... for meas in measurement_set.measurements.values(): ... for ds in meas.datasets.values(): ... if hasattr(ds, 'y_data'): ... ds.y_data = ds.y_data / global_max ... return measurement_set
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- class piblin_jax.transform.base.MeasurementTransform[source]
Bases:
Transform[Measurement]Transform that operates on Measurement objects.
Measurements contain multiple datasets with associated metadata. This transform level can operate on all datasets within a measurement.
Examples
>>> from piblin_jax.transform.base import MeasurementTransform >>> >>> class NormalizeTransform(MeasurementTransform): ... def _apply(self, measurement): ... # Normalize all datasets in measurement ... for dataset in measurement.datasets.values(): ... if hasattr(dataset, 'y_data'): ... max_val = dataset.y_data.max() ... dataset.y_data = dataset.y_data / max_val ... return measurement
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
- class piblin_jax.transform.base.Transform[source]
-
Abstract base class for all transforms.
Transforms operate on data structures at various hierarchy levels: - Dataset level: individual measurements - Measurement level: collections of datasets - MeasurementSet level: collections of measurements - Experiment level: collections of measurement sets - ExperimentSet level: collections of experiments
Transforms support: - Lazy evaluation: computation deferred until results accessed - JIT compilation: automatic compilation with JAX backend - Immutability: optional copying for functional programming style - Pipeline composition: chaining multiple transforms
Examples
>>> from piblin_jax.transform.base import DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> class MultiplyTransform(DatasetTransform): ... def __init__(self, factor): ... super().__init__() ... self.factor = factor ... ... def _apply(self, dataset): ... dataset.y_data = dataset.y_data * self.factor ... return dataset >>> >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3]), ... y_data=np.array([2, 4, 6]) ... ) >>> transform = MultiplyTransform(2.0) >>> result = transform.apply_to(dataset, make_copy=True)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to target data structure.
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply transform to target data structure.
This is the main public interface for applying transforms. It handles copying (if requested) and delegates to the subclass-specific _apply method.
- Parameters:
target (
T) – Data structure to transform (Dataset, Measurement, etc.)make_copy (
bool, defaultTrue) – If True, create a deep copy before transforming (default). If False, transform in-place (more memory efficient but modifies the original object).propagate_uncertainty (
bool, defaultFalse) – If True and target has uncertainty samples, propagate uncertainty through the transform using Monte Carlo sampling.
- Returns:
Transformed data structure
- Return type:
T
Examples
>>> transform = MyTransform() >>> # Create copy and transform >>> result = transform.apply_to(data, make_copy=True) >>> # Transform in-place (memory efficient) >>> result = transform.apply_to(data, make_copy=False) >>> # With uncertainty propagation >>> result = transform.apply_to(data_with_unc, propagate_uncertainty=True)
Notes
make_copy=True ensures functional programming style (immutability)
make_copy=False is more memory efficient for large datasets
In pipelines, only the first copy is made at entry
Uncertainty propagation applies the transform to each uncertainty sample
- piblin_jax.transform.base.jit_transform(func)[source]
Decorator to enable JIT compilation for transform _apply methods.
This decorator automatically compiles transform methods using JAX’s JIT compiler when the JAX backend is available. For NumPy backend, it gracefully falls back to the uncompiled function.
- Parameters:
func (
Callable) – The _apply method to compile- Returns:
JIT-compiled function (JAX) or original function (NumPy)
- Return type:
Callable
Examples
>>> class MyTransform(DatasetTransform): ... @jit_transform ... def _apply(self, dataset): ... # This will be JIT compiled with JAX ... dataset.y_data = dataset.y_data * 2.0 ... return dataset
Notes
JIT compilation can significantly improve performance
Only works with JAX backend (graceful fallback for NumPy)
First call may be slow (compilation), subsequent calls are fast
Static arguments should be marked appropriately
Pipeline
Pipeline composition for transforms.
This module provides pipeline functionality for composing multiple transforms: - Pipeline: Sequential composition of transforms - LazyPipeline: Pipeline with lazy evaluation support
Pipelines support: - MutableSequence interface (list-like operations) - Sequential transform application - Single copy at entry (memory efficient) - Lazy evaluation (computation deferred) - JIT compilation (entire pipeline)
- class piblin_jax.transform.pipeline.LazyPipeline(transforms=None)[source]
Bases:
Pipeline[T]Pipeline with lazy evaluation support.
Unlike the standard Pipeline, LazyPipeline defers computation until the results are actually accessed. This allows JAX to optimize the entire computation graph as a single operation.
Lazy evaluation is triggered on: - Property access (e.g., result.y_data) - Method calls (e.g., result.visualize()) - Export operations (e.g., result.export())
- Parameters:
transforms (
list[Transform], optional) – Initial list of transforms to include in pipeline
Examples
>>> from piblin_jax.transform import LazyPipeline >>> >>> # Create lazy pipeline >>> pipeline = LazyPipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), ... ]) >>> >>> # Apply to dataset (computation deferred) >>> lazy_result = pipeline.apply_to(dataset, make_copy=True) >>> >>> # Access property (triggers computation) >>> y_values = lazy_result.y_data # Computation happens here
Notes
Lazy evaluation allows JAX to optimize the entire pipeline
First property access triggers computation and caches result
Subsequent accesses use cached result
More efficient than eager evaluation for complex pipelines
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
append(transform)Add transform to end of pipeline.
apply_to(target[, make_copy, ...])Apply lazy pipeline to target.
clear()count(value)extend(values)S.extend(iterable) -- extend sequence by appending elements from the iterable
index(value, [start, [stop]])Raises ValueError if the value is not present.
insert(index, value)Insert transform at index.
Invalidate cached results.
pop([index])Raise IndexError if list is empty or index is out of range.
remove(value)S.remove(value) -- remove first occurrence of value.
reverse()S.reverse() -- reverse IN PLACE
- __init__(transforms=None)[source]
Initialize lazy pipeline.
- Parameters:
transforms (
list[Transform], optional) – Initial transforms to include in pipeline
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply lazy pipeline to target.
Computation is deferred until results are accessed. Returns a LazyResult wrapper that triggers computation on property/method access.
- Parameters:
- Returns:
Wrapper that triggers computation on access
- Return type:
Notes
The actual transformation is not performed until the result is accessed. This allows JAX to optimize the entire computation graph.
- invalidate_cache()[source]
Invalidate cached results.
Forces recomputation on next access. Useful if transforms have been modified or parameters changed.
Examples
>>> pipeline = LazyPipeline([transform1, transform2]) >>> result = pipeline.apply_to(dataset) >>> _ = result.y_data # Triggers computation >>> >>> # Modify pipeline >>> pipeline.append(transform3) >>> pipeline.invalidate_cache() # Force recomputation
- class piblin_jax.transform.pipeline.LazyResult(pipeline)[source]
Bases:
objectWrapper that triggers lazy computation on property access.
This class wraps the actual result and defers computation until properties or methods are accessed.
- Parameters:
pipeline (
LazyPipeline) – The lazy pipeline that will compute the result
Examples
>>> lazy_result = LazyResult(pipeline) >>> # No computation yet >>> y = lazy_result.y_data # Triggers computation here >>> # Subsequent accesses use cached result >>> x = lazy_result.x_data # No recomputation
Notes
This class is transparent to the user - it behaves like the actual result object, but triggers computation on first access.
- __init__(pipeline)[source]
Initialize lazy result wrapper.
- Parameters:
pipeline (
LazyPipeline) – Pipeline that will compute the result
- __getattr__(name)[source]
Get attribute from computed result.
Triggers computation on first access.
- Parameters:
name (
str) – Attribute name- Returns:
Attribute value from computed result
- Return type:
Any
- class piblin_jax.transform.pipeline.Pipeline(transforms=None)[source]
Bases:
Transform[T],MutableSequence[Transform[T]]Pipeline for composing multiple transforms sequentially.
A pipeline applies a sequence of transforms to data in order. It implements the MutableSequence interface, so it can be used like a list of transforms.
The pipeline is memory-efficient: when make_copy=True, it creates a single copy at entry, then applies all transforms in-place.
- Parameters:
transforms (
list[Transform], optional) – Initial list of transforms to include in pipeline
Examples
>>> from piblin_jax.transform import Pipeline, DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Create transforms >>> class MultiplyTransform(DatasetTransform): ... def __init__(self, factor): ... super().__init__() ... self.factor = factor ... ... def _apply(self, dataset): ... dataset.y_data = dataset.y_data * self.factor ... return dataset >>> >>> # Create pipeline >>> pipeline = Pipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), # Net effect: 6x ... ]) >>> >>> # Apply to dataset >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3]), ... y_data=np.array([2, 4, 6]) ... ) >>> result = pipeline.apply_to(dataset, make_copy=True) >>> # result.y_data is now [12, 24, 36]
Notes
Pipelines can be nested: a pipeline can contain other pipelines
Only one copy is made at entry, then all transforms apply in-place
This is much more memory efficient than copying at each step
Use lazy evaluation for even better performance with JAX
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
append(transform)Add transform to end of pipeline.
apply_to(target[, make_copy, ...])Apply pipeline to target.
clear()count(value)extend(values)S.extend(iterable) -- extend sequence by appending elements from the iterable
index(value, [start, [stop]])Raises ValueError if the value is not present.
insert(index, value)Insert transform at index.
pop([index])Raise IndexError if list is empty or index is out of range.
remove(value)S.remove(value) -- remove first occurrence of value.
reverse()S.reverse() -- reverse IN PLACE
- __init__(transforms=None)[source]
Initialize pipeline.
- Parameters:
transforms (
list[Transform], optional) – Initial transforms to include in pipeline
- apply_to(target, make_copy=True, propagate_uncertainty=False)[source]
Apply pipeline to target.
Only makes copy once at entry, then applies all transforms in-place for memory efficiency.
- Parameters:
- Returns:
Transformed data structure
- Return type:
T
Notes
This is much more efficient than copying at each transform step. The single copy at entry ensures immutability while minimizing memory overhead.
When propagate_uncertainty=True, uncertainty is efficiently propagated through the entire pipeline in a single pass.
- __getitem__(index: int) Transform[T][source]
- __getitem__(index: slice) list[Transform[T]]
Get transform(s) at index.
- Parameters:
- Returns:
Transform at index, or list of transforms for slice
- Return type:
Transformorlist[Transform]
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] # Get first transform >>> pipeline[1:3] # Get slice of transforms
- __setitem__(index: int, value: Transform[T]) None[source]
- __setitem__(index: slice, value: Iterable[Transform[T]]) None
Set transform(s) at index.
- Parameters:
- Raises:
TypeError – If value is not a Transform instance
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] = new_transform # Replace first transform
- __delitem__(index)[source]
Delete transform(s) at index.
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> del pipeline[0] # Remove first transform >>> del pipeline[1:] # Remove all but first transform
- __len__()[source]
Get number of transforms in pipeline.
- Returns:
Number of transforms
- Return type:
Examples
>>> pipeline = Pipeline([t1, t2, t3]) >>> len(pipeline) 3
- insert(index, value)[source]
Insert transform at index.
- Parameters:
index (
int) – Index at which to insert transformvalue (
Transform) – Transform to insert
- Raises:
TypeError – If value is not a Transform instance
Examples
>>> pipeline = Pipeline([t1, t3]) >>> pipeline.insert(1, t2) # Insert t2 between t1 and t3
- append(transform)[source]
Add transform to end of pipeline.
- Parameters:
transform (
Transform) – Transform to append- Raises:
TypeError – If transform is not a Transform instance
Examples
>>> pipeline = Pipeline([t1, t2]) >>> pipeline.append(t3) # Add t3 to end
Region-Based Transforms
Region-based transforms for piblin-jax.
This module provides transforms that operate on specific regions of data: - RegionTransform: Base class for region-based transforms - RegionMultiplyTransform: Example concrete implementation
Region-based transforms apply transformations only within specified regions, preserving data outside those regions. This is useful for selective processing such as background subtraction in specific spectral ranges or local smoothing.
- class piblin_jax.transform.region.RegionMultiplyTransform(region, factor)[source]
Bases:
RegionTransformExample transform: Multiply region by a factor.
This is a concrete implementation of RegionTransform that multiplies the dependent variable within the specified region(s) by a constant factor.
- Parameters:
region (
LinearRegion | CompoundRegion) – Region(s) to transformfactor (
float) – Multiplication factor
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion, CompoundRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Single region example >>> x_data = np.linspace(0, 10, 11) >>> y_data = np.ones(11) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> region = LinearRegion(x_min=3.0, x_max=7.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [3, 7] are multiplied by 2.0, others unchanged >>> # Multiple disjoint regions example >>> region1 = LinearRegion(x_min=1.0, x_max=2.0) >>> region2 = LinearRegion(x_min=8.0, x_max=9.0) >>> compound = CompoundRegion([region1, region2]) >>> transform = RegionMultiplyTransform(compound, factor=0.5) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [1, 2] OR [8, 9] are multiplied by 0.5
Notes
This is a simple example transform for demonstration and testing. More complex transforms can be implemented following the same pattern.
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.region.RegionTransform(region)[source]
Bases:
DatasetTransformBase class for transforms that operate on specific regions.
RegionTransform applies a transformation only within specified region(s), preserving data outside the regions. This enables selective processing of data based on independent variable ranges.
Subclasses should implement the _apply_to_region() method to define the specific transformation to apply within the region(s).
- Parameters:
region (
LinearRegion | CompoundRegion) – Region(s) to transform
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Create dataset >>> x_data = np.array([0, 1, 2, 3, 4, 5]) >>> y_data = np.array([1, 1, 1, 1, 1, 1]) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> # Define region and transform >>> region = LinearRegion(x_min=2.0, x_max=4.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> # Apply transform (only region [2, 4] is multiplied) >>> result = transform.apply_to(dataset, make_copy=True) >>> result.dependent_variable_data array([1., 1., 2., 2., 2., 1.])
Notes
Currently optimized for OneDimensionalDataset
Data outside regions is preserved exactly
Transformations use NumPy arrays internally for compatibility
Region masks are generated from the independent variable
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
Lambda and Dynamic Transforms
Lambda transforms and dynamic parameter transforms for piblin-jax.
This module provides: - LambdaTransform: Wraps arbitrary functions as transforms - DynamicTransform: Base class for data-driven parameter computation - AutoScaleTransform: Automatic data scaling to target range - AutoBaselineTransform: Automatic baseline correction from data
Lambda transforms allow users to create custom transforms from simple functions without defining new classes. Dynamic transforms compute parameters from the data itself, enabling adaptive processing.
- class piblin_jax.transform.lambda_transform.AutoBaselineTransform(n_points=10, method='first')[source]
Bases:
DynamicTransformAutomatically subtract baseline computed from data.
Computes baseline from first/last N points or minimum value, then subtracts it from all data. Useful for removing offsets and drift in experimental measurements.
- Parameters:
Examples
>>> from piblin_jax.transform import AutoBaselineTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Subtract baseline from first 10 points >>> transform = AutoBaselineTransform(n_points=10, method='first') >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.arange(100), ... dependent_variable_data=np.random.randn(100) + 5.0 # offset by 5 ... ) >>> result = transform.apply_to(dataset) >>> >>> # Subtract minimum value >>> transform = AutoBaselineTransform(method='min')
Notes
‘first’ method: Good for time series where initial values are baseline
‘last’ method: Good for measurements that return to baseline
‘min’ method: Good for ensuring all values are non-negative
Computed parameter: ‘baseline’ (value to subtract)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.lambda_transform.AutoScaleTransform(target_min=0.0, target_max=1.0)[source]
Bases:
DynamicTransformAutomatically scale data to specified range.
Computes min/max from data and scales to target range. Useful for normalizing data to standard ranges like [0, 1] or [-1, 1].
- Parameters:
Examples
>>> from piblin_jax.transform import AutoScaleTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Scale to [0, 1] >>> transform = AutoScaleTransform() >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.array([1, 2, 3]), ... dependent_variable_data=np.array([10, 20, 30]) ... ) >>> result = transform.apply_to(dataset) >>> # result.dependent_variable_data is now [0, 0.5, 1] >>> >>> # Scale to [-1, 1] >>> transform = AutoScaleTransform(target_min=-1.0, target_max=1.0)
Notes
Handles constant data (where min == max) by setting all values to target_min
Preserves data ordering and relative differences
Computed parameters: ‘scale’ (multiplicative factor) and ‘offset’ (additive shift)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.lambda_transform.DynamicTransform[source]
Bases:
DatasetTransformBase class for transforms with data-driven parameters.
Dynamic transforms compute parameters from the data itself, then apply transformations using those parameters. This enables adaptive processing where the transformation depends on data characteristics.
Subclasses must implement: - _compute_parameters(dataset): Extract parameters from data - _apply_with_parameters(dataset, params): Apply transformation
- Parameters:
None
Examples
>>> from piblin_jax.transform.lambda_transform import DynamicTransform >>> from piblin_jax.backend import jnp >>> >>> class CustomDynamicTransform(DynamicTransform): ... def _compute_parameters(self, dataset): ... y_data = dataset._dependent_variable_data ... return {'mean': jnp.mean(y_data)} ... ... def _apply_with_parameters(self, dataset, params): ... dataset._dependent_variable_data -= params['mean'] ... return dataset
Notes
Parameters are computed fresh each time the transform is applied
Caching can be implemented in subclasses if needed
Only works with OneDimensionalDataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.lambda_transform.LambdaTransform(func=None, use_x=False, jit_compile=True, lambda_func=None)[source]
Bases:
DatasetTransformTransform that wraps an arbitrary function.
Allows users to create custom transforms from simple functions without defining new classes. Supports both simple functions that operate on the dependent variable only, and functions that use both independent and dependent variables.
- Parameters:
func (
Callable) – Function to apply to dependent variable. Signature: func(y_data: ndarray) -> ndarray Or: func(x_data: ndarray, y_data: ndarray) -> ndarrayuse_x (
bool, defaultFalse) – If True, pass both x and y to func. If False, pass only y to func.jit_compile (
bool, defaultTrue) – If True, attempt JIT compilation of the function. Falls back to regular function if compilation fails.
Examples
>>> from piblin_jax.transform import LambdaTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Simple function >>> transform = LambdaTransform(lambda y: y * 2.0) >>> dataset = OneDimensionalDataset( ... independent_variable_data=np.array([1, 2, 3]), ... dependent_variable_data=np.array([2, 4, 6]) ... ) >>> result = transform.apply_to(dataset) >>> >>> # Function using x and y >>> transform = LambdaTransform( ... lambda x, y: y / x.max(), ... use_x=True ... ) >>> >>> # JAX-compatible function for JIT >>> import jax.numpy as jnp >>> transform = LambdaTransform( ... lambda y: jnp.exp(y) * jnp.sin(y), ... jit_compile=True ... )
Notes
Functions should use jnp instead of np for JIT compilation
JIT compilation improves performance but requires JAX-compatible code
If compilation fails, the transform falls back to the uncompiled function
Only works with OneDimensionalDataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(func=None, use_x=False, jit_compile=True, lambda_func=None)[source]
Initialize lambda transform.
- Parameters:
- Raises:
TypeError – If func is not callable
ValueError – If neither func nor lambda_func is provided
Notes
The function should operate on arrays, not dataset objects. For example: lambda y: y * 2.0 not lambda ds: ds.dependent_variable_data * 2.0
Dataset Transforms
Overview
Core dataset-level transforms for piblin-jax.
This module provides fundamental transforms for processing 1D datasets: - Interpolation: Resample to new x-values - Smoothing: Reduce noise (moving average, Gaussian) - Baseline correction: Remove systematic offsets and drifts - Normalization: Scale data to standard ranges - Calculus: Derivatives and integration
All transforms are JAX-compatible with JIT compilation support and graceful fallback to NumPy when JAX is unavailable.
- class piblin_jax.transform.dataset.AsymmetricLeastSquaresBaseline(lambda_=100000.0, p=0.01, max_iter=10)[source]
Bases:
DatasetTransformFit and subtract baseline using Asymmetric Least Squares (ALS) method.
This advanced baseline correction method is particularly effective for data with positive peaks on a varying baseline (e.g., spectra, chromatograms). The ALS method penalizes positive residuals more than negative ones, causing the baseline to fit below peaks.
- Parameters:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import AsymmetricLeastSquaresBaseline >>> >>> # Create spectrum with peaks on curved baseline >>> x = np.linspace(0, 100, 1000) >>> baseline = 10 + 0.1 * x + 0.001 * x**2 >>> peaks = 50 * np.exp(-((x - 30)**2) / 20) >>> peaks += 30 * np.exp(-((x - 70)**2) / 15) >>> y = baseline + peaks >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Remove baseline using ALS >>> transform = AsymmetricLeastSquaresBaseline(lambda_=1e6, p=0.01) >>> result = transform.apply_to(dataset)
Notes
ALS is iterative and may be slower than polynomial baseline
Very effective for spectroscopy and chromatography data
lambda controls smoothness (typical: 1e4 to 1e7)
p controls asymmetry (typical: 0.001 to 0.1)
Based on Eilers & Boelens (2005) paper
References
P. H. C. Eilers and H. F. M. Boelens, “Baseline Correction with Asymmetric Least Squares Smoothing”, Leiden University Medical Centre Report, 2005.
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.CumulativeIntegral(method='trapezoid')[source]
Bases:
DatasetTransformCompute cumulative integral of 1D dataset.
This transform computes the cumulative integral (running sum) of the dependent variable with respect to the independent variable using the trapezoidal rule.
- Parameters:
method (
str, default'trapezoid') – Integration method: - ‘trapezoid’: Trapezoidal rule (2nd order accurate) - ‘simpson’: Simpson’s rule (4th order accurate, requires odd number of points)
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import CumulativeIntegral >>> >>> # Create constant function (integral should be linear) >>> x = np.linspace(0, 10, 100) >>> y = np.ones_like(x) # Integral of 1 is x >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Compute cumulative integral >>> integral = CumulativeIntegral() >>> result = integral.apply_to(dataset) >>> # Result should be approximately linear (x) >>> result.dependent_variable_data[-1] # Should be ~10 10.0
Notes
Trapezoidal rule: I[i] = sum((y[i] + y[i-1]) / 2 * dx[i])
First value is always 0 (integral from x[0] to x[0])
JIT-compiled with JAX backend
Handles non-uniform spacing
For smoother results on noisy data, consider smoothing first
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.DefiniteIntegral(x_min=None, x_max=None, method='trapezoid')[source]
Bases:
DatasetTransformCompute definite integral over specified region.
This transform computes the definite integral (total area) under the curve between specified x-values.
- Parameters:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import DefiniteIntegral >>> >>> # Create data >>> x = np.linspace(0, np.pi, 100) >>> y = np.sin(x) # Integral from 0 to pi is 2 >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Compute definite integral >>> integral = DefiniteIntegral() >>> result = integral.apply_to(dataset) >>> # Result stores integral value in metadata
Notes
Returns dataset with integral value stored in details
Original data is preserved
For cumulative integral, use CumulativeIntegral instead
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.Derivative(order=1, method='gradient')[source]
Bases:
DatasetTransformCompute numerical derivative of 1D dataset.
This transform computes numerical derivatives using finite differences. Supports first and second derivatives with various accuracy schemes.
- Parameters:
order (
int, default1) – Derivative order (1 or 2). - 1: First derivative (dy/dx) - 2: Second derivative (d²y/dx²)method (
str, default'gradient') – Method for computing derivative: - ‘gradient’: Central differences (2nd order accurate) - ‘forward’: Forward differences (1st order accurate) - ‘backward’: Backward differences (1st order accurate)
- Raises:
ValueError – If order is not 1 or 2.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import Derivative >>> >>> # Create data with known derivative >>> x = np.linspace(0, 10, 100) >>> y = x**2 # dy/dx = 2x >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Compute first derivative >>> deriv = Derivative(order=1) >>> result = deriv.apply_to(dataset) >>> # Result should be approximately 2*x >>> >>> # Compute second derivative >>> deriv2 = Derivative(order=2) >>> result2 = deriv2.apply_to(dataset) >>> # Result should be approximately 2 (constant)
Notes
Uses jnp.gradient for central differences
Gradient method provides 2nd order accuracy
Handles non-uniform spacing in x
JIT-compiled with JAX backend
Edge effects present at boundaries
For noisy data, consider smoothing before differentiation
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(order=1, method='gradient')[source]
Initialize derivative transform.
- Parameters:
- Raises:
ValueError – If order is not 1 or 2.
- class piblin_jax.transform.dataset.GaussianSmooth(sigma=1.0, truncate=3.0)[source]
Bases:
DatasetTransformSmooth data using Gaussian filter.
This transform applies a Gaussian filter for smoothing, which provides better frequency response than simple moving average.
- Parameters:
Examples
>>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import GaussianSmooth >>> import numpy as np >>> >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.3 * np.random.randn(100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> smooth = GaussianSmooth(sigma=2.0) >>> result = smooth.apply_to(dataset)
Notes
Gaussian smoothing preserves features better than moving average
sigma controls the amount of smoothing
Larger sigma = more smoothing
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.Interpolate1D(new_x, method='linear')[source]
Bases:
DatasetTransformInterpolate 1D dataset to new x-values.
This transform resamples a 1D dataset to a new set of independent variable values using linear interpolation. It supports both JAX and NumPy backends with automatic fallback.
- Parameters:
new_x (
array-like) – New independent variable values for interpolation.method (
str, default'linear') – Interpolation method. Currently only ‘linear’ is fully supported. Future versions may support ‘cubic’, ‘spline’, etc.
- new_x
Target x-values for interpolation.
- Type:
ndarray
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import Interpolate1D >>> >>> # Create original dataset >>> x = np.array([0, 1, 2, 3, 4]) >>> y = np.array([0, 1, 4, 9, 16]) # y = x^2 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Interpolate to finer grid >>> new_x = np.linspace(0, 4, 17) # 17 points from 0 to 4 >>> interp = Interpolate1D(new_x, method='linear') >>> result = interp.apply_to(dataset) >>> >>> # Result has new x-values with interpolated y-values >>> result.independent_variable_data.shape (17,)
Notes
Linear interpolation is used for both JAX and NumPy backends
For JAX backend, uses jnp.interp (compiled with JIT)
For NumPy backend, uses np.interp
Extrapolation uses constant values (edge values)
Metadata (conditions, details) is preserved from original dataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.MaxNormalize[source]
Bases:
DatasetTransformNormalize data by dividing by maximum absolute value.
This simple normalization scales data so that the maximum absolute value is 1. Preserves zero and sign of data.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MaxNormalize >>> >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(-50, 100, 100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> transform = MaxNormalize() >>> result = transform.apply_to(dataset) >>> np.max(np.abs(result.dependent_variable_data)) 1.0
Notes
Formula: y_norm = y / max(abs(y))
Preserves zero and sign
Maximum absolute value becomes 1
Simple and fast
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.MinMaxNormalize(feature_range=(0, 1))[source]
Bases:
DatasetTransformMin-max normalization to scale data to a specific range.
This transform scales the dependent variable to a target range, typically [0, 1]. This is useful for comparing datasets with different scales or preparing data for machine learning.
- Parameters:
feature_range (
tupleoffloat, default(0,1)) – Target range for normalization as (min, max).
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MinMaxNormalize >>> >>> # Create data with arbitrary range >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(5, 25, 100) # Range: 5 to 25 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Normalize to [0, 1] >>> transform = MinMaxNormalize() >>> result = transform.apply_to(dataset) >>> np.min(result.dependent_variable_data) # Should be ~0 0.0 >>> np.max(result.dependent_variable_data) # Should be ~1 1.0 >>> >>> # Normalize to custom range >>> transform = MinMaxNormalize(feature_range=(-1, 1)) >>> result = transform.apply_to(dataset)
Notes
Formula: y_scaled = (y - y_min) / (y_max - y_min) * (max - min) + min
Small epsilon added to denominator to avoid division by zero
JIT-compiled with JAX backend for efficiency
Independent variable is preserved
Metadata is preserved
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.MovingAverageSmooth(window_size=5)[source]
Bases:
DatasetTransformSmooth data using moving average filter.
This transform applies a simple moving average (box filter) to smooth noisy data. It uses convolution for efficient computation.
- Parameters:
window_size (
int, default5) – Size of the moving average window. Must be odd to ensure symmetry.
- Raises:
ValueError – If window_size is not odd.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MovingAverageSmooth >>> >>> # Create noisy data >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.5 * np.random.randn(100) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Apply smoothing >>> smooth = MovingAverageSmooth(window_size=5) >>> result = smooth.apply_to(dataset) >>> >>> # Result has smoothed y-values (x unchanged) >>> result.independent_variable_data # Same as original >>> result.dependent_variable_data # Smoothed version
Notes
Uses ‘same’ mode for convolution (output same size as input)
Edge effects present at boundaries (first/last few points)
For JAX backend, convolution is JIT-compiled for efficiency
Window must be odd to ensure symmetric smoothing
Larger windows = more smoothing but more distortion
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(window_size=5)[source]
Initialize moving average smoothing transform.
- Parameters:
window_size (
int, default5) – Size of moving average window (must be odd).- Raises:
ValueError – If window_size is even.
- class piblin_jax.transform.dataset.PolynomialBaseline(degree=1)[source]
Bases:
DatasetTransformFit and subtract polynomial baseline from data.
This transform fits a polynomial to the data and subtracts it, removing systematic trends and offsets. Commonly used in spectroscopy to remove background signals.
- Parameters:
degree (
int, default1) – Degree of polynomial to fit. - 0: Constant offset - 1: Linear drift - 2: Quadratic curvature - Higher: More complex baselines
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import PolynomialBaseline >>> >>> # Create data with linear drift >>> x = np.linspace(0, 10, 100) >>> signal = np.sin(x) >>> baseline = 2.0 * x + 5.0 # Linear drift >>> y = signal + baseline >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Remove linear baseline >>> transform = PolynomialBaseline(degree=1) >>> result = transform.apply_to(dataset) >>> >>> # Result should be close to original signal >>> np.allclose(result.dependent_variable_data, signal, atol=0.1) True
Notes
Uses least-squares polynomial fitting (np.polyfit)
Works for both JAX and NumPy backends (uses NumPy for fitting)
Higher degree polynomials may overfit to noise
For complex baselines, consider spline-based methods
Independent variable (x) is preserved
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.RobustNormalize[source]
Bases:
DatasetTransformRobust normalization using median and IQR.
This transform normalizes data using median and interquartile range (IQR) instead of mean and standard deviation. More robust to outliers than z-score normalization.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import RobustNormalize >>> >>> # Create data with outliers >>> x = np.linspace(0, 10, 100) >>> y = np.random.randn(100) >>> y[0] = 100 # Outlier >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Robust normalization (less affected by outlier) >>> transform = RobustNormalize() >>> result = transform.apply_to(dataset)
Notes
Formula: y_robust = (y - median(y)) / IQR(y)
IQR = Q3 - Q1 (interquartile range)
More robust to outliers than z-score
JIT-compiled with JAX backend
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.ZScoreNormalize[source]
Bases:
DatasetTransformZ-score normalization (standardization).
This transform standardizes data to have zero mean and unit variance. Also known as standardization or z-score transformation. Useful for comparing datasets with different units or scales.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import ZScoreNormalize >>> >>> # Create data with arbitrary mean and std >>> x = np.linspace(0, 10, 100) >>> y = 5.0 * np.random.randn(100) + 10.0 # mean=10, std=5 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Standardize to mean=0, std=1 >>> transform = ZScoreNormalize() >>> result = transform.apply_to(dataset) >>> np.mean(result.dependent_variable_data) # Should be ~0 0.0 >>> np.std(result.dependent_variable_data) # Should be ~1 1.0
Notes
Formula: y_zscore = (y - mean(y)) / std(y)
Small epsilon added to denominator to avoid division by zero
Results have mean ≈ 0 and standard deviation ≈ 1
JIT-compiled with JAX backend
Preserves shape of distribution
Sensitive to outliers (unlike robust scaling)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
Smoothing Transforms
Smoothing transforms for 1D datasets.
This module provides various smoothing/filtering transforms to reduce noise in time series and spectral data.
- class piblin_jax.transform.dataset.smoothing.GaussianSmooth(sigma=1.0, truncate=3.0)[source]
Bases:
DatasetTransformSmooth data using Gaussian filter.
This transform applies a Gaussian filter for smoothing, which provides better frequency response than simple moving average.
- Parameters:
Examples
>>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import GaussianSmooth >>> import numpy as np >>> >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.3 * np.random.randn(100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> smooth = GaussianSmooth(sigma=2.0) >>> result = smooth.apply_to(dataset)
Notes
Gaussian smoothing preserves features better than moving average
sigma controls the amount of smoothing
Larger sigma = more smoothing
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.smoothing.MovingAverageSmooth(window_size=5)[source]
Bases:
DatasetTransformSmooth data using moving average filter.
This transform applies a simple moving average (box filter) to smooth noisy data. It uses convolution for efficient computation.
- Parameters:
window_size (
int, default5) – Size of the moving average window. Must be odd to ensure symmetry.
- window_size
Window size for moving average.
- Type:
- Raises:
ValueError – If window_size is not odd.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MovingAverageSmooth >>> >>> # Create noisy data >>> x = np.linspace(0, 10, 100) >>> y = np.sin(x) + 0.5 * np.random.randn(100) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Apply smoothing >>> smooth = MovingAverageSmooth(window_size=5) >>> result = smooth.apply_to(dataset) >>> >>> # Result has smoothed y-values (x unchanged) >>> result.independent_variable_data # Same as original >>> result.dependent_variable_data # Smoothed version
Notes
Uses ‘same’ mode for convolution (output same size as input)
Edge effects present at boundaries (first/last few points)
For JAX backend, convolution is JIT-compiled for efficiency
Window must be odd to ensure symmetric smoothing
Larger windows = more smoothing but more distortion
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(window_size=5)[source]
Initialize moving average smoothing transform.
- Parameters:
window_size (
int, default5) – Size of moving average window (must be odd).- Raises:
ValueError – If window_size is even.
Interpolation Transforms
Interpolation transforms for 1D datasets.
This module provides interpolation transforms that resample datasets to new x-values using various interpolation methods.
- class piblin_jax.transform.dataset.interpolate.Interpolate1D(new_x, method='linear')[source]
Bases:
DatasetTransformInterpolate 1D dataset to new x-values.
This transform resamples a 1D dataset to a new set of independent variable values using linear interpolation. It supports both JAX and NumPy backends with automatic fallback.
- Parameters:
new_x (
array-like) – New independent variable values for interpolation.method (
str, default'linear') – Interpolation method. Currently only ‘linear’ is fully supported. Future versions may support ‘cubic’, ‘spline’, etc.
- new_x
Target x-values for interpolation.
- Type:
ndarray
- method
Interpolation method name.
- Type:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import Interpolate1D >>> >>> # Create original dataset >>> x = np.array([0, 1, 2, 3, 4]) >>> y = np.array([0, 1, 4, 9, 16]) # y = x^2 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Interpolate to finer grid >>> new_x = np.linspace(0, 4, 17) # 17 points from 0 to 4 >>> interp = Interpolate1D(new_x, method='linear') >>> result = interp.apply_to(dataset) >>> >>> # Result has new x-values with interpolated y-values >>> result.independent_variable_data.shape (17,)
Notes
Linear interpolation is used for both JAX and NumPy backends
For JAX backend, uses jnp.interp (compiled with JIT)
For NumPy backend, uses np.interp
Extrapolation uses constant values (edge values)
Metadata (conditions, details) is preserved from original dataset
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
Normalization Transforms
Normalization transforms for 1D datasets.
This module provides various normalization and scaling transforms to standardize data for comparison and analysis.
- class piblin_jax.transform.dataset.normalization.MaxNormalize[source]
Bases:
DatasetTransformNormalize data by dividing by maximum absolute value.
This simple normalization scales data so that the maximum absolute value is 1. Preserves zero and sign of data.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MaxNormalize >>> >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(-50, 100, 100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> transform = MaxNormalize() >>> result = transform.apply_to(dataset) >>> np.max(np.abs(result.dependent_variable_data)) 1.0
Notes
Formula: y_norm = y / max(abs(y))
Preserves zero and sign
Maximum absolute value becomes 1
Simple and fast
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.normalization.MinMaxNormalize(feature_range=(0, 1))[source]
Bases:
DatasetTransformMin-max normalization to scale data to a specific range.
This transform scales the dependent variable to a target range, typically [0, 1]. This is useful for comparing datasets with different scales or preparing data for machine learning.
- Parameters:
feature_range (
tupleoffloat, default(0,1)) – Target range for normalization as (min, max).
- feature_range
Target range for scaled data.
- Type:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MinMaxNormalize >>> >>> # Create data with arbitrary range >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(5, 25, 100) # Range: 5 to 25 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Normalize to [0, 1] >>> transform = MinMaxNormalize() >>> result = transform.apply_to(dataset) >>> np.min(result.dependent_variable_data) # Should be ~0 0.0 >>> np.max(result.dependent_variable_data) # Should be ~1 1.0 >>> >>> # Normalize to custom range >>> transform = MinMaxNormalize(feature_range=(-1, 1)) >>> result = transform.apply_to(dataset)
Notes
Formula: y_scaled = (y - y_min) / (y_max - y_min) * (max - min) + min
Small epsilon added to denominator to avoid division by zero
JIT-compiled with JAX backend for efficiency
Independent variable is preserved
Metadata is preserved
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.normalization.RobustNormalize[source]
Bases:
DatasetTransformRobust normalization using median and IQR.
This transform normalizes data using median and interquartile range (IQR) instead of mean and standard deviation. More robust to outliers than z-score normalization.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import RobustNormalize >>> >>> # Create data with outliers >>> x = np.linspace(0, 10, 100) >>> y = np.random.randn(100) >>> y[0] = 100 # Outlier >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Robust normalization (less affected by outlier) >>> transform = RobustNormalize() >>> result = transform.apply_to(dataset)
Notes
Formula: y_robust = (y - median(y)) / IQR(y)
IQR = Q3 - Q1 (interquartile range)
More robust to outliers than z-score
JIT-compiled with JAX backend
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.normalization.ZScoreNormalize[source]
Bases:
DatasetTransformZ-score normalization (standardization).
This transform standardizes data to have zero mean and unit variance. Also known as standardization or z-score transformation. Useful for comparing datasets with different units or scales.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import ZScoreNormalize >>> >>> # Create data with arbitrary mean and std >>> x = np.linspace(0, 10, 100) >>> y = 5.0 * np.random.randn(100) + 10.0 # mean=10, std=5 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Standardize to mean=0, std=1 >>> transform = ZScoreNormalize() >>> result = transform.apply_to(dataset) >>> np.mean(result.dependent_variable_data) # Should be ~0 0.0 >>> np.std(result.dependent_variable_data) # Should be ~1 1.0
Notes
Formula: y_zscore = (y - mean(y)) / std(y)
Small epsilon added to denominator to avoid division by zero
Results have mean ≈ 0 and standard deviation ≈ 1
JIT-compiled with JAX backend
Preserves shape of distribution
Sensitive to outliers (unlike robust scaling)
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
Baseline Correction Transforms
Baseline subtraction transforms for 1D datasets.
This module provides baseline correction transforms to remove systematic offsets and drifts from spectroscopic and chromatographic data.
- class piblin_jax.transform.dataset.baseline.AsymmetricLeastSquaresBaseline(lambda_=100000.0, p=0.01, max_iter=10)[source]
Bases:
DatasetTransformFit and subtract baseline using Asymmetric Least Squares (ALS) method.
This advanced baseline correction method is particularly effective for data with positive peaks on a varying baseline (e.g., spectra, chromatograms). The ALS method penalizes positive residuals more than negative ones, causing the baseline to fit below peaks.
- Parameters:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import AsymmetricLeastSquaresBaseline >>> >>> # Create spectrum with peaks on curved baseline >>> x = np.linspace(0, 100, 1000) >>> baseline = 10 + 0.1 * x + 0.001 * x**2 >>> peaks = 50 * np.exp(-((x - 30)**2) / 20) >>> peaks += 30 * np.exp(-((x - 70)**2) / 15) >>> y = baseline + peaks >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Remove baseline using ALS >>> transform = AsymmetricLeastSquaresBaseline(lambda_=1e6, p=0.01) >>> result = transform.apply_to(dataset)
Notes
ALS is iterative and may be slower than polynomial baseline
Very effective for spectroscopy and chromatography data
lambda controls smoothness (typical: 1e4 to 1e7)
p controls asymmetry (typical: 0.001 to 0.1)
Based on Eilers & Boelens (2005) paper
References
P. H. C. Eilers and H. F. M. Boelens, “Baseline Correction with Asymmetric Least Squares Smoothing”, Leiden University Medical Centre Report, 2005.
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.baseline.PolynomialBaseline(degree=1)[source]
Bases:
DatasetTransformFit and subtract polynomial baseline from data.
This transform fits a polynomial to the data and subtracts it, removing systematic trends and offsets. Commonly used in spectroscopy to remove background signals.
- Parameters:
degree (
int, default1) – Degree of polynomial to fit. - 0: Constant offset - 1: Linear drift - 2: Quadratic curvature - Higher: More complex baselines
- degree
Polynomial degree for baseline fitting.
- Type:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import PolynomialBaseline >>> >>> # Create data with linear drift >>> x = np.linspace(0, 10, 100) >>> signal = np.sin(x) >>> baseline = 2.0 * x + 5.0 # Linear drift >>> y = signal + baseline >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Remove linear baseline >>> transform = PolynomialBaseline(degree=1) >>> result = transform.apply_to(dataset) >>> >>> # Result should be close to original signal >>> np.allclose(result.dependent_variable_data, signal, atol=0.1) True
Notes
Uses least-squares polynomial fitting (np.polyfit)
Works for both JAX and NumPy backends (uses NumPy for fitting)
Higher degree polynomials may overfit to noise
For complex baselines, consider spline-based methods
Independent variable (x) is preserved
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
Calculus Transforms
Calculus-based transforms for 1D datasets.
This module provides derivatives and integration transforms for numerical differentiation and integration of experimental data.
- class piblin_jax.transform.dataset.calculus.CumulativeIntegral(method='trapezoid')[source]
Bases:
DatasetTransformCompute cumulative integral of 1D dataset.
This transform computes the cumulative integral (running sum) of the dependent variable with respect to the independent variable using the trapezoidal rule.
- Parameters:
method (
str, default'trapezoid') – Integration method: - ‘trapezoid’: Trapezoidal rule (2nd order accurate) - ‘simpson’: Simpson’s rule (4th order accurate, requires odd number of points)
- method
Integration method.
- Type:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import CumulativeIntegral >>> >>> # Create constant function (integral should be linear) >>> x = np.linspace(0, 10, 100) >>> y = np.ones_like(x) # Integral of 1 is x >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Compute cumulative integral >>> integral = CumulativeIntegral() >>> result = integral.apply_to(dataset) >>> # Result should be approximately linear (x) >>> result.dependent_variable_data[-1] # Should be ~10 10.0
Notes
Trapezoidal rule: I[i] = sum((y[i] + y[i-1]) / 2 * dx[i])
First value is always 0 (integral from x[0] to x[0])
JIT-compiled with JAX backend
Handles non-uniform spacing
For smoother results on noisy data, consider smoothing first
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.calculus.DefiniteIntegral(x_min=None, x_max=None, method='trapezoid')[source]
Bases:
DatasetTransformCompute definite integral over specified region.
This transform computes the definite integral (total area) under the curve between specified x-values.
- Parameters:
- method
Integration method.
- Type:
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import DefiniteIntegral >>> >>> # Create data >>> x = np.linspace(0, np.pi, 100) >>> y = np.sin(x) # Integral from 0 to pi is 2 >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Compute definite integral >>> integral = DefiniteIntegral() >>> result = integral.apply_to(dataset) >>> # Result stores integral value in metadata
Notes
Returns dataset with integral value stored in details
Original data is preserved
For cumulative integral, use CumulativeIntegral instead
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- class piblin_jax.transform.dataset.calculus.Derivative(order=1, method='gradient')[source]
Bases:
DatasetTransformCompute numerical derivative of 1D dataset.
This transform computes numerical derivatives using finite differences. Supports first and second derivatives with various accuracy schemes.
- Parameters:
order (
int, default1) – Derivative order (1 or 2). - 1: First derivative (dy/dx) - 2: Second derivative (d²y/dx²)method (
str, default'gradient') – Method for computing derivative: - ‘gradient’: Central differences (2nd order accurate) - ‘forward’: Forward differences (1st order accurate) - ‘backward’: Backward differences (1st order accurate)
- order
Derivative order.
- Type:
- method
Differentiation method.
- Type:
- Raises:
ValueError – If order is not 1 or 2.
Examples
>>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import Derivative >>> >>> # Create data with known derivative >>> x = np.linspace(0, 10, 100) >>> y = x**2 # dy/dx = 2x >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Compute first derivative >>> deriv = Derivative(order=1) >>> result = deriv.apply_to(dataset) >>> # Result should be approximately 2*x >>> >>> # Compute second derivative >>> deriv2 = Derivative(order=2) >>> result2 = deriv2.apply_to(dataset) >>> # Result should be approximately 2 (constant)
Notes
Uses jnp.gradient for central differences
Gradient method provides 2nd order accuracy
Handles non-uniform spacing in x
JIT-compiled with JAX backend
Edge effects present at boundaries
For noisy data, consider smoothing before differentiation
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Dataset.
- __init__(order=1, method='gradient')[source]
Initialize derivative transform.
- Parameters:
- Raises:
ValueError – If order is not 1 or 2.
Measurement Transforms
Overview
Measurement-level transforms for piblin-jax.
This module provides transforms that operate on Measurement and MeasurementSet objects: - FilterDatasets: Filter datasets within a Measurement - FilterMeasurements: Filter measurements within a MeasurementSet - SplitByRegion: Split datasets by regions - MergeReplicates: Merge measurements with identical conditions
- class piblin_jax.transform.measurement.FilterDatasets(predicate=None, dataset_type=None)[source]
Bases:
MeasurementTransformFilter datasets within a Measurement.
This transform filters datasets based on either their type or a custom predicate function. Returns a new Measurement containing only the datasets that match the filter criteria.
- Parameters:
predicate (
Callable[[Dataset],bool] | None) – Function that returns True for datasets to keep. Cannot be used with dataset_type.dataset_type (
type | None) – Filter by dataset type (alternative to predicate). Cannot be used with predicate.
- Raises:
ValueError – If neither predicate nor dataset_type is provided. If both predicate and dataset_type are provided.
Examples
>>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.measurement import FilterDatasets >>> >>> # Filter by type >>> transform = FilterDatasets(dataset_type=OneDimensionalDataset) >>> result = transform.apply_to(measurement) >>> >>> # Filter by predicate >>> transform = FilterDatasets( ... predicate=lambda ds: ds.conditions.get('temp', 0) > 25 ... ) >>> result = transform.apply_to(measurement)
Notes
Returns a new Measurement with filtered datasets
Preserves measurement-level conditions and details
Empty dataset list is allowed if no datasets match
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
- class piblin_jax.transform.measurement.FilterMeasurements(predicate)[source]
Bases:
MeasurementSetTransformFilter measurements within a MeasurementSet.
This transform filters measurements based on a predicate function. Returns a new MeasurementSet containing only the measurements that match the filter criteria.
- Parameters:
predicate (
Callable[[Measurement],bool]) – Function that returns True for measurements to keep- Raises:
TypeError – If predicate is not callable
Examples
>>> from piblin_jax.transform.measurement import FilterMeasurements >>> >>> # Filter by condition >>> transform = FilterMeasurements( ... predicate=lambda m: m.conditions.get('temp', 0) > 25 ... ) >>> result = transform.apply_to(measurement_set) >>> >>> # Filter by replicate number >>> transform = FilterMeasurements( ... predicate=lambda m: m.conditions.get('replicate', 0) <= 3 ... ) >>> result = transform.apply_to(measurement_set)
Notes
Returns a new MeasurementSet with filtered measurements
Preserves measurement-set-level conditions and details
Empty measurement list is allowed if no measurements match
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- class piblin_jax.transform.measurement.MergeReplicates(strategy='average')[source]
Bases:
MeasurementSetTransformMerge measurements with identical conditions.
This transform groups measurements by their conditions and merges replicates (measurements with identical conditions) using either averaging or concatenation.
- Parameters:
strategy (
str, default'average') – Merge strategy: ‘average’ or ‘concatenate’- Raises:
ValueError – If strategy is not ‘average’ or ‘concatenate’
Examples
>>> from piblin_jax.transform.measurement import MergeReplicates >>> >>> # Average replicate measurements >>> transform = MergeReplicates(strategy='average') >>> result = transform.apply_to(measurement_set)
Notes
Groups measurements by conditions (all key-value pairs must match)
For ‘average’ strategy: - Averages dependent variable data across replicates - Assumes all replicates have same independent variable data - Assumes all replicates have same dataset structure
For ‘concatenate’ strategy: - Currently returns first measurement (not yet implemented)
Preserves measurement-set-level conditions and details
Single measurements (no replicates) are returned unchanged
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- class piblin_jax.transform.measurement.SplitByRegion(regions)[source]
Bases:
MeasurementTransformSplit datasets by multiple regions, creating new Measurement.
This transform splits each OneDimensionalDataset in a Measurement into multiple datasets based on the specified regions. Each region creates a new dataset containing only the data points within that region.
- Parameters:
regions (
list[LinearRegion]) – List of regions to split by- Raises:
ValueError – If regions list is empty
Examples
>>> from piblin_jax.data.roi import LinearRegion >>> from piblin_jax.transform.measurement import SplitByRegion >>> >>> regions = [ ... LinearRegion(0, 5), ... LinearRegion(5, 10) ... ] >>> transform = SplitByRegion(regions) >>> result = transform.apply_to(measurement)
Notes
Only processes OneDimensionalDataset objects
Non-1D datasets are silently skipped
Empty regions (no data points) are included as empty datasets
Preserves dataset-level conditions and details for each split
Preserves measurement-level conditions and details
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
Filtering Transforms
Collection-level transforms for Measurements and MeasurementSets.
This module provides transforms that operate on collections: - FilterDatasets: Filter datasets within a Measurement - FilterMeasurements: Filter measurements within a MeasurementSet - SplitByRegion: Split datasets by regions - MergeReplicates: Merge measurements with identical conditions
- class piblin_jax.transform.measurement.filter.FilterDatasets(predicate=None, dataset_type=None)[source]
Bases:
MeasurementTransformFilter datasets within a Measurement.
This transform filters datasets based on either their type or a custom predicate function. Returns a new Measurement containing only the datasets that match the filter criteria.
- Parameters:
predicate (
Callable[[Dataset],bool] | None) – Function that returns True for datasets to keep. Cannot be used with dataset_type.dataset_type (
type | None) – Filter by dataset type (alternative to predicate). Cannot be used with predicate.
- Raises:
ValueError – If neither predicate nor dataset_type is provided. If both predicate and dataset_type are provided.
Examples
>>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.measurement import FilterDatasets >>> >>> # Filter by type >>> transform = FilterDatasets(dataset_type=OneDimensionalDataset) >>> result = transform.apply_to(measurement) >>> >>> # Filter by predicate >>> transform = FilterDatasets( ... predicate=lambda ds: ds.conditions.get('temp', 0) > 25 ... ) >>> result = transform.apply_to(measurement)
Notes
Returns a new Measurement with filtered datasets
Preserves measurement-level conditions and details
Empty dataset list is allowed if no datasets match
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
- __init__(predicate=None, dataset_type=None)[source]
Initialize FilterDatasets transform.
- Parameters:
predicate (
Callable[[Dataset],bool] | None) – Function that returns True for datasets to keepdataset_type (
type | None) – Filter by dataset type
- class piblin_jax.transform.measurement.filter.FilterMeasurements(predicate)[source]
Bases:
MeasurementSetTransformFilter measurements within a MeasurementSet.
This transform filters measurements based on a predicate function. Returns a new MeasurementSet containing only the measurements that match the filter criteria.
- Parameters:
predicate (
Callable[[Measurement],bool]) – Function that returns True for measurements to keep- Raises:
TypeError – If predicate is not callable
Examples
>>> from piblin_jax.transform.measurement import FilterMeasurements >>> >>> # Filter by condition >>> transform = FilterMeasurements( ... predicate=lambda m: m.conditions.get('temp', 0) > 25 ... ) >>> result = transform.apply_to(measurement_set) >>> >>> # Filter by replicate number >>> transform = FilterMeasurements( ... predicate=lambda m: m.conditions.get('replicate', 0) <= 3 ... ) >>> result = transform.apply_to(measurement_set)
Notes
Returns a new MeasurementSet with filtered measurements
Preserves measurement-set-level conditions and details
Empty measurement list is allowed if no measurements match
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- __init__(predicate)[source]
Initialize FilterMeasurements transform.
- Parameters:
predicate (
Callable[[Measurement],bool]) – Function that returns True for measurements to keep
- class piblin_jax.transform.measurement.filter.MergeReplicates(strategy='average')[source]
Bases:
MeasurementSetTransformMerge measurements with identical conditions.
This transform groups measurements by their conditions and merges replicates (measurements with identical conditions) using either averaging or concatenation.
- Parameters:
strategy (
str, default'average') – Merge strategy: ‘average’ or ‘concatenate’- Raises:
ValueError – If strategy is not ‘average’ or ‘concatenate’
Examples
>>> from piblin_jax.transform.measurement import MergeReplicates >>> >>> # Average replicate measurements >>> transform = MergeReplicates(strategy='average') >>> result = transform.apply_to(measurement_set)
Notes
Groups measurements by conditions (all key-value pairs must match)
For ‘average’ strategy: - Averages dependent variable data across replicates - Assumes all replicates have same independent variable data - Assumes all replicates have same dataset structure
For ‘concatenate’ strategy: - Currently returns first measurement (not yet implemented)
Preserves measurement-set-level conditions and details
Single measurements (no replicates) are returned unchanged
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to MeasurementSet.
- class piblin_jax.transform.measurement.filter.SplitByRegion(regions)[source]
Bases:
MeasurementTransformSplit datasets by multiple regions, creating new Measurement.
This transform splits each OneDimensionalDataset in a Measurement into multiple datasets based on the specified regions. Each region creates a new dataset containing only the data points within that region.
- Parameters:
regions (
list[LinearRegion]) – List of regions to split by- Raises:
ValueError – If regions list is empty
Examples
>>> from piblin_jax.data.roi import LinearRegion >>> from piblin_jax.transform.measurement import SplitByRegion >>> >>> regions = [ ... LinearRegion(0, 5), ... LinearRegion(5, 10) ... ] >>> transform = SplitByRegion(regions) >>> result = transform.apply_to(measurement)
Notes
Only processes OneDimensionalDataset objects
Non-1D datasets are silently skipped
Empty regions (no data points) are included as empty datasets
Preserves dataset-level conditions and details for each split
Preserves measurement-level conditions and details
Methods
__call__(target[, make_copy, ...])Shorthand for apply_to.
apply_to(target[, make_copy, ...])Apply transform to Measurement.
- __init__(regions)[source]
Initialize SplitByRegion transform.
- Parameters:
regions (
list[LinearRegion]) – List of regions to split by