Source code for piblin_jax.transform.region

"""
Region-based transforms for piblin-jax.

This module provides transforms that operate on specific regions of data:
- RegionTransform: Base class for region-based transforms
- RegionMultiplyTransform: Example concrete implementation

Region-based transforms apply transformations only within specified regions,
preserving data outside those regions. This is useful for selective processing
such as background subtraction in specific spectral ranges or local smoothing.
"""

from typing import Union

import numpy as np

from piblin_jax.backend import from_numpy, jnp, to_numpy
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.data.roi import CompoundRegion, LinearRegion
from piblin_jax.transform.base import DatasetTransform


[docs] class RegionTransform(DatasetTransform): """ Base class for transforms that operate on specific regions. RegionTransform applies a transformation only within specified region(s), preserving data outside the regions. This enables selective processing of data based on independent variable ranges. Subclasses should implement the _apply_to_region() method to define the specific transformation to apply within the region(s). Parameters ---------- region : LinearRegion | CompoundRegion Region(s) to transform Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Create dataset >>> x_data = np.array([0, 1, 2, 3, 4, 5]) >>> y_data = np.array([1, 1, 1, 1, 1, 1]) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> # Define region and transform >>> region = LinearRegion(x_min=2.0, x_max=4.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> # Apply transform (only region [2, 4] is multiplied) >>> result = transform.apply_to(dataset, make_copy=True) >>> result.dependent_variable_data array([1., 1., 2., 2., 2., 1.]) Notes ----- - Currently optimized for OneDimensionalDataset - Data outside regions is preserved exactly - Transformations use NumPy arrays internally for compatibility - Region masks are generated from the independent variable """
[docs] def __init__(self, region: LinearRegion | CompoundRegion): """ Initialize RegionTransform. Parameters ---------- region : LinearRegion | CompoundRegion Region(s) to transform """ super().__init__() self.region = region
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply transform within region, preserve outside. This method handles the region masking logic. Subclasses should override _apply_to_region() instead to define the specific transformation to apply. Parameters ---------- dataset : OneDimensionalDataset Dataset to transform Returns ------- OneDimensionalDataset Transformed dataset Raises ------ TypeError If dataset is not OneDimensionalDataset """ if not isinstance(dataset, OneDimensionalDataset): raise TypeError("RegionTransform only works with OneDimensionalDataset") # Get data as NumPy for mask generation x_data = dataset.independent_variable_data y_data = dataset.dependent_variable_data # Generate mask mask = self.region.get_mask(x_data) # Extract region data x_region = x_data[mask] y_region = y_data[mask] # Apply transform to region (subclass implements this) y_region_transformed = self._apply_to_region(x_region, y_region) # Reconstruct full array with transformed region y_data_full = y_data.copy() y_data_full[mask] = y_region_transformed # Update dataset with backend arrays dataset._dependent_variable_data = from_numpy(y_data_full) return dataset def _apply_to_region(self, x_region: np.ndarray, y_region: np.ndarray) -> np.ndarray: """ Apply transformation to region data. Subclasses override this method to implement specific transforms. This method receives only the data within the region and should return the transformed dependent variable. Parameters ---------- x_region : np.ndarray Independent variable within region y_region : np.ndarray Dependent variable within region Returns ------- np.ndarray Transformed dependent variable Raises ------ NotImplementedError If subclass doesn't implement this method """ raise NotImplementedError("Subclasses must implement _apply_to_region()")
[docs] class RegionMultiplyTransform(RegionTransform): """ Example transform: Multiply region by a factor. This is a concrete implementation of RegionTransform that multiplies the dependent variable within the specified region(s) by a constant factor. Parameters ---------- region : LinearRegion | CompoundRegion Region(s) to transform factor : float Multiplication factor Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.data.roi import LinearRegion, CompoundRegion >>> from piblin_jax.transform.region import RegionMultiplyTransform >>> # Single region example >>> x_data = np.linspace(0, 10, 11) >>> y_data = np.ones(11) >>> dataset = OneDimensionalDataset( ... independent_variable_data=x_data, ... dependent_variable_data=y_data ... ) >>> region = LinearRegion(x_min=3.0, x_max=7.0) >>> transform = RegionMultiplyTransform(region, factor=2.0) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [3, 7] are multiplied by 2.0, others unchanged >>> # Multiple disjoint regions example >>> region1 = LinearRegion(x_min=1.0, x_max=2.0) >>> region2 = LinearRegion(x_min=8.0, x_max=9.0) >>> compound = CompoundRegion([region1, region2]) >>> transform = RegionMultiplyTransform(compound, factor=0.5) >>> result = transform.apply_to(dataset, make_copy=True) >>> # Points in [1, 2] OR [8, 9] are multiplied by 0.5 Notes ----- This is a simple example transform for demonstration and testing. More complex transforms can be implemented following the same pattern. """
[docs] def __init__(self, region: LinearRegion | CompoundRegion, factor: float): """ Initialize RegionMultiplyTransform. Parameters ---------- region : LinearRegion | CompoundRegion Region(s) to transform factor : float Multiplication factor """ super().__init__(region) self.factor = factor
def _apply_to_region(self, x_region: np.ndarray, y_region: np.ndarray) -> np.ndarray: """ Multiply region data by factor. Parameters ---------- x_region : np.ndarray Independent variable within region (not used) y_region : np.ndarray Dependent variable within region Returns ------- np.ndarray Transformed dependent variable (y_region * factor) """ return y_region * self.factor
__all__ = ["RegionMultiplyTransform", "RegionTransform"]