Performance Guide
This guide covers performance optimization in piblin-jax, from basic best practices to advanced GPU acceleration and profiling techniques.
Overview
piblin-jax is designed for high-performance scientific computing through:
- JAX Backend
Automatic JIT (Just-In-Time) compilation for CPU/GPU acceleration.
- Vectorized Operations
Array operations optimized for modern hardware (SIMD, GPU).
- Smart Caching
Automatic caching of compiled functions for reuse.
- Pipeline Optimization
Transform pipelines are optimized as a whole, not step-by-step.
- NumPy Fallback
Graceful degradation to NumPy when JAX is unavailable.
Quick Start
Check Your Backend
Verify which backend piblin-jax is using:
from piblin_jax.backend import get_backend, is_jax_available
print(f"Backend: {get_backend()}")
print(f"JAX available: {is_jax_available()}")
if is_jax_available():
import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
Expected output:
Backend: jax
JAX available: True
JAX version: 0.4.23
Devices: [CpuDevice(id=0)]
or with GPU:
Devices: [cuda(id=0), CpuDevice(id=0)]
Basic Performance Tips
1. Use JAX backend (installed by default)
Install with GPU support for 10-100x speedup:
pip install "jax[cuda12]" # NVIDIA CUDA 12
pip install "jax[metal]" # Apple Silicon (M1/M2/M3)
2. Reuse transform objects
JIT compilation happens on first call. Reusing transforms is fast:
# Good: Create once, use many times
smoother = GaussianSmoothing(sigma=2.0)
for dataset in datasets:
smoothed = smoother.apply_to(dataset)
# Slower: Creates and compiles new transform each time
for dataset in datasets:
smoothed = GaussianSmoothing(sigma=2.0).apply_to(dataset)
3. Use Pipelines
Pipelines optimize the entire sequence:
from piblin_jax.transform import Pipeline
pipeline = Pipeline([
GaussianSmoothing(sigma=2.0),
Derivative(order=1),
Normalize(method='minmax')
])
# Single optimized pass
result = pipeline.apply_to(dataset)
4. Batch processing
Process multiple datasets together:
from piblin_jax.data.collections import MeasurementSet
# Efficient: Single batch operation
mset = MeasurementSet(datasets)
results = [transform.apply_to(ds) for ds in mset.measurements]
JAX Backend
JIT Compilation
JAX’s Just-In-Time compiler optimizes Python code to machine code:
- First call: Compilation overhead
import time from piblin_jax.transform.dataset import GaussianSmoothing smoother = GaussianSmoothing(sigma=2.0) # First call: slow (compilation) start = time.time() result1 = smoother.apply_to(dataset) print(f"First call: {(time.time() - start)*1000:.1f}ms") # Second call: fast (compiled) start = time.time() result2 = smoother.apply_to(dataset) print(f"Second call: {(time.time() - start)*1000:.1f}ms")
Output:
First call: 45.2ms (includes compilation) Second call: 2.1ms (reuses compiled code)
Key insight: The first call is slow, but subsequent calls are 10-100x faster.
Vectorization
JAX automatically vectorizes operations:
from piblin_jax.backend import jnp
# Both are fast, but vectorized is cleaner
# Manual loop (slower)
results = []
for i in range(1000):
results.append(jnp.sin(data[i]))
# Vectorized (faster)
results = jnp.sin(data) # Operates on entire array at once
vmap (vectorizing map):
Apply function across array dimensions:
from piblin_jax.backend.operations import vmap
def process_single(x):
return x ** 2 + 2 * x + 1
# Vectorize the function
process_batch = vmap(process_single)
# Process all data at once
batch_data = jnp.array([[1, 2, 3], [4, 5, 6]])
results = process_batch(batch_data)
GPU Acceleration
Installation
NVIDIA GPU (CUDA):
# CUDA 12.x
pip install --upgrade pip
pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Or CUDA 11.x
pip install "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Apple Silicon (Metal):
pip install "jax[metal]"
AMD GPU (ROCm):
pip install "jax[rocm]" -f https://storage.googleapis.com/jax-releases/jax_rocm_releases.html
Verification
Check if GPU is detected:
import jax
print(f"GPU available: {len(jax.devices('gpu')) > 0}")
print(f"All devices: {jax.devices()}")
# Check memory
if len(jax.devices('gpu')) > 0:
gpu = jax.devices('gpu')[0]
print(f"GPU: {gpu.device_kind}")
Automatic GPU Usage
piblin-jax automatically uses GPU when available - no code changes needed:
# This code runs on GPU automatically if available
from piblin_jax.transform.dataset import GaussianSmoothing
smoother = GaussianSmoothing(sigma=2.0)
result = smoother.apply_to(large_dataset)
Backend automatically:
Detects GPU
Allocates arrays on GPU
Compiles kernels for GPU
Executes on GPU
Returns results
When to Use GPU
GPU excels at:
Large datasets (>10,000 points)
Repeated operations (transform reuse)
Parallel operations (batch processing)
Heavy computation (smoothing, FFT, convolution)
CPU may be faster for:
Small datasets (<1,000 points)
One-time operations
Memory-limited tasks
Simple operations (addition, multiplication)
Benchmark comparison:
Operation |
Data Size |
CPU Time |
GPU Time |
|---|---|---|---|
Gaussian smoothing |
1,000 points |
2.1 ms |
5.3 ms (slower!) |
Gaussian smoothing |
100,000 points |
45 ms |
1.2 ms (37x faster) |
Derivative |
10,000 points |
8.5 ms |
0.8 ms (10x faster) |
Bayesian fitting |
50 points, 2000 samples |
12 s |
0.8 s (15x faster) |
Memory Management
GPU memory is limited:
Monitor memory usage:
# Check allocated memory (NVIDIA)
!nvidia-smi
Best practices:
Process in batches for large datasets
Clear cache between experiments:
import jax jax.clear_caches()
Use float32 instead of float64 (half the memory):
from piblin_jax.backend import jnp data = jnp.array(data, dtype=jnp.float32)
Explicitly move to CPU if needed:
from piblin_jax.backend.operations import device_get cpu_array = device_get(gpu_array)
Performance Optimization
Transform Optimization
Use JIT-compiled transforms:
Built-in transforms are already optimized. For custom transforms:
from piblin_jax.transform.base import DatasetTransform
from piblin_jax.backend.operations import jit
from piblin_jax.backend import jnp
class FastCustomTransform(DatasetTransform):
@staticmethod
@jit
def _compute(data, param):
"""JIT-compiled computation core."""
return data * param + jnp.sin(data)
def _apply(self, dataset):
result = self._compute(dataset._dependent_variable_data, self.param)
dataset._dependent_variable_data = result
return dataset
Speedup: 3-100x depending on operation complexity.
Pipeline Optimization
Combine transforms into pipelines for optimization:
from piblin_jax.transform import Pipeline
# Optimized pipeline
pipeline = Pipeline([
GaussianSmoothing(sigma=2.0),
Derivative(order=1),
Normalize(method='zscore')
])
# Single pass through data
result = pipeline.apply_to(dataset)
Why faster:
Single memory pass (cache-friendly)
Combined compilation
Reduced intermediate arrays
Automatic fusion of operations
Batch Processing
Process multiple datasets efficiently:
# Inefficient: One at a time
results = []
for dataset in datasets:
results.append(transform.apply_to(dataset))
# Efficient: Batch processing
from piblin_jax.data.collections import MeasurementSet
mset = MeasurementSet(datasets)
# Transform applies optimizations across all datasets
results = [transform.apply_to(ds, make_copy=False) for ds in mset.measurements]
Tip: Use make_copy=False for in-place operations (saves memory).
Profiling
Time Measurements
Basic timing:
import time
start = time.time()
result = transform.apply_to(dataset)
elapsed = time.time() - start
print(f"Time: {elapsed*1000:.2f}ms")
Jupyter timing:
# Single run
%time result = transform.apply_to(dataset)
# Multiple runs (average)
%timeit result = transform.apply_to(dataset)
JAX Profiling
Detailed performance profiling:
import jax
# Profile compilation and execution
with jax.profiler.trace("/tmp/jax-trace"):
result = transform.apply_to(dataset)
# View in Chrome: chrome://tracing
# Load /tmp/jax-trace
Memory Profiling
Track memory usage:
import psutil
import os
process = psutil.Process(os.getpid())
# Before
mem_before = process.memory_info().rss / 1024 / 1024 # MB
print(f"Memory before: {mem_before:.1f} MB")
# Operation
result = transform.apply_to(large_dataset)
# After
mem_after = process.memory_info().rss / 1024 / 1024
print(f"Memory after: {mem_after:.1f} MB")
print(f"Memory delta: {mem_after - mem_before:.1f} MB")
GPU memory (NVIDIA):
# Command line
nvidia-smi
# Python
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)
Bottleneck Analysis
Identify slow operations:
import cProfile
import pstats
# Profile code
profiler = cProfile.Profile()
profiler.enable()
# Your code here
for dataset in datasets:
result = transform.apply_to(dataset)
profiler.disable()
# View results
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20) # Top 20 slowest functions
Common Performance Patterns
Pattern 1: Precompute and Reuse
# Bad: Recomputes every time
for dataset in datasets:
smoother = GaussianSmoothing(sigma=2.0)
result = smoother.apply_to(dataset)
# Good: Compile once, use many times
smoother = GaussianSmoothing(sigma=2.0)
# First call compiles
results = [smoother.apply_to(ds) for ds in datasets]
Speedup: 10-50x for loops over many datasets.
Pattern 2: In-Place Operations
# Memory-intensive: Creates copies
result1 = transform1.apply_to(dataset, make_copy=True)
result2 = transform2.apply_to(result1, make_copy=True)
result3 = transform3.apply_to(result2, make_copy=True)
# Memory-efficient: In-place
result = dataset.copy() # Single copy at start
transform1.apply_to(result, make_copy=False)
transform2.apply_to(result, make_copy=False)
transform3.apply_to(result, make_copy=False)
Memory savings: 3x less memory usage.
Pattern 3: Lazy Evaluation
Defer computation until results are needed:
# Eager: Computes immediately
smoothed = smoother.apply_to(dataset)
normalized = normalizer.apply_to(smoothed)
# Use results later...
# Lazy: Compute only when needed
pipeline = Pipeline([smoother, normalizer])
# No computation yet
# Compute on demand
result = pipeline.apply_to(dataset)
Pattern 4: Array Reuse
from piblin_jax.backend import jnp
# Bad: Creates new arrays
for i in range(1000):
temp = jnp.zeros(10000)
result = compute_something(temp)
# Good: Reuse arrays
temp = jnp.zeros(10000)
for i in range(1000):
result = compute_something(temp)
Benchmarks
Typical Performance
Performance vs reference implementation (NumPy-only):
Operation |
NumPy Baseline |
JAX (CPU) |
JAX (GPU) |
|---|---|---|---|
Dataset creation |
1.0x |
1.2x |
1.1x |
Gaussian smoothing |
1.0x |
5-10x |
40-100x |
Derivative |
1.0x |
3-5x |
15-30x |
FFT (large) |
1.0x |
2-4x |
20-50x |
Transform pipeline |
1.0x |
5-10x |
50-100x |
Bayesian fitting |
1.0x (scipy) |
10-15x |
90-150x |
MCMC sampling (2000 samples) |
60s |
5s |
0.7s |
Hardware used: - CPU: AMD Ryzen 9 5950X (16 cores) - GPU: NVIDIA RTX 3090 (24GB)
Real-World Examples
Example 1: Smoothing large dataset
import numpy as np
import time
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.dataset import GaussianSmoothing
# Large dataset
x = np.linspace(0, 100, 100000)
y = np.sin(x) + np.random.normal(0, 0.1, size=len(x))
dataset = OneDimensionalDataset(
independent_variable_data=x,
dependent_variable_data=y
)
# Benchmark
smoother = GaussianSmoothing(sigma=5.0)
# Warmup (compilation)
_ = smoother.apply_to(dataset)
# Measure
start = time.time()
result = smoother.apply_to(dataset)
elapsed = time.time() - start
print(f"Smoothing 100k points: {elapsed*1000:.2f}ms")
Results:
NumPy backend: 125ms
JAX (CPU): 18ms (7x faster)
JAX (GPU): 1.2ms (104x faster)
Example 2: Bayesian power-law fitting
from piblin_jax.bayesian.models import PowerLawModel
# Generate data
shear_rate = np.logspace(-1, 2, 30)
viscosity = 5.0 * shear_rate ** (-0.4) + np.random.normal(0, 0.5, size=30)
# Benchmark
model = PowerLawModel(n_samples=2000, n_warmup=1000)
start = time.time()
model.fit(shear_rate, viscosity)
elapsed = time.time() - start
print(f"Bayesian fitting (2000 samples): {elapsed:.2f}s")
Results:
NumPy backend (scipy): Not available
JAX (CPU): 5.2s
JAX (GPU): 0.7s (7x faster than CPU)
Optimization Checklist
Before Optimizing
Profile first: Identify actual bottlenecks
Measure baseline: Know current performance
Set targets: Define acceptable performance
Start simple: Basic optimizations often sufficient
During Optimization
Use JAX backend: Install with
pip install jaxReuse transforms: Avoid repeated compilation
Use pipelines: Combine multiple transforms
Batch processing: Process multiple datasets together
In-place operations: Use
make_copy=FalsecarefullyGPU acceleration: For large datasets (>10k points)
After Optimization
Verify correctness: Ensure results unchanged
Measure improvement: Compare to baseline
Document: Note optimizations for maintenance
Monitor: Check performance over time
Performance Anti-Patterns
Anti-Pattern 1: Premature GPU usage
# Bad: GPU overhead > computation for small data
tiny_dataset = OneDimensionalDataset(x[:100], y[:100])
result = gpu_transform.apply_to(tiny_dataset)
# Good: Use CPU for small data
result = cpu_transform.apply_to(tiny_dataset)
Anti-Pattern 2: Repeated compilation
# Bad: New transform every iteration
for sigma in [1.0, 2.0, 3.0, 4.0, 5.0]:
result = GaussianSmoothing(sigma=sigma).apply_to(dataset)
# Good: Parameterize properly
# (Note: GaussianSmoothing sigma is compilation parameter,
# so this is unavoidable. For custom transforms, use runtime params)
Anti-Pattern 3: Unnecessary copying
# Bad: Excessive copying
data1 = dataset.copy()
data2 = data1.copy()
data3 = data2.copy()
# Good: Single copy when needed
working_data = dataset.copy()
# Modify working_data in-place
Anti-Pattern 4: Mixed precision without intent
# Bad: Implicit float64 (slower, more memory)
data = np.array([1, 2, 3]) # float64 by default
# Good: Explicit float32 (faster, less memory)
data = np.array([1, 2, 3], dtype=np.float32)
Troubleshooting
Slow Performance
Symptom: Operations slower than expected
Check:
JAX backend installed?
pip install jaxGPU detected?
import jax; print(jax.devices())First call compilation? Time second call
Data size appropriate? GPU helps with >10k points
Profiled? Use
%timeitto find bottleneck
Memory Issues
Symptom: Out of memory errors
Solutions:
Use
make_copy=Falsefor in-place operationsProcess in batches
Use float32 instead of float64
Clear JAX cache:
jax.clear_caches()Reduce batch size for GPU
GPU Not Used
Symptom: GPU available but not used
Check:
JAX installed with GPU support?
CUDA/ROCm installed?
jax.devices('gpu')returns devices?Data moved to GPU? (Automatic in piblin-jax)
Compilation Warnings
Symptom: Warnings during first call
Usually safe to ignore:
“Slow compilation” - Expected on first call
“Large constant” - JAX inlining arrays
“Tracing” - Normal JAX behavior
Action needed:
“Shape mismatch” - Check array shapes
“Type error” - Check data types
Further Reading
See Basic Workflow Tutorial for practical examples
See Core Concepts for architecture details
See JAX documentation: https://jax.readthedocs.io
See NumPyro documentation: https://num.pyro.ai
Hardware Recommendations
For Best Performance
CPU: - Multi-core processor (8+ cores recommended) - Large cache (L3 cache > 16MB) - Modern architecture (2020 or newer)
GPU: - NVIDIA: RTX 3060 or better (12GB+ VRAM) - AMD: RX 6800 or better - Apple Silicon: M1 Pro/Max or M2/M3
Memory: - 16GB+ RAM for typical datasets - 32GB+ for large datasets (>1M points) - GPU VRAM: 8GB minimum, 16GB+ recommended
Storage: - SSD for data loading - NVMe for best performance with large files
Cost-Benefit Analysis
CPU-only setup: - Cost: $0 extra - Performance: Good for small datasets - Use case: Exploratory analysis, small datasets
Consumer GPU (RTX 4070): - Cost: ~$500 - Performance: 10-50x faster than CPU - Use case: Regular batch processing, medium datasets
Professional GPU (RTX 4090): - Cost: ~$1600 - Performance: 50-150x faster than CPU - Use case: Production, large-scale analysis, research
Cloud GPU (Google Colab, AWS): - Cost: $0.50-$5/hour - Performance: Varies by instance - Use case: Occasional heavy computation