Source code for piblin_jax.transform.pipeline

"""
Pipeline composition for transforms.

This module provides pipeline functionality for composing multiple transforms:
- Pipeline: Sequential composition of transforms
- LazyPipeline: Pipeline with lazy evaluation support

Pipelines support:
- MutableSequence interface (list-like operations)
- Sequential transform application
- Single copy at entry (memory efficient)
- Lazy evaluation (computation deferred)
- JIT compilation (entire pipeline)
"""

from collections.abc import Iterable, MutableSequence
from typing import Any, TypeVar, overload

from .base import Transform

T = TypeVar("T")


[docs] class Pipeline(Transform[T], MutableSequence[Transform[T]]): """ Pipeline for composing multiple transforms sequentially. A pipeline applies a sequence of transforms to data in order. It implements the MutableSequence interface, so it can be used like a list of transforms. The pipeline is memory-efficient: when make_copy=True, it creates a single copy at entry, then applies all transforms in-place. Parameters ---------- transforms : list[Transform], optional Initial list of transforms to include in pipeline Examples -------- >>> from piblin_jax.transform import Pipeline, DatasetTransform >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> import numpy as np >>> >>> # Create transforms >>> class MultiplyTransform(DatasetTransform): ... def __init__(self, factor): ... super().__init__() ... self.factor = factor ... ... def _apply(self, dataset): ... dataset.y_data = dataset.y_data * self.factor ... return dataset >>> >>> # Create pipeline >>> pipeline = Pipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), # Net effect: 6x ... ]) >>> >>> # Apply to dataset >>> dataset = OneDimensionalDataset( ... x_data=np.array([1, 2, 3]), ... y_data=np.array([2, 4, 6]) ... ) >>> result = pipeline.apply_to(dataset, make_copy=True) >>> # result.y_data is now [12, 24, 36] Notes ----- - Pipelines can be nested: a pipeline can contain other pipelines - Only one copy is made at entry, then all transforms apply in-place - This is much more memory efficient than copying at each step - Use lazy evaluation for even better performance with JAX """
[docs] def __init__(self, transforms: list[Transform[T]] | None = None): """ Initialize pipeline. Parameters ---------- transforms : list[Transform], optional Initial transforms to include in pipeline """ super().__init__() self._transforms: list[Transform[T]] = list(transforms) if transforms else [] self._lazy = False # Standard pipeline is eager by default
def _apply(self, target: T, propagate_uncertainty: bool = False) -> T: """ Apply all transforms in sequence. This is the internal implementation that applies each transform in the pipeline sequentially. Each transform operates in-place on the result of the previous transform. Parameters ---------- target : T Data structure to transform propagate_uncertainty : bool, default=False If True, propagate uncertainty through all transforms Returns ------- T Transformed data structure after all transforms applied Notes ----- All transforms are applied with make_copy=False for efficiency, since the Pipeline.apply_to method handles copying at entry. """ result = target for transform in self._transforms: # Apply each transform in-place (no copy) result = transform.apply_to( result, make_copy=False, propagate_uncertainty=propagate_uncertainty ) return result
[docs] def apply_to(self, target: T, make_copy: bool = True, propagate_uncertainty: bool = False) -> T: """ Apply pipeline to target. Only makes copy once at entry, then applies all transforms in-place for memory efficiency. Parameters ---------- target : T Data structure to transform make_copy : bool, default=True If True, create one copy at entry before applying transforms propagate_uncertainty : bool, default=False If True and target has uncertainty, propagate through all transforms Returns ------- T Transformed data structure Notes ----- This is much more efficient than copying at each transform step. The single copy at entry ensures immutability while minimizing memory overhead. When propagate_uncertainty=True, uncertainty is efficiently propagated through the entire pipeline in a single pass. """ if make_copy: # Single copy at entry target = self._copy_tree(target) # Apply all transforms in-place return self._apply(target, propagate_uncertainty=propagate_uncertainty)
# MutableSequence interface implementation # This allows Pipeline to be used like a list @overload def __getitem__(self, index: int) -> Transform[T]: ... @overload def __getitem__(self, index: slice) -> list[Transform[T]]: ...
[docs] def __getitem__(self, index: int | slice) -> Transform[T] | list[Transform[T]]: """ Get transform(s) at index. Parameters ---------- index : int or slice Index or slice to retrieve Returns ------- Transform or list[Transform] Transform at index, or list of transforms for slice Examples -------- >>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] # Get first transform >>> pipeline[1:3] # Get slice of transforms """ if isinstance(index, slice): # For slices, return list of transforms return self._transforms[index] # For single index, return single transform return self._transforms[index]
@overload def __setitem__(self, index: int, value: Transform[T]) -> None: ... @overload def __setitem__(self, index: slice, value: Iterable[Transform[T]]) -> None: ...
[docs] def __setitem__(self, index: int | slice, value: Transform[T] | Iterable[Transform[T]]) -> None: """ Set transform(s) at index. Parameters ---------- index : int or slice Index or slice to set value : Transform or list[Transform] Transform(s) to set at index Raises ------ TypeError If value is not a Transform instance Examples -------- >>> pipeline = Pipeline([t1, t2, t3]) >>> pipeline[0] = new_transform # Replace first transform """ if isinstance(index, slice): # For slices, validate all values are transforms value_list = list(value) if not isinstance(value, Transform) else [value] if not all(isinstance(v, Transform) for v in value_list): raise TypeError("Pipeline can only contain Transform objects") self._transforms[index] = value_list else: # For single index, validate value is a transform if not isinstance(value, Transform): raise TypeError("Pipeline can only contain Transform objects") self._transforms[index] = value
[docs] def __delitem__(self, index: int | slice) -> None: """ Delete transform(s) at index. Parameters ---------- index : int or slice Index or slice to delete Examples -------- >>> pipeline = Pipeline([t1, t2, t3]) >>> del pipeline[0] # Remove first transform >>> del pipeline[1:] # Remove all but first transform """ del self._transforms[index]
[docs] def __len__(self) -> int: """ Get number of transforms in pipeline. Returns ------- int Number of transforms Examples -------- >>> pipeline = Pipeline([t1, t2, t3]) >>> len(pipeline) 3 """ return len(self._transforms)
[docs] def insert(self, index: int, value: Transform[T]) -> None: """ Insert transform at index. Parameters ---------- index : int Index at which to insert transform value : Transform Transform to insert Raises ------ TypeError If value is not a Transform instance Examples -------- >>> pipeline = Pipeline([t1, t3]) >>> pipeline.insert(1, t2) # Insert t2 between t1 and t3 """ if not isinstance(value, Transform): raise TypeError("Pipeline can only contain Transform objects") self._transforms.insert(index, value)
[docs] def append(self, transform: Transform[T]) -> None: """ Add transform to end of pipeline. Parameters ---------- transform : Transform Transform to append Raises ------ TypeError If transform is not a Transform instance Examples -------- >>> pipeline = Pipeline([t1, t2]) >>> pipeline.append(t3) # Add t3 to end """ if not isinstance(transform, Transform): raise TypeError("Pipeline can only contain Transform objects") self._transforms.append(transform)
[docs] def __repr__(self) -> str: """ String representation of pipeline. Returns ------- str String representation showing number of transforms Examples -------- >>> pipeline = Pipeline([t1, t2, t3]) >>> repr(pipeline) 'Pipeline(3 transforms)' """ return f"Pipeline({len(self._transforms)} transforms)"
[docs] def __str__(self) -> str: """ Human-readable string representation. Returns ------- str String showing all transforms in pipeline """ if not self._transforms: return "Pipeline(empty)" lines = ["Pipeline:"] for i, transform in enumerate(self._transforms): lines.append(f" {i}. {transform.__class__.__name__}") return "\n".join(lines)
[docs] class LazyPipeline(Pipeline[T]): """ Pipeline with lazy evaluation support. Unlike the standard Pipeline, LazyPipeline defers computation until the results are actually accessed. This allows JAX to optimize the entire computation graph as a single operation. Lazy evaluation is triggered on: - Property access (e.g., result.y_data) - Method calls (e.g., result.visualize()) - Export operations (e.g., result.export()) Parameters ---------- transforms : list[Transform], optional Initial list of transforms to include in pipeline Examples -------- >>> from piblin_jax.transform import LazyPipeline >>> >>> # Create lazy pipeline >>> pipeline = LazyPipeline([ ... MultiplyTransform(2.0), ... MultiplyTransform(3.0), ... ]) >>> >>> # Apply to dataset (computation deferred) >>> lazy_result = pipeline.apply_to(dataset, make_copy=True) >>> >>> # Access property (triggers computation) >>> y_values = lazy_result.y_data # Computation happens here Notes ----- - Lazy evaluation allows JAX to optimize the entire pipeline - First property access triggers computation and caches result - Subsequent accesses use cached result - More efficient than eager evaluation for complex pipelines """
[docs] def __init__(self, transforms: list[Transform[T]] | None = None): """ Initialize lazy pipeline. Parameters ---------- transforms : list[Transform], optional Initial transforms to include in pipeline """ super().__init__(transforms) self._lazy = True # Lazy pipelines defer computation self._target: T | None = None self._result_cache: T | None = None self._dirty = True # Flag indicating computation needed self._propagate_unc = False # Store uncertainty propagation flag
[docs] def apply_to( self, target: T, make_copy: bool = True, propagate_uncertainty: bool = False ) -> Any: """ Apply lazy pipeline to target. Computation is deferred until results are accessed. Returns a LazyResult wrapper that triggers computation on property/method access. Parameters ---------- target : T Data structure to transform make_copy : bool, default=True If True, create copy before transforming propagate_uncertainty : bool, default=False If True, propagate uncertainty through all transforms Returns ------- LazyResult Wrapper that triggers computation on access Notes ----- The actual transformation is not performed until the result is accessed. This allows JAX to optimize the entire computation graph. """ if make_copy: target = self._copy_tree(target) # Store target and mark as dirty (needs computation) self._target = target self._dirty = True self._propagate_unc = propagate_uncertainty # Return lazy wrapper that triggers computation on access return LazyResult(self)
def _compute(self) -> T | None: """ Execute the pipeline computation. This is called when results are accessed. It performs the actual computation and caches the result. Returns ------- T | None Computed result Notes ----- Result is cached so subsequent accesses don't recompute. """ if self._dirty and self._target is not None: # Perform computation with uncertainty propagation if requested self._result_cache = self._apply( self._target, propagate_uncertainty=self._propagate_unc ) self._dirty = False return self._result_cache
[docs] def invalidate_cache(self) -> None: """ Invalidate cached results. Forces recomputation on next access. Useful if transforms have been modified or parameters changed. Examples -------- >>> pipeline = LazyPipeline([transform1, transform2]) >>> result = pipeline.apply_to(dataset) >>> _ = result.y_data # Triggers computation >>> >>> # Modify pipeline >>> pipeline.append(transform3) >>> pipeline.invalidate_cache() # Force recomputation """ self._dirty = True self._result_cache = None
[docs] class LazyResult: """ Wrapper that triggers lazy computation on property access. This class wraps the actual result and defers computation until properties or methods are accessed. Parameters ---------- pipeline : LazyPipeline The lazy pipeline that will compute the result Examples -------- >>> lazy_result = LazyResult(pipeline) >>> # No computation yet >>> y = lazy_result.y_data # Triggers computation here >>> # Subsequent accesses use cached result >>> x = lazy_result.x_data # No recomputation Notes ----- This class is transparent to the user - it behaves like the actual result object, but triggers computation on first access. """
[docs] def __init__(self, pipeline: LazyPipeline[Any]): """ Initialize lazy result wrapper. Parameters ---------- pipeline : LazyPipeline Pipeline that will compute the result """ # Store in __dict__ to avoid triggering __getattr__ object.__setattr__(self, "_pipeline", pipeline) object.__setattr__(self, "_computed", None)
[docs] def __getattr__(self, name: str) -> Any: """ Get attribute from computed result. Triggers computation on first access. Parameters ---------- name : str Attribute name Returns ------- Any Attribute value from computed result """ # Trigger computation if not already done if object.__getattribute__(self, "_computed") is None: pipeline = object.__getattribute__(self, "_pipeline") computed = pipeline._compute() object.__setattr__(self, "_computed", computed) # Get attribute from computed result computed = object.__getattribute__(self, "_computed") return getattr(computed, name)
[docs] def __setattr__(self, name: str, value: Any) -> None: """ Set attribute on computed result. Triggers computation if not already done. Parameters ---------- name : str Attribute name value : Any Attribute value """ if name in ("_pipeline", "_computed"): # Internal attributes object.__setattr__(self, name, value) else: # Trigger computation if needed if object.__getattribute__(self, "_computed") is None: pipeline = object.__getattribute__(self, "_pipeline") computed = pipeline._compute() object.__setattr__(self, "_computed", computed) # Set attribute on computed result computed = object.__getattribute__(self, "_computed") setattr(computed, name, value)
[docs] def __repr__(self) -> str: """String representation.""" return f"LazyResult(computed={object.__getattribute__(self, '_computed') is not None})"
__all__ = [ "LazyPipeline", "LazyResult", "Pipeline", ]