"""
Backend abstraction layer for piblin-jax.
This module provides a unified interface for both JAX and NumPy backends,
with automatic fallback to NumPy when JAX is unavailable.
The backend is detected at module import time and stored in the BACKEND global variable.
All array operations should use the exported `jnp` interface which points to either
jax.numpy or numpy depending on availability.
"""
import sys
import types
import warnings
from typing import Any, Union
import numpy as np
# Backend detection
_JAX_AVAILABLE = False
BACKEND = "numpy" # Default to NumPy
jnp: types.ModuleType = np # Default to NumPy
def _detect_platform() -> str:
"""
Detect the current operating system platform.
Returns
-------
str
One of 'linux', 'macos', or 'windows'.
"""
platform = sys.platform.lower()
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform.startswith("win"):
return "windows"
else:
# Default to the actual platform string for unknown platforms
return platform
def _get_cuda_version() -> tuple[int, int] | None:
"""
Get CUDA version from JAX backend.
Returns
-------
tuple of (int, int) or None
Tuple of (major, minor) version numbers, or None if CUDA unavailable.
"""
try:
import jax
# Try newer JAX API first (v0.8.0+)
try:
from jax.extend import backend as jax_backend
backend = jax_backend.get_backend()
except (ImportError, AttributeError):
# Fallback to older API for JAX < 0.8.0
backend = jax.lib.xla_bridge.get_backend()
version_string = backend.platform_version
# Parse version string (e.g., "12.0", "11.8", "12.3.1")
parts = version_string.split(".")
if len(parts) >= 2:
major = int(parts[0])
minor = int(parts[1])
return (major, minor)
return None
except Exception:
# CUDA not available or error accessing version
return None
def _validate_cuda_version(cuda_version: tuple[int, int] | None) -> bool:
"""
Validate that CUDA version meets minimum requirements.
Parameters
----------
cuda_version : tuple of (int, int) or None
CUDA version tuple (major, minor).
Returns
-------
bool
True if CUDA version >= 12.0, False otherwise.
"""
if cuda_version is None:
return False
major, _ = cuda_version
return major >= 12
def _check_legacy_gpu_extras() -> None:
"""
Check for legacy GPU extras (gpu-metal, gpu-rocm) and issue deprecation warning.
This function attempts to detect if the user has installed deprecated GPU extras
and warns them to migrate to gpu-cuda on Linux.
"""
try:
# Check if JAX was installed with Metal or ROCm backend
import jax
# Try to detect Metal backend (macOS)
try:
devices = jax.devices()
for device in devices:
device_str = str(device).lower()
if "metal" in device_str or "gpu" in device_str:
platform = _detect_platform()
if platform == "macos":
warnings.warn(
"Detected JAX with Metal backend. gpu-metal is deprecated. "
"GPU support is now only available on Linux with CUDA 12+. "
"On macOS, CPU-only mode is recommended.",
DeprecationWarning,
stacklevel=2,
)
return
except Exception: # nosec B110 # Intentional: silently ignore detection errors
pass
# Try to detect ROCm backend (AMD GPUs)
try:
devices = jax.devices()
for device in devices:
device_str = str(device).lower()
if "rocm" in device_str or "amd" in device_str:
warnings.warn(
"Detected JAX with ROCm backend. gpu-rocm is deprecated. "
"GPU support is now only available on Linux with CUDA 12+.",
DeprecationWarning,
stacklevel=2,
)
return
except Exception: # nosec B110 # Intentional: silently ignore detection errors
pass
except ImportError:
# JAX not installed, no legacy extras to check
pass
try:
import jax
import jax.numpy as jnp_jax
_JAX_AVAILABLE = True
BACKEND = "jax"
jnp = jnp_jax
# Check for legacy GPU extras and warn if detected
_check_legacy_gpu_extras()
# Platform validation
detected_platform = _detect_platform()
if detected_platform == "linux":
# On Linux, validate CUDA version
cuda_version = _get_cuda_version()
if not _validate_cuda_version(cuda_version):
warnings.warn(
"GPU acceleration requires CUDA 12+. Using JAX in CPU mode.",
UserWarning,
stacklevel=2,
)
# Keep JAX available in CPU mode - don't disable it
else:
# Non-Linux platforms: JAX runs in CPU mode (GPU unavailable)
warnings.warn(
"GPU acceleration is only available on Linux with CUDA 12+. Using JAX in CPU mode.",
UserWarning,
stacklevel=2,
)
# Keep JAX available in CPU mode - don't disable it
except ImportError:
warnings.warn(
"JAX not available, using NumPy (reduced performance). "
"Install JAX for GPU acceleration and JIT compilation: pip install jax jaxlib",
UserWarning,
stacklevel=2,
)
_JAX_AVAILABLE = False
BACKEND = "numpy"
jnp = np
[docs]
def is_jax_available() -> bool:
"""
Check if JAX backend is available.
Returns
-------
bool
True if JAX is available and being used, False if using NumPy fallback.
Examples
--------
>>> from piblin_jax.backend import is_jax_available
>>> if is_jax_available():
... print("Using JAX backend with GPU acceleration")
... else:
... print("Using NumPy fallback")
"""
return _JAX_AVAILABLE
[docs]
def get_backend() -> str:
"""
Get the name of the current backend.
Returns
-------
str
Either 'jax' or 'numpy' depending on which backend is in use.
Examples
--------
>>> from piblin_jax.backend import get_backend
>>> backend = get_backend()
>>> print(f"Using backend: {backend}")
"""
return BACKEND
[docs]
def get_device_info() -> dict[str, Any]:
"""
Get information about available compute devices.
Returns
-------
dict
Dictionary containing:
- 'backend': str, name of backend ('jax' or 'numpy')
- 'devices': list, available compute devices
- 'default_device': str, the default device being used
- 'os_platform': str, detected OS platform ('linux', 'macos', 'windows')
- 'gpu_supported': bool, whether GPU is supported on current platform
- 'cuda_version': tuple or None, CUDA version (major, minor) if available
- Additional JAX-specific info if JAX is available
Examples
--------
>>> from piblin_jax.backend import get_device_info
>>> info = get_device_info()
>>> print(f"Backend: {info['backend']}")
>>> print(f"Devices: {info['devices']}")
"""
info = {
"backend": BACKEND,
"devices": [],
"default_device": "cpu",
"os_platform": _detect_platform(),
"gpu_supported": False,
"cuda_version": None,
}
# Detect CUDA version if on Linux
if info["os_platform"] == "linux":
cuda_version = _get_cuda_version()
info["cuda_version"] = cuda_version
info["gpu_supported"] = _validate_cuda_version(cuda_version)
if _JAX_AVAILABLE:
try:
import jax
devices = jax.devices()
info["devices"] = [str(d) for d in devices]
info["default_device"] = str(jax.devices()[0])
info["device_count"] = len(devices)
# Add platform information using updated JAX API
try:
from jax.extend import backend as jax_backend
info["platform"] = jax_backend.get_backend().platform
except (ImportError, AttributeError):
# Fallback for older JAX versions
info["platform"] = str(devices[0]).split(":")[0] if devices else "cpu"
except Exception as e:
warnings.warn(f"Could not get JAX device info: {e}", UserWarning, stacklevel=2)
info["devices"] = ["cpu"]
else:
info["devices"] = ["cpu"]
info["platform"] = "numpy"
return info
[docs]
def to_numpy(arr: Any) -> np.ndarray:
"""
Convert a backend array to NumPy array.
This function handles conversion from JAX arrays to NumPy arrays,
and passes through NumPy arrays unchanged. Useful for API boundaries
where pure NumPy arrays are required.
Parameters
----------
arr : array_like
Input array (JAX or NumPy array, or nested structure).
Returns
-------
np.ndarray
NumPy array with the same data.
Examples
--------
>>> from piblin_jax.backend import jnp, to_numpy
>>> jax_arr = jnp.array([1, 2, 3])
>>> np_arr = to_numpy(jax_arr)
>>> type(np_arr)
<class 'numpy.ndarray'>
"""
if isinstance(arr, np.ndarray):
return arr
if _JAX_AVAILABLE:
# For JAX arrays, use np.asarray which handles DeviceArray conversion
try:
return np.asarray(arr)
except Exception:
# Fallback for complex types
return np.array(arr)
else:
# Already using NumPy backend
return np.asarray(arr)
[docs]
def from_numpy(arr: np.ndarray) -> Any:
"""
Convert a NumPy array to backend array.
This function converts NumPy arrays to the current backend format
(JAX array if JAX available, otherwise returns NumPy array unchanged).
Parameters
----------
arr : np.ndarray
Input NumPy array.
Returns
-------
array_like
Backend array (JAX DeviceArray if JAX available, else NumPy array).
Examples
--------
>>> import numpy as np
>>> from piblin_jax.backend import from_numpy
>>> np_arr = np.array([1, 2, 3])
>>> backend_arr = from_numpy(np_arr)
"""
if _JAX_AVAILABLE:
return jnp.asarray(arr)
else:
return arr
[docs]
def to_numpy_pytree(pytree: Any) -> Any:
"""
Convert a pytree (nested structure) of arrays to NumPy.
Handles nested dictionaries, lists, and tuples containing arrays.
Parameters
----------
pytree : Any
Nested structure containing arrays.
Returns
-------
Any
Same structure with all arrays converted to NumPy.
Examples
--------
>>> from piblin_jax.backend import jnp, to_numpy_pytree
>>> pytree = {'a': jnp.array([1, 2]), 'b': [jnp.array([3, 4])]}
>>> np_pytree = to_numpy_pytree(pytree)
"""
if isinstance(pytree, dict):
return {k: to_numpy_pytree(v) for k, v in pytree.items()}
elif isinstance(pytree, (list, tuple)):
converted = [to_numpy_pytree(item) for item in pytree]
return type(pytree)(converted)
elif hasattr(pytree, "__array__"):
# Anything that looks like an array
return to_numpy(pytree)
else:
return pytree
[docs]
def from_numpy_pytree(pytree: Any) -> Any:
"""
Convert a pytree (nested structure) of NumPy arrays to backend arrays.
Handles nested dictionaries, lists, and tuples containing arrays.
Parameters
----------
pytree : Any
Nested structure containing NumPy arrays.
Returns
-------
Any
Same structure with all arrays converted to backend format.
Examples
--------
>>> import numpy as np
>>> from piblin_jax.backend import from_numpy_pytree
>>> pytree = {'a': np.array([1, 2]), 'b': [np.array([3, 4])]}
>>> backend_pytree = from_numpy_pytree(pytree)
"""
if isinstance(pytree, dict):
return {k: from_numpy_pytree(v) for k, v in pytree.items()}
elif isinstance(pytree, (list, tuple)):
converted = [from_numpy_pytree(item) for item in pytree]
return type(pytree)(converted)
elif isinstance(pytree, np.ndarray):
return from_numpy(pytree)
else:
return pytree
# Export public API
__all__ = [
"BACKEND",
"from_numpy",
"from_numpy_pytree",
"get_backend",
"get_device_info",
"is_jax_available",
"jnp",
"to_numpy",
"to_numpy_pytree",
]