GPU Acceleration Best Practices

This tutorial covers best practices for leveraging GPU acceleration in piblin-jax using JAX’s automatic device placement and JIT compilation.

Prerequisites

  • Basic Workflow Tutorial - Basic piblin-jax usage

  • Linux with NVIDIA GPU + CUDA 12+ (GPU acceleration is Linux-only)

  • JAX with GPU support installed (see Installation)

  • Basic understanding of GPU computing concepts

Warning

GPU Support Platform Constraints:

  • Linux + NVIDIA GPU + CUDA 12: Full GPU acceleration (50-100x speedup)

  • macOS: CPU-only (no NVIDIA GPU support)

  • Windows: CPU-only (CUDA support experimental/unstable in JAX)

Deprecated backends removed: Apple Metal and AMD ROCm are no longer supported.

Installation for GPU Support

If you haven’t already installed GPU support, see the detailed installation guide in Installation. Quick summary:

Recommended (from repository):

git clone https://github.com/piblin/piblin-jax.git
cd piblin-jax
make install-gpu-cuda

Manual installation:

pip uninstall -y jax jaxlib
pip install "jax[cuda12-local]>=0.8.0,<0.9.0"
pip install piblin-jax

Overview

piblin-jax leverages JAX’s automatic GPU acceleration to deliver dramatic speedups (10-100x) for large datasets and compute-intensive operations. This tutorial shows you how to maximize GPU performance in your workflows.

Key benefits:

  • Automatic device placement - No manual GPU management

  • JIT compilation - Functions compiled once, executed fast

  • Batch processing - Efficient parallel computation

  • Transparent fallback - Works on CPU when GPU unavailable

Checking GPU Availability

First, verify GPU access:

from piblin_jax.backend import get_backend, get_device_info, is_jax_available

# Check backend
backend = get_backend()
print(f"Backend: {backend}")  # 'jax' or 'numpy'

# Check device info
if is_jax_available():
    info = get_device_info()
    print(f"Device Type: {info['device_type']}")  # 'cpu', 'gpu', or 'tpu'
    print(f"Platform: {info['os_platform']}")     # 'linux', 'macos', 'windows'
    print(f"GPU Supported: {info['gpu_supported']}")  # True/False
    print(f"Devices: {info['devices']}")

    if info['device_type'] == 'gpu':
        print("✓ GPU acceleration available!")
        print(f"CUDA version: {info['cuda_version']}")
    else:
        print("⚠ No GPU detected, using CPU")

        if not info['gpu_supported']:
            print(f"  Reason: GPU not supported on {info['os_platform']}")
else:
    print("⚠ JAX not installed, using NumPy backend")

Expected output with GPU:

Backend: jax
Device Type: gpu
Platform: linux
GPU Supported: True
Devices: ['cuda(id=0)']
✓ GPU acceleration available!
CUDA version: (12, 3)

Expected output on macOS/Windows:

Backend: jax
Device Type: cpu
Platform: macos
GPU Supported: False
Devices: ['CpuDevice(id=0)']
⚠ No GPU detected, using CPU
  Reason: GPU not supported on macos

Understanding Performance Characteristics

CPU vs GPU Trade-offs

GPUs excel at different workloads than CPUs:

GPU Advantages:

  • Parallel operations on large arrays (>10,000 elements)

  • Matrix operations (transforms, smoothing)

  • Repeated operations (MCMC sampling, batch processing)

  • Vectorized computations

CPU Advantages:

  • Small datasets (<1,000 elements)

  • Sequential operations

  • Complex control flow

  • Single operations (no repetition)

Rule of Thumb:

if dataset_size > 10_000 or repeated_operations:
    # Use GPU for significant speedup
    pass
else:
    # CPU is fine, GPU overhead not worth it
    pass

JIT Compilation

Basic JIT Usage

JIT compilation provides automatic optimization:

from piblin_jax.backend.operations import jit
from piblin_jax.backend import jnp

# Decorate functions for JIT compilation
@jit
def compute_gradient(x):
    """Compute gradient with JIT compilation."""
    return jnp.gradient(x)

# First call: compiles + executes (~100ms)
result1 = compute_gradient(data)

# Subsequent calls: uses cached compilation (~1ms)
result2 = compute_gradient(data)  # Much faster!

Performance Tips:

  1. JIT functions you’ll call repeatedly

  2. First call has compilation overhead - that’s normal

  3. Compiled functions are cached - reused automatically

  4. Works on both CPU and GPU - same code, automatic optimization

When to Use JIT

# ✓ Good candidates for JIT
@jit
def heavy_computation(x):
    """Complex mathematical operation - JIT improves performance."""
    return jnp.sum(jnp.exp(x) * jnp.sin(x) ** 2)

@jit
def matrix_operation(x):
    """Matrix ops benefit from JIT."""
    return jnp.dot(x.T, x)

# ✗ Poor candidates for JIT
def simple_operation(x):
    """Too simple - JIT overhead not worth it."""
    return x + 1

def data_dependent_control(x):
    """Data-dependent control flow - harder to compile."""
    if jnp.mean(x) > 0:  # Avoid this pattern
        return x * 2
    else:
        return x / 2

Batch Processing for GPU Efficiency

Processing Multiple Datasets

GPUs excel at batch operations:

from piblin_jax.transform import Pipeline
from piblin_jax.transform.dataset import GaussianSmoothing

# Instead of sequential processing
results = []
for dataset in datasets:  # Slow on GPU
    result = pipeline.apply(dataset)
    results.append(result)

# Better: Process in batches
# Stack datasets into single array
stacked_data = jnp.stack([ds.y for ds in datasets])

# Apply transform to entire batch at once
@jit
def batch_smooth(data_batch):
    # Process all datasets in parallel
    return gaussian_filter(data_batch, sigma=2.0, axis=1)

smoothed_batch = batch_smooth(stacked_data)

# Unstack results
results = [OneDimensionalDataset(ds.x, y)
           for ds, y in zip(datasets, smoothed_batch)]

Vectorization with vmap

Use vmap for automatic vectorization:

from piblin_jax.backend.operations import vmap

def process_single(x):
    """Process a single 1D array."""
    return jnp.cumsum(x) / jnp.arange(1, len(x) + 1)

# Vectorize across batch dimension
process_batch = vmap(process_single)

# Now process entire batch in parallel
batch_data = jnp.stack([dataset.y for dataset in datasets])
results = process_batch(batch_data)  # Parallel on GPU!

Memory Management

GPU Memory Constraints

GPUs have limited memory compared to CPU RAM:

# ✗ Bad: May run out of GPU memory
huge_dataset = create_dataset(size=100_000_000)  # 100M points
result = pipeline.apply(huge_dataset)  # OOM error!

# ✓ Good: Process in chunks
chunk_size = 1_000_000
results = []

for i in range(0, len(huge_dataset.x), chunk_size):
    chunk = create_chunk(huge_dataset, i, i + chunk_size)
    result = pipeline.apply(chunk)
    results.append(result)

# Combine results
final_result = combine_chunks(results)

Monitoring Memory Usage

import jax

# For CUDA GPUs
if jax.devices()[0].platform == 'gpu':
    # JAX manages memory automatically, but you can monitor:
    print("JAX will use GPU memory as needed")
    print("Set XLA_PYTHON_CLIENT_PREALLOCATE=false to disable preallocation")

# Best practice: Delete large arrays when done
large_array = jnp.zeros((10000, 10000))
result = process(large_array)
del large_array  # Free memory

Optimizing Transform Pipelines

Pipeline-Level Optimization

from piblin_jax.transform import Pipeline
from piblin_jax.transform.dataset import (
    GaussianSmoothing,
    MinMaxNormalization,
    Derivative
)

# Create pipeline
pipeline = Pipeline([
    GaussianSmoothing(sigma=2.0),  # GPU-optimized
    Derivative(order=1),           # GPU-optimized
    MinMaxNormalization()          # GPU-optimized
])

# Warm-up: Trigger JIT compilation
_ = pipeline.apply(sample_dataset)

# Now process many datasets efficiently
for dataset in large_dataset_collection:
    result = pipeline.apply(dataset)  # Fast!

Custom GPU-Optimized Transforms

Create transforms that leverage GPU:

from piblin_jax.transform.base import DatasetTransform
from piblin_jax.backend.operations import jit

class GPUOptimizedTransform(DatasetTransform):
    """Transform optimized for GPU execution."""

    @staticmethod
    @jit  # JIT compile for GPU
    def _compute(y, param):
        """GPU-accelerated computation."""
        # JAX operations automatically use GPU
        return jnp.fft.fft(y * param).real

    def apply(self, dataset):
        """Apply transform."""
        result_y = self._compute(dataset.y, self.param)
        return OneDimensionalDataset(dataset.x, result_y)

Performance Benchmarking

Measuring GPU Speedup

import time
from piblin_jax.backend import get_device_info

def benchmark_pipeline(pipeline, dataset, n_iterations=10):
    """Benchmark pipeline performance."""
    # Warm-up
    _ = pipeline.apply(dataset)

    # Benchmark
    start = time.time()
    for _ in range(n_iterations):
        result = pipeline.apply(dataset)
    end = time.time()

    avg_time = (end - start) / n_iterations
    device = get_device_info()['device_type']

    print(f"Device: {device}")
    print(f"Average time: {avg_time*1000:.2f} ms")
    print(f"Throughput: {len(dataset.x)/avg_time:.0f} points/second")

    return avg_time

# Compare CPU vs GPU
# (run this twice: once with CPU, once with GPU JAX)
pipeline = Pipeline([GaussianSmoothing(sigma=2.0)])
dataset = create_large_dataset(100_000)

cpu_time = benchmark_pipeline(pipeline, dataset)

# With GPU:
# gpu_time = benchmark_pipeline(pipeline, dataset)
# speedup = cpu_time / gpu_time
# print(f"GPU Speedup: {speedup:.1f}x")

MCMC/Bayesian Acceleration

Bayesian models benefit enormously from GPU:

from piblin_jax.bayesian import PowerLawModel

# Create model (automatically uses GPU if available)
model = PowerLawModel(
    n_samples=5000,  # More samples with GPU
    n_warmup=2000,
    n_chains=4       # Parallel chains on GPU
)

# Fit model - GPU provides 10-100x speedup
model.fit(shear_rate, viscosity)

# Expected performance:
# CPU: ~60 seconds
# GPU: ~2-5 seconds (10-30x faster)

GPU MCMC Tips:

  1. Use more samples - GPU makes large sample sizes feasible

  2. Run multiple chains - Parallel chains improve convergence diagnostics

  3. Batch predictions - Get posterior predictive for many x values at once

Common Issues and Solutions

Issue: GPU Not Detected

# Symptom: JAX reports 'cpu' instead of 'gpu'

# Solution 1: Verify platform support
from piblin_jax.backend import get_device_info
info = get_device_info()
if not info['gpu_supported']:
    print(f"GPU not supported on {info['os_platform']}")
    print("GPU acceleration requires Linux + CUDA 12+")

# Solution 2: Verify JAX GPU installation
import jax
print(jax.devices())  # Should show GPU devices

# Solution 3: Check CUDA/drivers (NVIDIA)
# Run: nvidia-smi (command line)

# Solution 4: Reinstall JAX with GPU support
# pip uninstall -y jax jaxlib
# pip install "jax[cuda12-local]>=0.8.0,<0.9.0"

Issue: Out of Memory Errors

# Symptom: "Out of memory" or "XLA allocation failed"

# Solution 1: Reduce batch size
chunk_size = 10_000  # Instead of 100_000

# Solution 2: Use smaller data types
data = data.astype(jnp.float32)  # Instead of float64

# Solution 3: Clear memory between operations
del large_intermediate_array

# Solution 4: Disable preallocation
# Set environment variable:
# export XLA_PYTHON_CLIENT_PREALLOCATE=false

Issue: Slow First Execution

# Symptom: First call to JIT function is very slow

# This is normal! JIT compilation happens on first call.
# Subsequent calls use cached compiled version.

# Solution: Warm up your functions
@jit
def my_function(x):
    return jnp.sum(x ** 2)

# Warm-up call (compile)
_ = my_function(jnp.array([1, 2, 3]))

# Now fast for all subsequent calls
result = my_function(my_data)  # Fast!

Best Practices Summary

  1. Use GPU on Linux with CUDA 12+ - Only platform with full GPU support

  2. Use GPU for large datasets (>10,000 elements) and repeated operations

  3. Apply JIT to performance-critical functions - first call compiles, subsequent calls are fast

  4. Process in batches - stack datasets and process together

  5. Use vmap for vectorization - automatic parallelization

  6. Monitor memory - chunk large datasets, delete unused arrays

  7. Warm up pipelines - run once before benchmarking

  8. Leverage Bayesian GPU acceleration - massive speedup for MCMC

Performance Comparison Table

Expected speedups (GPU vs CPU on Linux with CUDA 12+):

Operation

Dataset Size

GPU Speedup

Gaussian smoothing

10K points

5-10x

Gaussian smoothing

100K points

20-50x

Transform pipeline (3 steps)

100K points

30-70x

Bayesian MCMC (2K samples)

50 points

10-30x

Bayesian MCMC (10K samples)

50 points

50-100x

Batch processing (100 datasets)

10K each

40-80x

Note

Performance results are from benchmarks on Linux with NVIDIA A100 GPU and CUDA 12.3. Your results may vary based on hardware and workload.

Next Steps

See also