Performance Guide

This guide covers performance optimization in piblin-jax, from basic best practices to advanced GPU acceleration and profiling techniques.

Overview

piblin-jax is designed for high-performance scientific computing through:

JAX Backend

Automatic JIT (Just-In-Time) compilation for CPU/GPU acceleration.

Vectorized Operations

Array operations optimized for modern hardware (SIMD, GPU).

Smart Caching

Automatic caching of compiled functions for reuse.

Pipeline Optimization

Transform pipelines are optimized as a whole, not step-by-step.

NumPy Fallback

Graceful degradation to NumPy when JAX is unavailable.

Quick Start

Check Your Backend

Verify which backend piblin-jax is using:

from piblin_jax.backend import get_backend, is_jax_available

print(f"Backend: {get_backend()}")
print(f"JAX available: {is_jax_available()}")

if is_jax_available():
    import jax
    print(f"JAX version: {jax.__version__}")
    print(f"Devices: {jax.devices()}")

Expected output:

Backend: jax
JAX available: True
JAX version: 0.4.23
Devices: [CpuDevice(id=0)]

or with GPU:

Devices: [cuda(id=0), CpuDevice(id=0)]

Basic Performance Tips

1. Use JAX backend (installed by default)

Install with GPU support for 10-100x speedup:

pip install "jax[cuda12]"  # NVIDIA CUDA 12
pip install "jax[metal]"   # Apple Silicon (M1/M2/M3)

2. Reuse transform objects

JIT compilation happens on first call. Reusing transforms is fast:

# Good: Create once, use many times
smoother = GaussianSmoothing(sigma=2.0)
for dataset in datasets:
    smoothed = smoother.apply_to(dataset)

# Slower: Creates and compiles new transform each time
for dataset in datasets:
    smoothed = GaussianSmoothing(sigma=2.0).apply_to(dataset)

3. Use Pipelines

Pipelines optimize the entire sequence:

from piblin_jax.transform import Pipeline

pipeline = Pipeline([
    GaussianSmoothing(sigma=2.0),
    Derivative(order=1),
    Normalize(method='minmax')
])

# Single optimized pass
result = pipeline.apply_to(dataset)

4. Batch processing

Process multiple datasets together:

from piblin_jax.data.collections import MeasurementSet

# Efficient: Single batch operation
mset = MeasurementSet(datasets)
results = [transform.apply_to(ds) for ds in mset.measurements]

JAX Backend

JIT Compilation

JAX’s Just-In-Time compiler optimizes Python code to machine code:

First call: Compilation overhead
import time
from piblin_jax.transform.dataset import GaussianSmoothing

smoother = GaussianSmoothing(sigma=2.0)

# First call: slow (compilation)
start = time.time()
result1 = smoother.apply_to(dataset)
print(f"First call: {(time.time() - start)*1000:.1f}ms")

# Second call: fast (compiled)
start = time.time()
result2 = smoother.apply_to(dataset)
print(f"Second call: {(time.time() - start)*1000:.1f}ms")

Output:

First call: 45.2ms  (includes compilation)
Second call: 2.1ms  (reuses compiled code)

Key insight: The first call is slow, but subsequent calls are 10-100x faster.

Vectorization

JAX automatically vectorizes operations:

from piblin_jax.backend import jnp

# Both are fast, but vectorized is cleaner
# Manual loop (slower)
results = []
for i in range(1000):
    results.append(jnp.sin(data[i]))

# Vectorized (faster)
results = jnp.sin(data)  # Operates on entire array at once

vmap (vectorizing map):

Apply function across array dimensions:

from piblin_jax.backend.operations import vmap

def process_single(x):
    return x ** 2 + 2 * x + 1

# Vectorize the function
process_batch = vmap(process_single)

# Process all data at once
batch_data = jnp.array([[1, 2, 3], [4, 5, 6]])
results = process_batch(batch_data)

GPU Acceleration

Installation

NVIDIA GPU (CUDA):

# CUDA 12.x
pip install --upgrade pip
pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Or CUDA 11.x
pip install "jax[cuda11]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Apple Silicon (Metal):

pip install "jax[metal]"

AMD GPU (ROCm):

pip install "jax[rocm]" -f https://storage.googleapis.com/jax-releases/jax_rocm_releases.html

Verification

Check if GPU is detected:

import jax
print(f"GPU available: {len(jax.devices('gpu')) > 0}")
print(f"All devices: {jax.devices()}")

# Check memory
if len(jax.devices('gpu')) > 0:
    gpu = jax.devices('gpu')[0]
    print(f"GPU: {gpu.device_kind}")

Automatic GPU Usage

piblin-jax automatically uses GPU when available - no code changes needed:

# This code runs on GPU automatically if available
from piblin_jax.transform.dataset import GaussianSmoothing

smoother = GaussianSmoothing(sigma=2.0)
result = smoother.apply_to(large_dataset)

Backend automatically:

  1. Detects GPU

  2. Allocates arrays on GPU

  3. Compiles kernels for GPU

  4. Executes on GPU

  5. Returns results

When to Use GPU

GPU excels at:

  • Large datasets (>10,000 points)

  • Repeated operations (transform reuse)

  • Parallel operations (batch processing)

  • Heavy computation (smoothing, FFT, convolution)

CPU may be faster for:

  • Small datasets (<1,000 points)

  • One-time operations

  • Memory-limited tasks

  • Simple operations (addition, multiplication)

Benchmark comparison:

Operation

Data Size

CPU Time

GPU Time

Gaussian smoothing

1,000 points

2.1 ms

5.3 ms (slower!)

Gaussian smoothing

100,000 points

45 ms

1.2 ms (37x faster)

Derivative

10,000 points

8.5 ms

0.8 ms (10x faster)

Bayesian fitting

50 points, 2000 samples

12 s

0.8 s (15x faster)

Memory Management

GPU memory is limited:

Monitor memory usage:

# Check allocated memory (NVIDIA)
!nvidia-smi

Best practices:

  1. Process in batches for large datasets

  2. Clear cache between experiments:

    import jax
    jax.clear_caches()
    
  3. Use float32 instead of float64 (half the memory):

    from piblin_jax.backend import jnp
    data = jnp.array(data, dtype=jnp.float32)
    
  4. Explicitly move to CPU if needed:

    from piblin_jax.backend.operations import device_get
    cpu_array = device_get(gpu_array)
    

Performance Optimization

Transform Optimization

Use JIT-compiled transforms:

Built-in transforms are already optimized. For custom transforms:

from piblin_jax.transform.base import DatasetTransform
from piblin_jax.backend.operations import jit
from piblin_jax.backend import jnp

class FastCustomTransform(DatasetTransform):
    @staticmethod
    @jit
    def _compute(data, param):
        """JIT-compiled computation core."""
        return data * param + jnp.sin(data)

    def _apply(self, dataset):
        result = self._compute(dataset._dependent_variable_data, self.param)
        dataset._dependent_variable_data = result
        return dataset

Speedup: 3-100x depending on operation complexity.

Pipeline Optimization

Combine transforms into pipelines for optimization:

from piblin_jax.transform import Pipeline

# Optimized pipeline
pipeline = Pipeline([
    GaussianSmoothing(sigma=2.0),
    Derivative(order=1),
    Normalize(method='zscore')
])

# Single pass through data
result = pipeline.apply_to(dataset)

Why faster:

  • Single memory pass (cache-friendly)

  • Combined compilation

  • Reduced intermediate arrays

  • Automatic fusion of operations

Batch Processing

Process multiple datasets efficiently:

# Inefficient: One at a time
results = []
for dataset in datasets:
    results.append(transform.apply_to(dataset))

# Efficient: Batch processing
from piblin_jax.data.collections import MeasurementSet

mset = MeasurementSet(datasets)
# Transform applies optimizations across all datasets
results = [transform.apply_to(ds, make_copy=False) for ds in mset.measurements]

Tip: Use make_copy=False for in-place operations (saves memory).

Profiling

Time Measurements

Basic timing:

import time

start = time.time()
result = transform.apply_to(dataset)
elapsed = time.time() - start
print(f"Time: {elapsed*1000:.2f}ms")

Jupyter timing:

# Single run
%time result = transform.apply_to(dataset)

# Multiple runs (average)
%timeit result = transform.apply_to(dataset)

JAX Profiling

Detailed performance profiling:

import jax

# Profile compilation and execution
with jax.profiler.trace("/tmp/jax-trace"):
    result = transform.apply_to(dataset)

# View in Chrome: chrome://tracing
# Load /tmp/jax-trace

Memory Profiling

Track memory usage:

import psutil
import os

process = psutil.Process(os.getpid())

# Before
mem_before = process.memory_info().rss / 1024 / 1024  # MB
print(f"Memory before: {mem_before:.1f} MB")

# Operation
result = transform.apply_to(large_dataset)

# After
mem_after = process.memory_info().rss / 1024 / 1024
print(f"Memory after: {mem_after:.1f} MB")
print(f"Memory delta: {mem_after - mem_before:.1f} MB")

GPU memory (NVIDIA):

# Command line
nvidia-smi

# Python
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)

Bottleneck Analysis

Identify slow operations:

import cProfile
import pstats

# Profile code
profiler = cProfile.Profile()
profiler.enable()

# Your code here
for dataset in datasets:
    result = transform.apply_to(dataset)

profiler.disable()

# View results
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)  # Top 20 slowest functions

Common Performance Patterns

Pattern 1: Precompute and Reuse

# Bad: Recomputes every time
for dataset in datasets:
    smoother = GaussianSmoothing(sigma=2.0)
    result = smoother.apply_to(dataset)

# Good: Compile once, use many times
smoother = GaussianSmoothing(sigma=2.0)
# First call compiles
results = [smoother.apply_to(ds) for ds in datasets]

Speedup: 10-50x for loops over many datasets.

Pattern 2: In-Place Operations

# Memory-intensive: Creates copies
result1 = transform1.apply_to(dataset, make_copy=True)
result2 = transform2.apply_to(result1, make_copy=True)
result3 = transform3.apply_to(result2, make_copy=True)

# Memory-efficient: In-place
result = dataset.copy()  # Single copy at start
transform1.apply_to(result, make_copy=False)
transform2.apply_to(result, make_copy=False)
transform3.apply_to(result, make_copy=False)

Memory savings: 3x less memory usage.

Pattern 3: Lazy Evaluation

Defer computation until results are needed:

# Eager: Computes immediately
smoothed = smoother.apply_to(dataset)
normalized = normalizer.apply_to(smoothed)
# Use results later...

# Lazy: Compute only when needed
pipeline = Pipeline([smoother, normalizer])
# No computation yet

# Compute on demand
result = pipeline.apply_to(dataset)

Pattern 4: Array Reuse

from piblin_jax.backend import jnp

# Bad: Creates new arrays
for i in range(1000):
    temp = jnp.zeros(10000)
    result = compute_something(temp)

# Good: Reuse arrays
temp = jnp.zeros(10000)
for i in range(1000):
    result = compute_something(temp)

Benchmarks

Typical Performance

Performance vs reference implementation (NumPy-only):

Operation

NumPy Baseline

JAX (CPU)

JAX (GPU)

Dataset creation

1.0x

1.2x

1.1x

Gaussian smoothing

1.0x

5-10x

40-100x

Derivative

1.0x

3-5x

15-30x

FFT (large)

1.0x

2-4x

20-50x

Transform pipeline

1.0x

5-10x

50-100x

Bayesian fitting

1.0x (scipy)

10-15x

90-150x

MCMC sampling (2000 samples)

60s

5s

0.7s

Hardware used: - CPU: AMD Ryzen 9 5950X (16 cores) - GPU: NVIDIA RTX 3090 (24GB)

Real-World Examples

Example 1: Smoothing large dataset

import numpy as np
import time
from piblin_jax.data.datasets import OneDimensionalDataset
from piblin_jax.transform.dataset import GaussianSmoothing

# Large dataset
x = np.linspace(0, 100, 100000)
y = np.sin(x) + np.random.normal(0, 0.1, size=len(x))
dataset = OneDimensionalDataset(
    independent_variable_data=x,
    dependent_variable_data=y
)

# Benchmark
smoother = GaussianSmoothing(sigma=5.0)

# Warmup (compilation)
_ = smoother.apply_to(dataset)

# Measure
start = time.time()
result = smoother.apply_to(dataset)
elapsed = time.time() - start

print(f"Smoothing 100k points: {elapsed*1000:.2f}ms")

Results:

NumPy backend: 125ms
JAX (CPU): 18ms (7x faster)
JAX (GPU): 1.2ms (104x faster)

Example 2: Bayesian power-law fitting

from piblin_jax.bayesian.models import PowerLawModel

# Generate data
shear_rate = np.logspace(-1, 2, 30)
viscosity = 5.0 * shear_rate ** (-0.4) + np.random.normal(0, 0.5, size=30)

# Benchmark
model = PowerLawModel(n_samples=2000, n_warmup=1000)

start = time.time()
model.fit(shear_rate, viscosity)
elapsed = time.time() - start

print(f"Bayesian fitting (2000 samples): {elapsed:.2f}s")

Results:

NumPy backend (scipy): Not available
JAX (CPU): 5.2s
JAX (GPU): 0.7s (7x faster than CPU)

Optimization Checklist

Before Optimizing

  1. Profile first: Identify actual bottlenecks

  2. Measure baseline: Know current performance

  3. Set targets: Define acceptable performance

  4. Start simple: Basic optimizations often sufficient

During Optimization

  1. Use JAX backend: Install with pip install jax

  2. Reuse transforms: Avoid repeated compilation

  3. Use pipelines: Combine multiple transforms

  4. Batch processing: Process multiple datasets together

  5. In-place operations: Use make_copy=False carefully

  6. GPU acceleration: For large datasets (>10k points)

After Optimization

  1. Verify correctness: Ensure results unchanged

  2. Measure improvement: Compare to baseline

  3. Document: Note optimizations for maintenance

  4. Monitor: Check performance over time

Performance Anti-Patterns

Anti-Pattern 1: Premature GPU usage

# Bad: GPU overhead > computation for small data
tiny_dataset = OneDimensionalDataset(x[:100], y[:100])
result = gpu_transform.apply_to(tiny_dataset)

# Good: Use CPU for small data
result = cpu_transform.apply_to(tiny_dataset)

Anti-Pattern 2: Repeated compilation

# Bad: New transform every iteration
for sigma in [1.0, 2.0, 3.0, 4.0, 5.0]:
    result = GaussianSmoothing(sigma=sigma).apply_to(dataset)

# Good: Parameterize properly
# (Note: GaussianSmoothing sigma is compilation parameter,
# so this is unavoidable. For custom transforms, use runtime params)

Anti-Pattern 3: Unnecessary copying

# Bad: Excessive copying
data1 = dataset.copy()
data2 = data1.copy()
data3 = data2.copy()

# Good: Single copy when needed
working_data = dataset.copy()
# Modify working_data in-place

Anti-Pattern 4: Mixed precision without intent

# Bad: Implicit float64 (slower, more memory)
data = np.array([1, 2, 3])  # float64 by default

# Good: Explicit float32 (faster, less memory)
data = np.array([1, 2, 3], dtype=np.float32)

Troubleshooting

Slow Performance

Symptom: Operations slower than expected

Check:

  1. JAX backend installed? pip install jax

  2. GPU detected? import jax; print(jax.devices())

  3. First call compilation? Time second call

  4. Data size appropriate? GPU helps with >10k points

  5. Profiled? Use %timeit to find bottleneck

Memory Issues

Symptom: Out of memory errors

Solutions:

  1. Use make_copy=False for in-place operations

  2. Process in batches

  3. Use float32 instead of float64

  4. Clear JAX cache: jax.clear_caches()

  5. Reduce batch size for GPU

GPU Not Used

Symptom: GPU available but not used

Check:

  1. JAX installed with GPU support?

  2. CUDA/ROCm installed?

  3. jax.devices('gpu') returns devices?

  4. Data moved to GPU? (Automatic in piblin-jax)

Compilation Warnings

Symptom: Warnings during first call

Usually safe to ignore:

  • “Slow compilation” - Expected on first call

  • “Large constant” - JAX inlining arrays

  • “Tracing” - Normal JAX behavior

Action needed:

  • “Shape mismatch” - Check array shapes

  • “Type error” - Check data types

Further Reading

Hardware Recommendations

For Best Performance

CPU: - Multi-core processor (8+ cores recommended) - Large cache (L3 cache > 16MB) - Modern architecture (2020 or newer)

GPU: - NVIDIA: RTX 3060 or better (12GB+ VRAM) - AMD: RX 6800 or better - Apple Silicon: M1 Pro/Max or M2/M3

Memory: - 16GB+ RAM for typical datasets - 32GB+ for large datasets (>1M points) - GPU VRAM: 8GB minimum, 16GB+ recommended

Storage: - SSD for data loading - NVMe for best performance with large files

Cost-Benefit Analysis

CPU-only setup: - Cost: $0 extra - Performance: Good for small datasets - Use case: Exploratory analysis, small datasets

Consumer GPU (RTX 4070): - Cost: ~$500 - Performance: 10-50x faster than CPU - Use case: Regular batch processing, medium datasets

Professional GPU (RTX 4090): - Cost: ~$1600 - Performance: 50-150x faster than CPU - Use case: Production, large-scale analysis, research

Cloud GPU (Google Colab, AWS): - Cost: $0.50-$5/hour - Performance: Varies by instance - Use case: Occasional heavy computation