"""
Backend-agnostic array operations.
This module provides unified interfaces for common array operations that work
with both JAX and NumPy backends. It also provides JIT compilation decorators
and device placement utilities that gracefully degrade to no-ops when using
the NumPy backend.
"""
from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any, ParamSpec, TypeVar
import numpy as np
from . import _JAX_AVAILABLE, jnp
# Type variables for generic callable decorators
P = ParamSpec("P")
R = TypeVar("R")
# Array Operations
[docs]
def copy(arr: Any) -> Any:
"""
Create a copy of an array.
Parameters
----------
arr : array_like
Input array.
Returns
-------
array_like
Copy of the input array.
Examples
--------
>>> from piblin_jax.backend import jnp
>>> from piblin_jax.backend.operations import copy
>>> arr = jnp.array([1, 2, 3])
>>> arr_copy = copy(arr)
"""
if _JAX_AVAILABLE:
# JAX arrays are immutable, so copy is just array creation
return jnp.array(arr)
else:
return np.copy(arr)
[docs]
def concatenate(arrays: Sequence[Any], axis: int = 0) -> Any:
"""
Concatenate arrays along an existing axis.
Parameters
----------
arrays : sequence of array_like
Arrays to concatenate. All arrays must have the same shape except
in the dimension corresponding to axis.
axis : int, optional
Axis along which to concatenate. Default is 0.
Returns
-------
array_like
Concatenated array.
Examples
--------
>>> from piblin_jax.backend import jnp
>>> from piblin_jax.backend.operations import concatenate
>>> arr1 = jnp.array([1, 2])
>>> arr2 = jnp.array([3, 4])
>>> result = concatenate([arr1, arr2])
"""
return jnp.concatenate(arrays, axis=axis)
[docs]
def stack(arrays: Sequence[Any], axis: int = 0) -> Any:
"""
Stack arrays along a new axis.
Parameters
----------
arrays : sequence of array_like
Arrays to stack. All arrays must have the same shape.
axis : int, optional
Axis along which to stack. Default is 0.
Returns
-------
array_like
Stacked array with one additional dimension.
Examples
--------
>>> from piblin_jax.backend import jnp
>>> from piblin_jax.backend.operations import stack
>>> arr1 = jnp.array([1, 2, 3])
>>> arr2 = jnp.array([4, 5, 6])
>>> result = stack([arr1, arr2])
>>> result.shape
(2, 3)
"""
return jnp.stack(arrays, axis=axis)
[docs]
def reshape(arr: Any, shape: int | Sequence[int]) -> Any:
"""
Reshape an array.
Parameters
----------
arr : array_like
Input array.
shape : int or sequence of ints
New shape. One dimension can be -1, in which case it's inferred.
Returns
-------
array_like
Reshaped array.
Examples
--------
>>> from piblin_jax.backend import jnp
>>> from piblin_jax.backend.operations import reshape
>>> arr = jnp.array([1, 2, 3, 4, 5, 6])
>>> result = reshape(arr, (2, 3))
>>> result.shape
(2, 3)
"""
return jnp.reshape(arr, shape)
# JIT Compilation
[docs]
def jit[**P, R](
func: Callable[P, R] | None = None, **kwargs: Any
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""
Just-in-time compilation decorator.
For JAX backend, this uses jax.jit for compilation and optimization.
For NumPy backend, this is a no-op that returns the function unchanged.
Parameters
----------
func : callable, optional
Function to JIT compile.
**kwargs
Additional arguments passed to jax.jit (ignored for NumPy backend).
Returns
-------
callable
JIT-compiled function (JAX) or original function (NumPy).
Examples
--------
>>> from piblin_jax.backend.operations import jit
>>> from piblin_jax.backend import jnp
>>>
>>> @jit
... def compute(x):
... return x ** 2 + 2 * x + 1
>>>
>>> result = compute(jnp.array([1.0, 2.0, 3.0]))
"""
def decorator(f: Callable[P, R]) -> Callable[P, R]:
"""
Apply JIT compilation or no-op depending on backend availability.
Parameters
----------
f : callable
Function to compile.
Returns
-------
callable
Compiled function or wrapper.
Examples
--------
>>> def my_func(x):
... return x + 1
>>> compiled = decorator(my_func)
"""
if _JAX_AVAILABLE:
import jax
return jax.jit(f, **kwargs)
else:
# No-op for NumPy backend
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"""
Wrapper function for NumPy backend compatibility.
Returns
-------
Any
Result of wrapped function.
Examples
--------
>>> wrapper(1, 2, x=3) # doctest: +SKIP
"""
return f(*args, **kwargs)
return wrapper
# Support both @jit and @jit(static_argnums=0) syntax
if func is None:
return decorator
else:
return decorator(func)
[docs]
def vmap(
func: Callable[..., Any],
in_axes: int | Sequence[int | None] = 0,
out_axes: int = 0,
**kwargs: Any,
) -> Callable[..., Any]:
"""
Vectorizing map decorator.
For JAX backend, this uses jax.vmap for automatic vectorization.
For NumPy backend, this provides a simple implementation using iteration.
Parameters
----------
func : callable
Function to vectorize.
in_axes : int or sequence of int/None, optional
Axis to map over for each input. Default is 0.
out_axes : int, optional
Axis of output to map over. Default is 0.
**kwargs
Additional arguments passed to jax.vmap (ignored for NumPy backend).
Returns
-------
callable
Vectorized function.
Examples
--------
>>> from piblin_jax.backend.operations import vmap
>>> from piblin_jax.backend import jnp
>>>
>>> def add_one(x):
... return x + 1
>>>
>>> batched_add = vmap(add_one)
>>> result = batched_add(jnp.array([1, 2, 3]))
"""
if _JAX_AVAILABLE:
import jax
return jax.vmap(func, in_axes=in_axes, out_axes=out_axes, **kwargs)
else:
# Simple NumPy implementation
@wraps(func)
def wrapper(*args: Any) -> Any:
"""
Simplified vectorization wrapper for NumPy backend.
Returns
-------
array_like
Stacked results of function applied to each element.
Examples
--------
>>> wrapper([1, 2, 3]) # doctest: +SKIP
"""
# Basic implementation - map over first axis
if not args:
return func()
# Handle single input case
if len(args) == 1:
arr = args[0]
results = [func(arr[i]) for i in range(len(arr))]
return np.stack(results, axis=out_axes)
# Handle multiple inputs - this is simplified
# Real implementation would need to handle in_axes properly
raise NotImplementedError(
"NumPy backend vmap with multiple inputs not fully implemented. "
"Use JAX backend for full vmap support."
)
return wrapper
[docs]
def grad(
func: Callable[..., Any], argnums: int | Sequence[int] = 0, **kwargs: Any
) -> Callable[..., Any]:
"""
Gradient computation decorator.
For JAX backend, this uses jax.grad for automatic differentiation.
For NumPy backend, this raises NotImplementedError.
Parameters
----------
func : callable
Function to differentiate. Should return a scalar.
argnums : int or sequence of int, optional
Which arguments to differentiate with respect to. Default is 0.
**kwargs
Additional arguments passed to jax.grad (ignored for NumPy backend).
Returns
-------
callable
Function that computes gradients.
Examples
--------
>>> from piblin_jax.backend.operations import grad
>>> from piblin_jax.backend import jnp
>>>
>>> def loss(x):
... return jnp.sum(x ** 2)
>>>
>>> grad_loss = grad(loss)
>>> gradient = grad_loss(jnp.array([1.0, 2.0, 3.0]))
"""
if _JAX_AVAILABLE:
import jax
return jax.grad(func, argnums=argnums, **kwargs)
else:
def not_implemented(*args: Any, **kwargs: Any) -> Any:
"""
Raise error when automatic differentiation is unavailable.
Returns
-------
None
Never returns, always raises.
Raises
------
NotImplementedError
Always raised when JAX is unavailable.
Examples
--------
>>> not_implemented() # doctest: +SKIP
Traceback (most recent call last):
...
NotImplementedError: Automatic differentiation requires JAX backend
"""
raise NotImplementedError(
"Automatic differentiation requires JAX backend. "
"Install JAX or use numerical differentiation."
)
return not_implemented
# Device Management
[docs]
def device_put(arr: Any, device: Any | None = None) -> Any:
"""
Transfer array to a specific device.
For JAX backend, this uses jax.device_put for device placement.
For NumPy backend, this is a no-op that returns the array unchanged.
Parameters
----------
arr : array_like
Input array.
device : optional
Target device (JAX device object). Ignored for NumPy backend.
Returns
-------
array_like
Array on the specified device (JAX) or original array (NumPy).
Examples
--------
>>> from piblin_jax.backend.operations import device_put
>>> from piblin_jax.backend import jnp
>>> arr = jnp.array([1, 2, 3])
>>> arr_on_device = device_put(arr)
"""
if _JAX_AVAILABLE:
import jax
if device is None:
return jax.device_put(arr)
else:
return jax.device_put(arr, device)
else:
return arr
[docs]
def device_get(arr: Any) -> np.ndarray:
"""
Transfer array from device to host (NumPy array).
For JAX backend, this converts DeviceArray to NumPy array.
For NumPy backend, this is effectively a no-op.
Parameters
----------
arr : array_like
Input array.
Returns
-------
np.ndarray
NumPy array.
Examples
--------
>>> from piblin_jax.backend.operations import device_get
>>> from piblin_jax.backend import jnp
>>> arr = jnp.array([1, 2, 3])
>>> np_arr = device_get(arr)
"""
return np.asarray(arr)
# Type Conversions
[docs]
def ensure_array(arr: Any, dtype: Any | None = None) -> Any:
"""
Ensure input is a backend array with optional dtype conversion.
Parameters
----------
arr : array_like
Input data (array, list, scalar, etc.).
dtype : dtype, optional
Desired data type.
Returns
-------
array_like
Backend array.
Examples
--------
>>> from piblin_jax.backend.operations import ensure_array
>>> arr = ensure_array([1, 2, 3], dtype=float)
"""
if dtype is None:
return jnp.asarray(arr)
else:
return jnp.asarray(arr, dtype=dtype)
[docs]
def astype(arr: Any, dtype: Any) -> Any:
"""
Cast array to specified dtype.
Parameters
----------
arr : array_like
Input array.
dtype : dtype
Target data type.
Returns
-------
array_like
Array cast to specified dtype.
Examples
--------
>>> from piblin_jax.backend.operations import astype
>>> from piblin_jax.backend import jnp
>>> arr = jnp.array([1, 2, 3])
>>> arr_float = astype(arr, jnp.float32)
"""
return arr.astype(dtype)
# Export public API
__all__ = [
"astype",
"concatenate",
# Array operations
"copy",
"device_get",
# Device management
"device_put",
# Type conversions
"ensure_array",
"grad",
# JIT and vectorization
"jit",
"reshape",
"stack",
"vmap",
]