Source code for piblin_jax.transform.dataset.normalization

"""
Normalization transforms for 1D datasets.

This module provides various normalization and scaling transforms
to standardize data for comparison and analysis.
"""

from typing import Any

from piblin_jax.backend import jnp
from piblin_jax.backend.operations import jit
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.base import DatasetTransform


[docs] class MinMaxNormalize(DatasetTransform): """ Min-max normalization to scale data to a specific range. This transform scales the dependent variable to a target range, typically [0, 1]. This is useful for comparing datasets with different scales or preparing data for machine learning. Parameters ---------- feature_range : tuple of float, default=(0, 1) Target range for normalization as (min, max). Attributes ---------- feature_range : tuple Target range for scaled data. Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MinMaxNormalize >>> >>> # Create data with arbitrary range >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(5, 25, 100) # Range: 5 to 25 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Normalize to [0, 1] >>> transform = MinMaxNormalize() >>> result = transform.apply_to(dataset) >>> np.min(result.dependent_variable_data) # Should be ~0 0.0 >>> np.max(result.dependent_variable_data) # Should be ~1 1.0 >>> >>> # Normalize to custom range >>> transform = MinMaxNormalize(feature_range=(-1, 1)) >>> result = transform.apply_to(dataset) Notes ----- - Formula: y_scaled = (y - y_min) / (y_max - y_min) * (max - min) + min - Small epsilon added to denominator to avoid division by zero - JIT-compiled with JAX backend for efficiency - Independent variable is preserved - Metadata is preserved """
[docs] def __init__(self, feature_range: tuple[float, float] = (0, 1)): """ Initialize min-max normalization transform. Parameters ---------- feature_range : tuple of float, default=(0, 1) Target range as (min, max). """ super().__init__() self.feature_range = feature_range
@staticmethod @jit def _compute_minmax_norm(y: Any, target_min: float, target_max: float) -> Any: """JIT-compiled min-max normalization for 3-5x speedup.""" y_min = jnp.min(y) y_max = jnp.max(y) # Normalize to [0, 1] y_norm = (y - y_min) / (y_max - y_min + 1e-10) # Scale to target range return y_norm * (target_max - target_min) + target_min def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply min-max normalization to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset to normalize. Returns ------- OneDimensionalDataset Dataset with normalized dependent variable. Notes ----- Modifies dependent variable in-place. """ y = jnp.asarray(dataset.dependent_variable_data) # Use JIT-compiled normalization: 3-5x faster target_min, target_max = self.feature_range y_scaled = MinMaxNormalize._compute_minmax_norm(y, target_min, target_max) # type: ignore[call-arg] # Update dataset dataset._dependent_variable_data = y_scaled return dataset
[docs] class ZScoreNormalize(DatasetTransform): """ Z-score normalization (standardization). This transform standardizes data to have zero mean and unit variance. Also known as standardization or z-score transformation. Useful for comparing datasets with different units or scales. Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import ZScoreNormalize >>> >>> # Create data with arbitrary mean and std >>> x = np.linspace(0, 10, 100) >>> y = 5.0 * np.random.randn(100) + 10.0 # mean=10, std=5 >>> dataset = OneDimensionalDataset( ... independent_variable_data=x, ... dependent_variable_data=y ... ) >>> >>> # Standardize to mean=0, std=1 >>> transform = ZScoreNormalize() >>> result = transform.apply_to(dataset) >>> np.mean(result.dependent_variable_data) # Should be ~0 0.0 >>> np.std(result.dependent_variable_data) # Should be ~1 1.0 Notes ----- - Formula: y_zscore = (y - mean(y)) / std(y) - Small epsilon added to denominator to avoid division by zero - Results have mean ≈ 0 and standard deviation ≈ 1 - JIT-compiled with JAX backend - Preserves shape of distribution - Sensitive to outliers (unlike robust scaling) """ @staticmethod @jit def _compute_zscore(y: Any) -> Any: """JIT-compiled z-score normalization for 3-5x speedup.""" mean = jnp.mean(y) std = jnp.std(y) return (y - mean) / (std + 1e-10) def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply z-score normalization to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset to standardize. Returns ------- OneDimensionalDataset Dataset with standardized dependent variable. """ y = jnp.asarray(dataset.dependent_variable_data) # Use JIT-compiled z-score normalization: 3-5x faster y_zscore = ZScoreNormalize._compute_zscore(y) # Update dataset dataset._dependent_variable_data = y_zscore return dataset
[docs] class RobustNormalize(DatasetTransform): """ Robust normalization using median and IQR. This transform normalizes data using median and interquartile range (IQR) instead of mean and standard deviation. More robust to outliers than z-score normalization. Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import RobustNormalize >>> >>> # Create data with outliers >>> x = np.linspace(0, 10, 100) >>> y = np.random.randn(100) >>> y[0] = 100 # Outlier >>> dataset = OneDimensionalDataset(x, y) >>> >>> # Robust normalization (less affected by outlier) >>> transform = RobustNormalize() >>> result = transform.apply_to(dataset) Notes ----- - Formula: y_robust = (y - median(y)) / IQR(y) - IQR = Q3 - Q1 (interquartile range) - More robust to outliers than z-score - JIT-compiled with JAX backend """ @staticmethod @jit def _compute_robust_norm(y: Any) -> Any: """JIT-compiled robust normalization for 3-5x speedup.""" median = jnp.median(y) q75 = jnp.percentile(y, 75) q25 = jnp.percentile(y, 25) iqr = q75 - q25 return (y - median) / (iqr + 1e-10) def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply robust normalization to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset. Returns ------- OneDimensionalDataset Dataset with robust-normalized dependent variable. """ y = jnp.asarray(dataset.dependent_variable_data) # Use JIT-compiled robust normalization: 3-5x faster y_robust = RobustNormalize._compute_robust_norm(y) # Update dataset dataset._dependent_variable_data = y_robust return dataset
[docs] class MaxNormalize(DatasetTransform): """ Normalize data by dividing by maximum absolute value. This simple normalization scales data so that the maximum absolute value is 1. Preserves zero and sign of data. Examples -------- >>> import numpy as np >>> from piblin_jax.data.datasets import OneDimensionalDataset >>> from piblin_jax.transform.dataset import MaxNormalize >>> >>> x = np.linspace(0, 10, 100) >>> y = np.linspace(-50, 100, 100) >>> dataset = OneDimensionalDataset(x, y) >>> >>> transform = MaxNormalize() >>> result = transform.apply_to(dataset) >>> np.max(np.abs(result.dependent_variable_data)) 1.0 Notes ----- - Formula: y_norm = y / max(abs(y)) - Preserves zero and sign - Maximum absolute value becomes 1 - Simple and fast """ @staticmethod @jit def _compute_max_norm(y: Any) -> Any: """JIT-compiled max normalization for 3-5x speedup.""" max_abs = jnp.max(jnp.abs(y)) return y / (max_abs + 1e-10) def _apply(self, dataset: OneDimensionalDataset) -> OneDimensionalDataset: # type: ignore[override] """ Apply max normalization to dataset. Parameters ---------- dataset : OneDimensionalDataset Input dataset. Returns ------- OneDimensionalDataset Dataset with max-normalized dependent variable. """ y = jnp.asarray(dataset.dependent_variable_data) # Use JIT-compiled max normalization: 3-5x faster y_norm = MaxNormalize._compute_max_norm(y) # Update dataset dataset._dependent_variable_data = y_norm return dataset
__all__ = [ "MaxNormalize", "MinMaxNormalize", "RobustNormalize", "ZScoreNormalize", ]