"""
Region of Interest (ROI) classes for piblin-jax.
This module provides classes for defining regions on independent variables:
- LinearRegion: Contiguous region on a 1D independent variable
- CompoundRegion: Container for multiple LinearRegion objects (union)
Regions are used with RegionTransform to apply transformations only
within specified regions while preserving data outside those regions.
"""
import numpy as np
[docs]
class LinearRegion:
"""
Represents a contiguous region on a 1D independent variable.
A LinearRegion defines a contiguous range [x_min, x_max] (inclusive)
on an independent variable. It can generate boolean masks to select
data points within this range.
Parameters
----------
x_min : float
Lower bound (inclusive)
x_max : float
Upper bound (inclusive)
Raises
------
ValueError
If x_min >= x_max
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.roi import LinearRegion
>>> # Define region from 2.0 to 5.0
>>> region = LinearRegion(x_min=2.0, x_max=5.0)
>>> # Generate mask for data
>>> x_data = np.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> mask = region.get_mask(x_data)
>>> mask
array([False, False, True, True, True, True, False, False])
>>> # Extract data within region
>>> x_data[mask]
array([2, 3, 4, 5])
Notes
-----
- Bounds are inclusive: both x_min and x_max are included in the region
- Masks are generated using NumPy arrays for compatibility
- Use with RegionTransform to apply selective transformations
"""
[docs]
def __init__(self, x_min: float, x_max: float):
"""
Initialize LinearRegion.
Parameters
----------
x_min : float
Lower bound (inclusive)
x_max : float
Upper bound (inclusive)
Raises
------
ValueError
If x_min >= x_max
"""
if x_min >= x_max:
raise ValueError(f"x_min ({x_min}) must be less than x_max ({x_max})")
self.x_min = x_min
self.x_max = x_max
[docs]
def get_mask(self, x_data: np.ndarray) -> np.ndarray:
"""
Generate boolean mask for data within region.
Creates a boolean array where True indicates points within
the region [x_min, x_max] (inclusive).
Parameters
----------
x_data : np.ndarray
Independent variable data
Returns
-------
np.ndarray
Boolean mask (True for points in region)
Examples
--------
>>> region = LinearRegion(x_min=2.0, x_max=5.0)
>>> x_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> region.get_mask(x_data)
array([False, True, True, True, True, False])
"""
return (x_data >= self.x_min) & (x_data <= self.x_max)
[docs]
def __repr__(self) -> str:
"""Return string representation of LinearRegion."""
return f"LinearRegion(x_min={self.x_min}, x_max={self.x_max})"
[docs]
class CompoundRegion:
"""
Container for multiple LinearRegion objects (union of regions).
A CompoundRegion represents the union of multiple disjoint or
overlapping LinearRegion objects. It generates combined masks
that include all points in any of the constituent regions.
Parameters
----------
regions : list[LinearRegion]
List of LinearRegion objects
Raises
------
ValueError
If regions list is empty
TypeError
If any element is not a LinearRegion
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.roi import LinearRegion, CompoundRegion
>>> # Define two disjoint regions
>>> region1 = LinearRegion(x_min=1.0, x_max=2.0)
>>> region2 = LinearRegion(x_min=4.0, x_max=5.0)
>>> compound = CompoundRegion([region1, region2])
>>> # Generate combined mask
>>> x_data = np.array([0, 1, 2, 3, 4, 5, 6])
>>> mask = compound.get_mask(x_data)
>>> mask
array([False, True, True, False, True, True, False])
>>> # Extract data from both regions
>>> x_data[mask]
array([1, 2, 4, 5])
Notes
-----
- The mask is the union (OR) of all constituent region masks
- Regions can be disjoint or overlapping
- Access individual regions using indexing: compound[0], compound[1], etc.
- Get number of regions using len(compound)
"""
[docs]
def __init__(self, regions: list[LinearRegion]):
"""
Initialize CompoundRegion.
Parameters
----------
regions : list[LinearRegion]
List of LinearRegion objects
Raises
------
ValueError
If regions list is empty
TypeError
If any element is not a LinearRegion
"""
if not regions:
raise ValueError("CompoundRegion requires at least one region")
if not all(isinstance(r, LinearRegion) for r in regions):
raise TypeError("All regions must be LinearRegion objects")
self.regions = list(regions)
[docs]
def get_mask(self, x_data: np.ndarray) -> np.ndarray:
"""
Generate combined boolean mask (union of all regions).
Creates a boolean array where True indicates points within
any of the constituent regions.
Parameters
----------
x_data : np.ndarray
Independent variable data
Returns
-------
np.ndarray
Boolean mask (True for points in any region)
Examples
--------
>>> region1 = LinearRegion(x_min=1.0, x_max=2.0)
>>> region2 = LinearRegion(x_min=4.0, x_max=5.0)
>>> compound = CompoundRegion([region1, region2])
>>> x_data = np.array([0, 1, 2, 3, 4, 5, 6])
>>> compound.get_mask(x_data)
array([False, True, True, False, True, True, False])
"""
combined_mask = np.zeros_like(x_data, dtype=bool)
for region in self.regions:
combined_mask |= region.get_mask(x_data)
return combined_mask
[docs]
def __len__(self) -> int:
"""Return number of regions."""
return len(self.regions)
[docs]
def __getitem__(self, index: int) -> LinearRegion:
"""Get region by index."""
return self.regions[index]
[docs]
def __repr__(self) -> str:
"""Return string representation of CompoundRegion."""
return f"CompoundRegion({len(self.regions)} regions)"
__all__ = ["CompoundRegion", "LinearRegion"]