Backend Abstraction
Overview
The piblin_jax.backend module provides a unified interface for both JAX and NumPy
array operations, enabling piblin-jax to leverage JAX’s performance features (JIT
compilation, GPU acceleration, automatic differentiation) while maintaining full
compatibility with NumPy-only environments.
This abstraction layer is crucial for several reasons:
Graceful Degradation: The module automatically detects available backends at import time and falls back to NumPy when JAX is unavailable. This ensures piblin-jax works in any Python environment without requiring JAX as a hard dependency.
Unified API: All piblin-jax code uses the
jnpinterface exported by this module, which points to eitherjax.numpyornumpydepending on availability. This enables writing backend-agnostic code that works optimally with both.Performance Benefits: When JAX is available, piblin-jax automatically benefits from:
JIT Compilation: Functions are compiled to optimized machine code for significant speedups, especially for repeated operations
GPU/TPU Acceleration: Computations automatically utilize available accelerators
Vectorization: Advanced automatic vectorization for batch operations
Memory Efficiency: JAX’s functional approach enables better memory management
Device Management: The module provides utilities for querying available compute devices (CPU, GPU, TPU) and their properties, enabling device-aware optimizations.
Conversion Utilities: Functions for converting between JAX and NumPy arrays, as well as handling nested structures (pytrees), ensure seamless interoperability at API boundaries.
The backend abstraction is transparent to most users - you simply import and use
piblin_jax and it will automatically use the best available backend. Advanced users
can query backend status and device information for performance tuning.
Quick Examples
Basic Backend Detection
Check which backend is being used:
from piblin_jax.backend import get_backend, is_jax_available
# Check backend
backend = get_backend()
print(f"Using backend: {backend}") # 'jax' or 'numpy'
# Check JAX availability
if is_jax_available():
print("JAX available - GPU acceleration enabled")
else:
print("Using NumPy fallback")
Using the Unified Array Interface
Write backend-agnostic code:
from piblin_jax.backend import jnp
# Works with both JAX and NumPy
def compute_mean_squared(x):
return jnp.mean(x ** 2)
# This code runs on either backend
import numpy as np
data = np.array([1.0, 2.0, 3.0, 4.0])
result = compute_mean_squared(data)
Device Information and Management
Query available compute devices:
from piblin_jax.backend import get_device_info
# Get device information
info = get_device_info()
print(f"Backend: {info['backend']}")
print(f"Available devices: {info['devices']}")
print(f"Default device: {info['default_device']}")
if 'platform' in info:
print(f"Platform: {info['platform']}") # cpu, gpu, tpu
# Example output with JAX on GPU:
# Backend: jax
# Available devices: ['cuda:0', 'cuda:1']
# Default device: cuda:0
# Platform: gpu
Array Conversion Utilities
Convert between backends for API boundaries:
from piblin_jax.backend import jnp, to_numpy, from_numpy
# Create array with current backend
jax_array = jnp.array([1, 2, 3, 4, 5])
# Convert to NumPy (for saving, plotting, etc.)
numpy_array = to_numpy(jax_array)
print(type(numpy_array)) # <class 'numpy.ndarray'>
# Convert back to backend format
backend_array = from_numpy(numpy_array)
# Handle nested structures (pytrees)
from piblin_jax.backend import to_numpy_pytree, from_numpy_pytree
pytree = {
'params': {'weights': jnp.array([1, 2]), 'bias': jnp.array([0.5])},
'metrics': [jnp.array([0.95]), jnp.array([0.98])]
}
# Convert entire structure to NumPy
numpy_pytree = to_numpy_pytree(pytree)
See Also
Transformations - Transforms that leverage JAX optimization
Curve Fitting - Curve fitting with JAX acceleration
JAX Documentation - JAX fundamentals and API
JAX Device Management - Using multiple devices
API Reference
Module Contents
Backend abstraction layer for piblin-jax.
This module provides a unified interface for both JAX and NumPy backends, with automatic fallback to NumPy when JAX is unavailable.
The backend is detected at module import time and stored in the BACKEND global variable. All array operations should use the exported jnp interface which points to either jax.numpy or numpy depending on availability.
- piblin_jax.backend.from_numpy(arr)[source]
Convert a NumPy array to backend array.
This function converts NumPy arrays to the current backend format (JAX array if JAX available, otherwise returns NumPy array unchanged).
- Parameters:
arr (
np.ndarray) – Input NumPy array.- Returns:
Backend array (JAX DeviceArray if JAX available, else NumPy array).
- Return type:
array_like
Examples
>>> import numpy as np >>> from piblin_jax.backend import from_numpy >>> np_arr = np.array([1, 2, 3]) >>> backend_arr = from_numpy(np_arr)
- piblin_jax.backend.from_numpy_pytree(pytree)[source]
Convert a pytree (nested structure) of NumPy arrays to backend arrays.
Handles nested dictionaries, lists, and tuples containing arrays.
- Parameters:
pytree (
Any) – Nested structure containing NumPy arrays.- Returns:
Same structure with all arrays converted to backend format.
- Return type:
Any
Examples
>>> import numpy as np >>> from piblin_jax.backend import from_numpy_pytree >>> pytree = {'a': np.array([1, 2]), 'b': [np.array([3, 4])]} >>> backend_pytree = from_numpy_pytree(pytree)
- piblin_jax.backend.get_backend()[source]
Get the name of the current backend.
- Returns:
Either ‘jax’ or ‘numpy’ depending on which backend is in use.
- Return type:
Examples
>>> from piblin_jax.backend import get_backend >>> backend = get_backend() >>> print(f"Using backend: {backend}")
- piblin_jax.backend.get_device_info()[source]
Get information about available compute devices.
- Returns:
Dictionary containing: - ‘backend’: str, name of backend (‘jax’ or ‘numpy’) - ‘devices’: list, available compute devices - ‘default_device’: str, the default device being used - ‘os_platform’: str, detected OS platform (‘linux’, ‘macos’, ‘windows’) - ‘gpu_supported’: bool, whether GPU is supported on current platform - ‘cuda_version’: tuple or None, CUDA version (major, minor) if available - Additional JAX-specific info if JAX is available
- Return type:
Examples
>>> from piblin_jax.backend import get_device_info >>> info = get_device_info() >>> print(f"Backend: {info['backend']}") >>> print(f"Devices: {info['devices']}")
- piblin_jax.backend.is_jax_available()[source]
Check if JAX backend is available.
- Returns:
True if JAX is available and being used, False if using NumPy fallback.
- Return type:
Examples
>>> from piblin_jax.backend import is_jax_available >>> if is_jax_available(): ... print("Using JAX backend with GPU acceleration") ... else: ... print("Using NumPy fallback")
- piblin_jax.backend.to_numpy(arr)[source]
Convert a backend array to NumPy array.
This function handles conversion from JAX arrays to NumPy arrays, and passes through NumPy arrays unchanged. Useful for API boundaries where pure NumPy arrays are required.
- Parameters:
arr (
array_like) – Input array (JAX or NumPy array, or nested structure).- Returns:
NumPy array with the same data.
- Return type:
np.ndarray
Examples
>>> from piblin_jax.backend import jnp, to_numpy >>> jax_arr = jnp.array([1, 2, 3]) >>> np_arr = to_numpy(jax_arr) >>> type(np_arr) <class 'numpy.ndarray'>
- piblin_jax.backend.to_numpy_pytree(pytree)[source]
Convert a pytree (nested structure) of arrays to NumPy.
Handles nested dictionaries, lists, and tuples containing arrays.
- Parameters:
pytree (
Any) – Nested structure containing arrays.- Returns:
Same structure with all arrays converted to NumPy.
- Return type:
Any
Examples
>>> from piblin_jax.backend import jnp, to_numpy_pytree >>> pytree = {'a': jnp.array([1, 2]), 'b': [jnp.array([3, 4])]} >>> np_pytree = to_numpy_pytree(pytree)
Operations
Backend-agnostic array operations.
This module provides unified interfaces for common array operations that work with both JAX and NumPy backends. It also provides JIT compilation decorators and device placement utilities that gracefully degrade to no-ops when using the NumPy backend.
- piblin_jax.backend.operations.astype(arr, dtype)[source]
Cast array to specified dtype.
- Parameters:
arr (
array_like) – Input array.dtype (
dtype) – Target data type.
- Returns:
Array cast to specified dtype.
- Return type:
array_like
Examples
>>> from piblin_jax.backend.operations import astype >>> from piblin_jax.backend import jnp >>> arr = jnp.array([1, 2, 3]) >>> arr_float = astype(arr, jnp.float32)
- piblin_jax.backend.operations.concatenate(arrays, axis=0)[source]
Concatenate arrays along an existing axis.
- Parameters:
arrays (
sequenceofarray_like) – Arrays to concatenate. All arrays must have the same shape except in the dimension corresponding to axis.axis (
int, optional) – Axis along which to concatenate. Default is 0.
- Returns:
Concatenated array.
- Return type:
array_like
Examples
>>> from piblin_jax.backend import jnp >>> from piblin_jax.backend.operations import concatenate >>> arr1 = jnp.array([1, 2]) >>> arr2 = jnp.array([3, 4]) >>> result = concatenate([arr1, arr2])
- piblin_jax.backend.operations.copy(arr)[source]
Create a copy of an array.
- Parameters:
arr (
array_like) – Input array.- Returns:
Copy of the input array.
- Return type:
array_like
Examples
>>> from piblin_jax.backend import jnp >>> from piblin_jax.backend.operations import copy >>> arr = jnp.array([1, 2, 3]) >>> arr_copy = copy(arr)
- piblin_jax.backend.operations.device_get(arr)[source]
Transfer array from device to host (NumPy array).
For JAX backend, this converts DeviceArray to NumPy array. For NumPy backend, this is effectively a no-op.
- Parameters:
arr (
array_like) – Input array.- Returns:
NumPy array.
- Return type:
np.ndarray
Examples
>>> from piblin_jax.backend.operations import device_get >>> from piblin_jax.backend import jnp >>> arr = jnp.array([1, 2, 3]) >>> np_arr = device_get(arr)
- piblin_jax.backend.operations.device_put(arr, device=None)[source]
Transfer array to a specific device.
For JAX backend, this uses jax.device_put for device placement. For NumPy backend, this is a no-op that returns the array unchanged.
- Parameters:
arr (
array_like) – Input array.device (optional) – Target device (JAX device object). Ignored for NumPy backend.
- Returns:
Array on the specified device (JAX) or original array (NumPy).
- Return type:
array_like
Examples
>>> from piblin_jax.backend.operations import device_put >>> from piblin_jax.backend import jnp >>> arr = jnp.array([1, 2, 3]) >>> arr_on_device = device_put(arr)
- piblin_jax.backend.operations.ensure_array(arr, dtype=None)[source]
Ensure input is a backend array with optional dtype conversion.
- Parameters:
arr (
array_like) – Input data (array, list, scalar, etc.).dtype (
dtype, optional) – Desired data type.
- Returns:
Backend array.
- Return type:
array_like
Examples
>>> from piblin_jax.backend.operations import ensure_array >>> arr = ensure_array([1, 2, 3], dtype=float)
- piblin_jax.backend.operations.grad(func, argnums=0, **kwargs)[source]
Gradient computation decorator.
For JAX backend, this uses jax.grad for automatic differentiation. For NumPy backend, this raises NotImplementedError.
- Parameters:
- Returns:
Function that computes gradients.
- Return type:
callable
Examples
>>> from piblin_jax.backend.operations import grad >>> from piblin_jax.backend import jnp >>> >>> def loss(x): ... return jnp.sum(x ** 2) >>> >>> grad_loss = grad(loss) >>> gradient = grad_loss(jnp.array([1.0, 2.0, 3.0]))
- piblin_jax.backend.operations.jit(func=None, **kwargs)[source]
Just-in-time compilation decorator.
For JAX backend, this uses jax.jit for compilation and optimization. For NumPy backend, this is a no-op that returns the function unchanged.
- Parameters:
func (
callable, optional) – Function to JIT compile.**kwargs (Any) – Additional arguments passed to jax.jit (ignored for NumPy backend).
- Returns:
JIT-compiled function (JAX) or original function (NumPy).
- Return type:
callable
Examples
>>> from piblin_jax.backend.operations import jit >>> from piblin_jax.backend import jnp >>> >>> @jit ... def compute(x): ... return x ** 2 + 2 * x + 1 >>> >>> result = compute(jnp.array([1.0, 2.0, 3.0]))
- piblin_jax.backend.operations.reshape(arr, shape)[source]
Reshape an array.
- Parameters:
arr (
array_like) – Input array.shape (
intorsequenceofints) – New shape. One dimension can be -1, in which case it’s inferred.
- Returns:
Reshaped array.
- Return type:
array_like
Examples
>>> from piblin_jax.backend import jnp >>> from piblin_jax.backend.operations import reshape >>> arr = jnp.array([1, 2, 3, 4, 5, 6]) >>> result = reshape(arr, (2, 3)) >>> result.shape (2, 3)
- piblin_jax.backend.operations.stack(arrays, axis=0)[source]
Stack arrays along a new axis.
- Parameters:
arrays (
sequenceofarray_like) – Arrays to stack. All arrays must have the same shape.axis (
int, optional) – Axis along which to stack. Default is 0.
- Returns:
Stacked array with one additional dimension.
- Return type:
array_like
Examples
>>> from piblin_jax.backend import jnp >>> from piblin_jax.backend.operations import stack >>> arr1 = jnp.array([1, 2, 3]) >>> arr2 = jnp.array([4, 5, 6]) >>> result = stack([arr1, arr2]) >>> result.shape (2, 3)
- piblin_jax.backend.operations.vmap(func, in_axes=0, out_axes=0, **kwargs)[source]
Vectorizing map decorator.
For JAX backend, this uses jax.vmap for automatic vectorization. For NumPy backend, this provides a simple implementation using iteration.
- Parameters:
- Returns:
Vectorized function.
- Return type:
callable
Examples
>>> from piblin_jax.backend.operations import vmap >>> from piblin_jax.backend import jnp >>> >>> def add_one(x): ... return x + 1 >>> >>> batched_add = vmap(add_one) >>> result = batched_add(jnp.array([1, 2, 3]))