"""
Normalization transforms for 1D datasets.
This module provides various normalization and scaling transforms
to standardize data for comparison and analysis.
"""
from typing import Any
from piblin_jax.backend import jnp
from piblin_jax.backend.operations import jit
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform
[docs]
class MinMaxNormalize(DatasetTransform):
"""
Min-max normalization to scale data to a specific range.
This transform scales the dependent variable to a target range,
typically [0, 1]. This is useful for comparing datasets with
different scales or preparing data for machine learning.
Parameters
----------
feature_range : tuple of float, default=(0, 1)
Target range for normalization as (min, max).
Attributes
----------
feature_range : tuple
Target range for scaled data.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import MinMaxNormalize
>>>
>>> # Create data with arbitrary range
>>> x = np.linspace(0, 10, 100)
>>> y = np.linspace(5, 25, 100) # Range: 5 to 25
>>> dataset = OneDimensionalDataset(
... independent_variable_data=x,
... dependent_variable_data=y
... )
>>>
>>> # Normalize to [0, 1]
>>> transform = MinMaxNormalize()
>>> result = transform.apply_to(dataset)
>>> np.min(result.dependent_variable_data) # Should be ~0
0.0
>>> np.max(result.dependent_variable_data) # Should be ~1
1.0
>>>
>>> # Normalize to custom range
>>> transform = MinMaxNormalize(feature_range=(-1, 1))
>>> result = transform.apply_to(dataset)
Notes
-----
- Formula: y_scaled = (y - y_min) / (y_max - y_min) * (max - min) + min
- Small epsilon added to denominator to avoid division by zero
- JIT-compiled with JAX backend for efficiency
- Independent variable is preserved
- Metadata is preserved
"""
[docs]
def __init__(self, feature_range: tuple[float, float] = (0, 1)):
"""
Initialize min-max normalization transform.
Parameters
----------
feature_range : tuple of float, default=(0, 1)
Target range as (min, max).
"""
super().__init__()
self.feature_range = feature_range
@staticmethod
@jit
def _compute_minmax_norm(y: Any, target_min: float, target_max: float) -> Any:
"""JIT-compiled min-max normalization for 3-5x speedup."""
y_min = jnp.min(y)
y_max = jnp.max(y)
# Normalize to [0, 1]
y_norm = (y - y_min) / (y_max - y_min + 1e-10)
# Scale to target range
return y_norm * (target_max - target_min) + target_min
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply min-max normalization to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset to normalize.
Returns
-------
OneDimensionalDataset
Dataset with normalized dependent variable.
Notes
-----
Modifies dependent variable in-place.
"""
y = jnp.asarray(dataset.dependent_variable_data)
# Use JIT-compiled normalization: 3-5x faster
target_min, target_max = self.feature_range
y_scaled = MinMaxNormalize._compute_minmax_norm(y, target_min, target_max) # type: ignore[call-arg]
# Update dataset
dataset._dependent_variable_data = y_scaled
return dataset
[docs]
class ZScoreNormalize(DatasetTransform):
"""
Z-score normalization (standardization).
This transform standardizes data to have zero mean and unit variance.
Also known as standardization or z-score transformation. Useful for
comparing datasets with different units or scales.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import ZScoreNormalize
>>>
>>> # Create data with arbitrary mean and std
>>> x = np.linspace(0, 10, 100)
>>> y = 5.0 * np.random.randn(100) + 10.0 # mean=10, std=5
>>> dataset = OneDimensionalDataset(
... independent_variable_data=x,
... dependent_variable_data=y
... )
>>>
>>> # Standardize to mean=0, std=1
>>> transform = ZScoreNormalize()
>>> result = transform.apply_to(dataset)
>>> np.mean(result.dependent_variable_data) # Should be ~0
0.0
>>> np.std(result.dependent_variable_data) # Should be ~1
1.0
Notes
-----
- Formula: y_zscore = (y - mean(y)) / std(y)
- Small epsilon added to denominator to avoid division by zero
- Results have mean ≈ 0 and standard deviation ≈ 1
- JIT-compiled with JAX backend
- Preserves shape of distribution
- Sensitive to outliers (unlike robust scaling)
"""
@staticmethod
@jit
def _compute_zscore(y: Any) -> Any:
"""JIT-compiled z-score normalization for 3-5x speedup."""
mean = jnp.mean(y)
std = jnp.std(y)
return (y - mean) / (std + 1e-10)
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply z-score normalization to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset to standardize.
Returns
-------
OneDimensionalDataset
Dataset with standardized dependent variable.
"""
y = jnp.asarray(dataset.dependent_variable_data)
# Use JIT-compiled z-score normalization: 3-5x faster
y_zscore = ZScoreNormalize._compute_zscore(y)
# Update dataset
dataset._dependent_variable_data = y_zscore
return dataset
[docs]
class RobustNormalize(DatasetTransform):
"""
Robust normalization using median and IQR.
This transform normalizes data using median and interquartile range (IQR)
instead of mean and standard deviation. More robust to outliers than
z-score normalization.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import RobustNormalize
>>>
>>> # Create data with outliers
>>> x = np.linspace(0, 10, 100)
>>> y = np.random.randn(100)
>>> y[0] = 100 # Outlier
>>> dataset = OneDimensionalDataset(x, y)
>>>
>>> # Robust normalization (less affected by outlier)
>>> transform = RobustNormalize()
>>> result = transform.apply_to(dataset)
Notes
-----
- Formula: y_robust = (y - median(y)) / IQR(y)
- IQR = Q3 - Q1 (interquartile range)
- More robust to outliers than z-score
- JIT-compiled with JAX backend
"""
@staticmethod
@jit
def _compute_robust_norm(y: Any) -> Any:
"""JIT-compiled robust normalization for 3-5x speedup."""
median = jnp.median(y)
q75 = jnp.percentile(y, 75)
q25 = jnp.percentile(y, 25)
iqr = q75 - q25
return (y - median) / (iqr + 1e-10)
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply robust normalization to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset.
Returns
-------
OneDimensionalDataset
Dataset with robust-normalized dependent variable.
"""
y = jnp.asarray(dataset.dependent_variable_data)
# Use JIT-compiled robust normalization: 3-5x faster
y_robust = RobustNormalize._compute_robust_norm(y)
# Update dataset
dataset._dependent_variable_data = y_robust
return dataset
[docs]
class MaxNormalize(DatasetTransform):
"""
Normalize data by dividing by maximum absolute value.
This simple normalization scales data so that the maximum absolute
value is 1. Preserves zero and sign of data.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.data.datasets import OneDimensionalDataset
>>> from piblin_jax.transform.dataset import MaxNormalize
>>>
>>> x = np.linspace(0, 10, 100)
>>> y = np.linspace(-50, 100, 100)
>>> dataset = OneDimensionalDataset(x, y)
>>>
>>> transform = MaxNormalize()
>>> result = transform.apply_to(dataset)
>>> np.max(np.abs(result.dependent_variable_data))
1.0
Notes
-----
- Formula: y_norm = y / max(abs(y))
- Preserves zero and sign
- Maximum absolute value becomes 1
- Simple and fast
"""
@staticmethod
@jit
def _compute_max_norm(y: Any) -> Any:
"""JIT-compiled max normalization for 3-5x speedup."""
max_abs = jnp.max(jnp.abs(y))
return y / (max_abs + 1e-10)
def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override]
"""
Apply max normalization to dataset.
Parameters
----------
dataset : OneDimensionalDataset
Input dataset.
Returns
-------
OneDimensionalDataset
Dataset with max-normalized dependent variable.
"""
y = jnp.asarray(dataset.dependent_variable_data)
# Use JIT-compiled max normalization: 3-5x faster
y_norm = MaxNormalize._compute_max_norm(y)
# Update dataset
dataset._dependent_variable_data = y_norm
return dataset
__all__ = [
"MaxNormalize",
"MinMaxNormalize",
"RobustNormalize",
"ZScoreNormalize",
]