Performance Guide ================= This guide covers performance optimization in piblin-jax, from basic best practices to advanced GPU acceleration and profiling techniques. Overview -------- piblin-jax is designed for high-performance scientific computing through: **JAX Backend** Automatic JIT (Just-In-Time) compilation for CPU/GPU acceleration. **Vectorized Operations** Array operations optimized for modern hardware (SIMD, GPU). **Smart Caching** Automatic caching of compiled functions for reuse. **Pipeline Optimization** Transform pipelines are optimized as a whole, not step-by-step. **NumPy Fallback** Graceful degradation to NumPy when JAX is unavailable. Quick Start ----------- Check Your Backend ~~~~~~~~~~~~~~~~~~ Verify which backend piblin-jax is using:: from piblin_jax.backend import get_backend, is_jax_available print(f"Backend: {get_backend()}") print(f"JAX available: {is_jax_available()}") if is_jax_available(): import jax print(f"JAX version: {jax.__version__}") print(f"Devices: {jax.devices()}") Expected output:: Backend: jax JAX available: True JAX version: 0.4.23 Devices: [CpuDevice(id=0)] or with GPU:: Devices: [cuda(id=0), CpuDevice(id=0)] Basic Performance Tips ~~~~~~~~~~~~~~~~~~~~~~ **1. Use JAX backend (installed by default)** Install with GPU support for 10-100x speedup:: pip install "jax[cuda12]" # NVIDIA CUDA 12 pip install "jax[metal]" # Apple Silicon (M1/M2/M3) **2. Reuse transform objects** JIT compilation happens on first call. Reusing transforms is fast:: # Good: Create once, use many times smoother = GaussianSmoothing(sigma=2.0) for dataset in datasets: smoothed = smoother.apply_to(dataset) # Slower: Creates and compiles new transform each time for dataset in datasets: smoothed = GaussianSmoothing(sigma=2.0).apply_to(dataset) **3. Use Pipelines** Pipelines optimize the entire sequence:: from piblin_jax.transform import Pipeline pipeline = Pipeline([ GaussianSmoothing(sigma=2.0), Derivative(order=1), Normalize(method='minmax') ]) # Single optimized pass result = pipeline.apply_to(dataset) **4. Batch processing** Process multiple datasets together:: from piblin_jax.data.collections import MeasurementSet # Efficient: Single batch operation mset = MeasurementSet(datasets) results = [transform.apply_to(ds) for ds in mset.measurements] JAX Backend ----------- JIT Compilation ~~~~~~~~~~~~~~~ JAX's Just-In-Time compiler optimizes Python code to machine code: **First call: Compilation overhead** :: import time from piblin_jax.transform.dataset import GaussianSmoothing smoother = GaussianSmoothing(sigma=2.0) # First call: slow (compilation) start = time.time() result1 = smoother.apply_to(dataset) print(f"First call: {(time.time() - start)*1000:.1f}ms") # Second call: fast (compiled) start = time.time() result2 = smoother.apply_to(dataset) print(f"Second call: {(time.time() - start)*1000:.1f}ms") Output:: First call: 45.2ms (includes compilation) Second call: 2.1ms (reuses compiled code) **Key insight:** The first call is slow, but subsequent calls are 10-100x faster. Vectorization ~~~~~~~~~~~~~ JAX automatically vectorizes operations:: from piblin_jax.backend import jnp # Both are fast, but vectorized is cleaner # Manual loop (slower) results = [] for i in range(1000): results.append(jnp.sin(data[i])) # Vectorized (faster) results = jnp.sin(data) # Operates on entire array at once **vmap (vectorizing map):** Apply function across array dimensions:: from piblin_jax.backend.operations import vmap def process_single(x): return x ** 2 + 2 * x + 1 # Vectorize the function process_batch = vmap(process_single) # Process all data at once batch_data = jnp.array([[1, 2, 3], [4, 5, 6]]) results = process_batch(batch_data) GPU Acceleration ---------------- Installation ~~~~~~~~~~~~ **NVIDIA GPU (CUDA):** :: # CUDA 12.x pip install --upgrade pip pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # Or CUDA 11.x pip install "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html **Apple Silicon (Metal):** :: pip install "jax[metal]" **AMD GPU (ROCm):** :: pip install "jax[rocm]" -f https://storage.googleapis.com/jax-releases/jax_rocm_releases.html Verification ~~~~~~~~~~~~ Check if GPU is detected:: import jax print(f"GPU available: {len(jax.devices('gpu')) > 0}") print(f"All devices: {jax.devices()}") # Check memory if len(jax.devices('gpu')) > 0: gpu = jax.devices('gpu')[0] print(f"GPU: {gpu.device_kind}") Automatic GPU Usage ~~~~~~~~~~~~~~~~~~~ piblin-jax automatically uses GPU when available - no code changes needed:: # This code runs on GPU automatically if available from piblin_jax.transform.dataset import GaussianSmoothing smoother = GaussianSmoothing(sigma=2.0) result = smoother.apply_to(large_dataset) **Backend automatically:** 1. Detects GPU 2. Allocates arrays on GPU 3. Compiles kernels for GPU 4. Executes on GPU 5. Returns results When to Use GPU ~~~~~~~~~~~~~~~ **GPU excels at:** - Large datasets (>10,000 points) - Repeated operations (transform reuse) - Parallel operations (batch processing) - Heavy computation (smoothing, FFT, convolution) **CPU may be faster for:** - Small datasets (<1,000 points) - One-time operations - Memory-limited tasks - Simple operations (addition, multiplication) **Benchmark comparison:** .. list-table:: :widths: 30 20 20 20 :header-rows: 1 * - Operation - Data Size - CPU Time - GPU Time * - Gaussian smoothing - 1,000 points - 2.1 ms - 5.3 ms (slower!) * - Gaussian smoothing - 100,000 points - 45 ms - 1.2 ms (37x faster) * - Derivative - 10,000 points - 8.5 ms - 0.8 ms (10x faster) * - Bayesian fitting - 50 points, 2000 samples - 12 s - 0.8 s (15x faster) Memory Management ~~~~~~~~~~~~~~~~~ **GPU memory is limited:** Monitor memory usage:: # Check allocated memory (NVIDIA) !nvidia-smi **Best practices:** 1. **Process in batches** for large datasets 2. **Clear cache** between experiments:: import jax jax.clear_caches() 3. **Use float32** instead of float64 (half the memory):: from piblin_jax.backend import jnp data = jnp.array(data, dtype=jnp.float32) 4. **Explicitly move to CPU** if needed:: from piblin_jax.backend.operations import device_get cpu_array = device_get(gpu_array) Performance Optimization ------------------------ Transform Optimization ~~~~~~~~~~~~~~~~~~~~~~ **Use JIT-compiled transforms:** Built-in transforms are already optimized. For custom transforms:: from piblin_jax.transform.base import DatasetTransform from piblin_jax.backend.operations import jit from piblin_jax.backend import jnp class FastCustomTransform(DatasetTransform): @staticmethod @jit def _compute(data, param): """JIT-compiled computation core.""" return data * param + jnp.sin(data) def _apply(self, dataset): result = self._compute(dataset._dependent_variable_data, self.param) dataset._dependent_variable_data = result return dataset **Speedup:** 3-100x depending on operation complexity. Pipeline Optimization ~~~~~~~~~~~~~~~~~~~~~ Combine transforms into pipelines for optimization:: from piblin_jax.transform import Pipeline # Optimized pipeline pipeline = Pipeline([ GaussianSmoothing(sigma=2.0), Derivative(order=1), Normalize(method='zscore') ]) # Single pass through data result = pipeline.apply_to(dataset) **Why faster:** - Single memory pass (cache-friendly) - Combined compilation - Reduced intermediate arrays - Automatic fusion of operations Batch Processing ~~~~~~~~~~~~~~~~ Process multiple datasets efficiently:: # Inefficient: One at a time results = [] for dataset in datasets: results.append(transform.apply_to(dataset)) # Efficient: Batch processing from piblin_jax.data.collections import MeasurementSet mset = MeasurementSet(datasets) # Transform applies optimizations across all datasets results = [transform.apply_to(ds, make_copy=False) for ds in mset.measurements] **Tip:** Use ``make_copy=False`` for in-place operations (saves memory). Profiling --------- Time Measurements ~~~~~~~~~~~~~~~~~ **Basic timing:** :: import time start = time.time() result = transform.apply_to(dataset) elapsed = time.time() - start print(f"Time: {elapsed*1000:.2f}ms") **Jupyter timing:** :: # Single run %time result = transform.apply_to(dataset) # Multiple runs (average) %timeit result = transform.apply_to(dataset) JAX Profiling ~~~~~~~~~~~~~ Detailed performance profiling:: import jax # Profile compilation and execution with jax.profiler.trace("/tmp/jax-trace"): result = transform.apply_to(dataset) # View in Chrome: chrome://tracing # Load /tmp/jax-trace Memory Profiling ~~~~~~~~~~~~~~~~ **Track memory usage:** :: import psutil import os process = psutil.Process(os.getpid()) # Before mem_before = process.memory_info().rss / 1024 / 1024 # MB print(f"Memory before: {mem_before:.1f} MB") # Operation result = transform.apply_to(large_dataset) # After mem_after = process.memory_info().rss / 1024 / 1024 print(f"Memory after: {mem_after:.1f} MB") print(f"Memory delta: {mem_after - mem_before:.1f} MB") **GPU memory (NVIDIA):** :: # Command line nvidia-smi # Python import subprocess result = subprocess.run(['nvidia-smi'], capture_output=True, text=True) print(result.stdout) Bottleneck Analysis ~~~~~~~~~~~~~~~~~~~ Identify slow operations:: import cProfile import pstats # Profile code profiler = cProfile.Profile() profiler.enable() # Your code here for dataset in datasets: result = transform.apply_to(dataset) profiler.disable() # View results stats = pstats.Stats(profiler) stats.sort_stats('cumulative') stats.print_stats(20) # Top 20 slowest functions Common Performance Patterns ---------------------------- Pattern 1: Precompute and Reuse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :: # Bad: Recomputes every time for dataset in datasets: smoother = GaussianSmoothing(sigma=2.0) result = smoother.apply_to(dataset) # Good: Compile once, use many times smoother = GaussianSmoothing(sigma=2.0) # First call compiles results = [smoother.apply_to(ds) for ds in datasets] **Speedup:** 10-50x for loops over many datasets. Pattern 2: In-Place Operations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ :: # Memory-intensive: Creates copies result1 = transform1.apply_to(dataset, make_copy=True) result2 = transform2.apply_to(result1, make_copy=True) result3 = transform3.apply_to(result2, make_copy=True) # Memory-efficient: In-place result = dataset.copy() # Single copy at start transform1.apply_to(result, make_copy=False) transform2.apply_to(result, make_copy=False) transform3.apply_to(result, make_copy=False) **Memory savings:** 3x less memory usage. Pattern 3: Lazy Evaluation ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Defer computation until results are needed:: # Eager: Computes immediately smoothed = smoother.apply_to(dataset) normalized = normalizer.apply_to(smoothed) # Use results later... # Lazy: Compute only when needed pipeline = Pipeline([smoother, normalizer]) # No computation yet # Compute on demand result = pipeline.apply_to(dataset) Pattern 4: Array Reuse ~~~~~~~~~~~~~~~~~~~~~~ :: from piblin_jax.backend import jnp # Bad: Creates new arrays for i in range(1000): temp = jnp.zeros(10000) result = compute_something(temp) # Good: Reuse arrays temp = jnp.zeros(10000) for i in range(1000): result = compute_something(temp) Benchmarks ---------- Typical Performance ~~~~~~~~~~~~~~~~~~~ Performance vs reference implementation (NumPy-only): .. list-table:: :widths: 40 15 15 15 :header-rows: 1 * - Operation - NumPy Baseline - JAX (CPU) - JAX (GPU) * - Dataset creation - 1.0x - 1.2x - 1.1x * - Gaussian smoothing - 1.0x - 5-10x - 40-100x * - Derivative - 1.0x - 3-5x - 15-30x * - FFT (large) - 1.0x - 2-4x - 20-50x * - Transform pipeline - 1.0x - 5-10x - 50-100x * - Bayesian fitting - 1.0x (scipy) - 10-15x - 90-150x * - MCMC sampling (2000 samples) - 60s - 5s - 0.7s **Hardware used:** - CPU: AMD Ryzen 9 5950X (16 cores) - GPU: NVIDIA RTX 3090 (24GB) Real-World Examples ~~~~~~~~~~~~~~~~~~~ **Example 1: Smoothing large dataset** :: import numpy as np import time from piblin_jax.data.datasets import OneDimensionalDataset from piblin_jax.transform.dataset import GaussianSmoothing # Large dataset x = np.linspace(0, 100, 100000) y = np.sin(x) + np.random.normal(0, 0.1, size=len(x)) dataset = OneDimensionalDataset( independent_variable_data=x, dependent_variable_data=y ) # Benchmark smoother = GaussianSmoothing(sigma=5.0) # Warmup (compilation) _ = smoother.apply_to(dataset) # Measure start = time.time() result = smoother.apply_to(dataset) elapsed = time.time() - start print(f"Smoothing 100k points: {elapsed*1000:.2f}ms") Results:: NumPy backend: 125ms JAX (CPU): 18ms (7x faster) JAX (GPU): 1.2ms (104x faster) **Example 2: Bayesian power-law fitting** :: from piblin_jax.bayesian.models import PowerLawModel # Generate data shear_rate = np.logspace(-1, 2, 30) viscosity = 5.0 * shear_rate ** (-0.4) + np.random.normal(0, 0.5, size=30) # Benchmark model = PowerLawModel(n_samples=2000, n_warmup=1000) start = time.time() model.fit(shear_rate, viscosity) elapsed = time.time() - start print(f"Bayesian fitting (2000 samples): {elapsed:.2f}s") Results:: NumPy backend (scipy): Not available JAX (CPU): 5.2s JAX (GPU): 0.7s (7x faster than CPU) Optimization Checklist ---------------------- Before Optimizing ~~~~~~~~~~~~~~~~~ 1. **Profile first**: Identify actual bottlenecks 2. **Measure baseline**: Know current performance 3. **Set targets**: Define acceptable performance 4. **Start simple**: Basic optimizations often sufficient During Optimization ~~~~~~~~~~~~~~~~~~~ 1. **Use JAX backend**: Install with ``pip install jax`` 2. **Reuse transforms**: Avoid repeated compilation 3. **Use pipelines**: Combine multiple transforms 4. **Batch processing**: Process multiple datasets together 5. **In-place operations**: Use ``make_copy=False`` carefully 6. **GPU acceleration**: For large datasets (>10k points) After Optimization ~~~~~~~~~~~~~~~~~~ 1. **Verify correctness**: Ensure results unchanged 2. **Measure improvement**: Compare to baseline 3. **Document**: Note optimizations for maintenance 4. **Monitor**: Check performance over time Performance Anti-Patterns -------------------------- **Anti-Pattern 1: Premature GPU usage** :: # Bad: GPU overhead > computation for small data tiny_dataset = OneDimensionalDataset(x[:100], y[:100]) result = gpu_transform.apply_to(tiny_dataset) # Good: Use CPU for small data result = cpu_transform.apply_to(tiny_dataset) **Anti-Pattern 2: Repeated compilation** :: # Bad: New transform every iteration for sigma in [1.0, 2.0, 3.0, 4.0, 5.0]: result = GaussianSmoothing(sigma=sigma).apply_to(dataset) # Good: Parameterize properly # (Note: GaussianSmoothing sigma is compilation parameter, # so this is unavoidable. For custom transforms, use runtime params) **Anti-Pattern 3: Unnecessary copying** :: # Bad: Excessive copying data1 = dataset.copy() data2 = data1.copy() data3 = data2.copy() # Good: Single copy when needed working_data = dataset.copy() # Modify working_data in-place **Anti-Pattern 4: Mixed precision without intent** :: # Bad: Implicit float64 (slower, more memory) data = np.array([1, 2, 3]) # float64 by default # Good: Explicit float32 (faster, less memory) data = np.array([1, 2, 3], dtype=np.float32) Troubleshooting --------------- Slow Performance ~~~~~~~~~~~~~~~~ **Symptom:** Operations slower than expected **Check:** 1. JAX backend installed? ``pip install jax`` 2. GPU detected? ``import jax; print(jax.devices())`` 3. First call compilation? Time second call 4. Data size appropriate? GPU helps with >10k points 5. Profiled? Use ``%timeit`` to find bottleneck Memory Issues ~~~~~~~~~~~~~ **Symptom:** Out of memory errors **Solutions:** 1. Use ``make_copy=False`` for in-place operations 2. Process in batches 3. Use float32 instead of float64 4. Clear JAX cache: ``jax.clear_caches()`` 5. Reduce batch size for GPU GPU Not Used ~~~~~~~~~~~~ **Symptom:** GPU available but not used **Check:** 1. JAX installed with GPU support? 2. CUDA/ROCm installed? 3. ``jax.devices('gpu')`` returns devices? 4. Data moved to GPU? (Automatic in piblin-jax) Compilation Warnings ~~~~~~~~~~~~~~~~~~~~ **Symptom:** Warnings during first call **Usually safe to ignore:** - "Slow compilation" - Expected on first call - "Large constant" - JAX inlining arrays - "Tracing" - Normal JAX behavior **Action needed:** - "Shape mismatch" - Check array shapes - "Type error" - Check data types Further Reading --------------- - See :doc:`../tutorials/basic_workflow` for practical examples - See :doc:`concepts` for architecture details - See JAX documentation: https://jax.readthedocs.io - See NumPyro documentation: https://num.pyro.ai Hardware Recommendations ------------------------ For Best Performance ~~~~~~~~~~~~~~~~~~~~ **CPU:** - Multi-core processor (8+ cores recommended) - Large cache (L3 cache > 16MB) - Modern architecture (2020 or newer) **GPU:** - NVIDIA: RTX 3060 or better (12GB+ VRAM) - AMD: RX 6800 or better - Apple Silicon: M1 Pro/Max or M2/M3 **Memory:** - 16GB+ RAM for typical datasets - 32GB+ for large datasets (>1M points) - GPU VRAM: 8GB minimum, 16GB+ recommended **Storage:** - SSD for data loading - NVMe for best performance with large files Cost-Benefit Analysis ~~~~~~~~~~~~~~~~~~~~~ **CPU-only setup:** - Cost: $0 extra - Performance: Good for small datasets - Use case: Exploratory analysis, small datasets **Consumer GPU (RTX 4070):** - Cost: ~$500 - Performance: 10-50x faster than CPU - Use case: Regular batch processing, medium datasets **Professional GPU (RTX 4090):** - Cost: ~$1600 - Performance: 50-150x faster than CPU - Use case: Production, large-scale analysis, research **Cloud GPU (Google Colab, AWS):** - Cost: $0.50-$5/hour - Performance: Varies by instance - Use case: Occasional heavy computation