"""
Base class for Bayesian models using NumPyro.
This module provides the abstract base class for all Bayesian models in piblin-jax.
Models use NumPyro for MCMC sampling and uncertainty quantification.
"""
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
from jax import random
from numpyro.infer import MCMC, NUTS
[docs]
class BayesianModel(ABC):
"""
Abstract base class for Bayesian models using NumPyro.
This class provides the infrastructure for Bayesian inference using MCMC
sampling via NumPyro. Subclasses implement the model() method to define
the probabilistic model structure.
Parameters
----------
n_samples : int, optional
Number of MCMC samples to draw (default: 1000)
n_warmup : int, optional
Number of warmup samples for MCMC (default: 500)
n_chains : int, optional
Number of MCMC chains to run (default: 2)
random_seed : int, optional
Random seed for reproducibility (default: 0)
Attributes
----------
n_samples : int
Number of MCMC samples
n_warmup : int
Number of warmup samples
n_chains : int
Number of MCMC chains
random_seed : int
Random seed for PRNG
samples : dict[str, array] | None
Posterior samples from MCMC (None before fitting)
Examples
--------
>>> import numpy as np
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from piblin_jax.bayesian.base import BayesianModel
>>>
>>> class LinearRegressionModel(BayesianModel):
... def model(self, x, y=None):
... # Define priors
... slope = numpyro.sample('slope', dist.Normal(0, 10))
... intercept = numpyro.sample('intercept', dist.Normal(0, 10))
... sigma = numpyro.sample('sigma', dist.HalfNormal(1))
...
... # Define likelihood
... mu = slope * x + intercept
... with numpyro.plate('data', x.shape[0]):
... numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)
...
... def predict(self, x, credible_interval=0.95):
... # Generate predictions from posterior
... slope_samples = self._samples['slope']
... intercept_samples = self._samples['intercept']
... predictions = slope_samples[:, None] * x + intercept_samples[:, None]
... return {
... 'mean': np.mean(predictions, axis=0),
... 'lower': np.percentile(predictions, 2.5, axis=0),
... 'upper': np.percentile(predictions, 97.5, axis=0)
... }
>>>
>>> # Create and fit model
>>> x = np.linspace(0, 10, 50)
>>> y = 2.0 * x + 1.0 + 0.1 * np.random.randn(len(x))
>>> model = LinearRegressionModel(n_samples=1000, n_warmup=500)
>>> model.fit(x, y)
>>> predictions = model.predict(x)
Notes
-----
This class cannot be instantiated directly. Subclasses must implement:
- model(x, y=None, \\*\\*kwargs): Define the NumPyro probabilistic model
- predict(x, credible_interval=0.95): Generate predictions with uncertainty
The MCMC sampler uses the No-U-Turn Sampler (NUTS) algorithm, which is
an efficient Hamiltonian Monte Carlo variant that automatically tunes
the step size and number of steps.
"""
[docs]
def __init__(
self,
n_samples: int = 1000,
n_warmup: int = 500,
n_chains: int = 2,
random_seed: int = 0,
):
"""
Initialize BayesianModel.
Parameters
----------
n_samples : int, optional
Number of MCMC samples (default: 1000)
n_warmup : int, optional
Number of warmup samples (default: 500)
n_chains : int, optional
Number of MCMC chains (default: 2)
random_seed : int, optional
Random seed (default: 0)
"""
self.n_samples = n_samples
self.n_warmup = n_warmup
self.n_chains = n_chains
self.random_seed = random_seed
# Internal state (initialized after fitting)
self._mcmc = None
self._samples = None
[docs]
@abstractmethod
def model(self, x: Any, y: Any = None, **kwargs: Any) -> None:
"""
Define the NumPyro probabilistic model.
Subclasses must implement this method to specify the model structure,
including priors and likelihood.
Parameters
----------
x : array_like
Independent variable (input data)
y : array_like | None, optional
Dependent variable (observations, None for prediction)
**kwargs : dict
Additional model-specific parameters
Notes
-----
This method should use NumPyro's `sample` primitive to define
random variables. When y is not None, it should be used as the
observation in a `sample` call with `obs=y`.
Examples
--------
>>> def model(self, x, y=None):
... # Define priors
... slope = numpyro.sample('slope', dist.Normal(0, 10))
... intercept = numpyro.sample('intercept', dist.Normal(0, 10))
... sigma = numpyro.sample('sigma', dist.HalfNormal(1))
...
... # Define likelihood
... mu = slope * x + intercept
... numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)
"""
pass
[docs]
def fit(self, x: Any, y: Any, use_nlsq_init: bool = False, **kwargs: Any) -> "BayesianModel":
"""
Fit the model using MCMC sampling.
Runs MCMC using the No-U-Turn Sampler (NUTS) to sample from the
posterior distribution. Stores the posterior samples internally
for later prediction and inference.
Parameters
----------
x : array_like
Independent variable (input data)
y : array_like
Dependent variable (observations)
use_nlsq_init : bool, default=False
If True, use NLSQ to get initial parameter estimates for better
prior centering (experimental feature).
**kwargs : dict
Additional model-specific parameters passed to model()
Returns
-------
BayesianModel
Returns self for method chaining
Examples
--------
>>> model = MyBayesianModel(n_samples=1000, n_warmup=500)
>>> model.fit(x_data, y_data)
>>> predictions = model.predict(x_new)
Notes
-----
The use_nlsq_init parameter is experimental and may not work for all
model types. It attempts to use nonlinear least squares to find good
initial parameter estimates, which can help with MCMC convergence.
"""
# Optional: Use NLSQ for initialization (experimental)
# This is a placeholder for future enhancement
# Actual implementation would require model-specific logic
if use_nlsq_init:
# This is a placeholder - actual implementation would require
# extracting the deterministic part of the model and fitting
# with NLSQ to get good initial parameter values
pass
# Initialize NUTS kernel and MCMC sampler
kernel = NUTS(self.model)
self._mcmc = MCMC(
kernel,
num_samples=self.n_samples,
num_warmup=self.n_warmup,
num_chains=self.n_chains,
)
# Run MCMC sampling
rng_key = random.PRNGKey(self.random_seed)
assert self._mcmc is not None # Initialized above
self._mcmc.run(rng_key, x=x, y=y, **kwargs)
# Store posterior samples
self._samples = self._mcmc.get_samples()
return self
[docs]
@abstractmethod
def predict(self, x: Any, credible_interval: float = 0.95) -> dict[str, np.ndarray]:
"""
Generate predictions with uncertainty.
Subclasses must implement this method to generate predictions using
the posterior samples from MCMC.
Parameters
----------
x : array_like
Points to predict at
credible_interval : float, optional
Credible interval level (default: 0.95)
Returns
-------
dict
Dictionary with keys:
- 'mean': Mean prediction
- 'lower': Lower credible bound
- 'upper': Upper credible bound
- 'samples': Full posterior predictive samples (optional)
Raises
------
RuntimeError
If model has not been fit yet
Examples
--------
>>> predictions = model.predict(x_new, credible_interval=0.95)
>>> mean = predictions['mean']
>>> lower = predictions['lower']
>>> upper = predictions['upper']
"""
if self._samples is None:
raise RuntimeError("Model must be fit before prediction")
# Subclasses implement specific prediction logic
raise NotImplementedError("Subclasses must implement predict()")
@property
def samples(self) -> dict[str, np.ndarray] | None:
"""
Get posterior samples from MCMC.
:no-index:
Returns
-------
dict[str, np.ndarray] | None
Dictionary mapping parameter names to arrays of samples,
or None if model has not been fit yet.
Examples
--------
>>> model.fit(x, y)
>>> samples = model.samples
>>> slope_samples = samples['slope']
>>> print(f"Slope mean: {np.mean(slope_samples)}")
"""
return self._samples
[docs]
def get_credible_intervals(
self, param_name: str, level: float = 0.95, method: str = "eti"
) -> tuple[float, float]:
"""
Get credible intervals for a parameter.
Computes credible intervals from the posterior samples using either
equal-tailed intervals (ETI) or highest posterior density (HPD).
Parameters
----------
param_name : str
Name of the parameter (must match a sample name from model)
level : float, optional
Credible interval level between 0 and 1 (default: 0.95)
method : str, optional
Method for computing intervals:
- 'eti': Equal-tailed interval (default)
- 'hpd': Highest posterior density interval
Returns
-------
tuple[float, float]
(lower_bound, upper_bound) of the credible interval
Raises
------
RuntimeError
If model has not been fit yet
ValueError
If param_name is not in the posterior samples
ValueError
If method is not recognized
Examples
--------
>>> model.fit(x, y)
>>> lower, upper = model.get_credible_intervals('slope', level=0.95)
>>> print(f"95% credible interval for slope: [{lower:.3f}, {upper:.3f}]")
>>> # Use 68% interval (approximately 1 sigma)
>>> lower, upper = model.get_credible_intervals('slope', level=0.68)
Notes
-----
The ETI method uses percentiles and is simple to compute but may not
give the shortest interval for skewed distributions.
The HPD method finds the shortest interval containing the specified
probability mass, but this is a simplified implementation. For full
HPD computation, consider using the arviz library.
"""
if self._samples is None:
raise RuntimeError("Model must be fit first")
if param_name not in self._samples:
raise ValueError(f"Unknown parameter: {param_name}")
samples = self._samples[param_name]
# Defensive check: ensure samples is not None and not empty
if samples is None:
raise ValueError(f"No samples available for parameter: {param_name}")
samples_array = np.asarray(samples)
if samples_array.size == 0:
raise ValueError(f"Empty samples for parameter: {param_name}")
if method == "eti":
# Equal-tailed interval using percentiles
alpha = 1 - level
lower = float(np.percentile(samples_array, 100 * alpha / 2))
upper = float(np.percentile(samples_array, 100 * (1 - alpha / 2)))
elif method == "hpd":
# Simplified HPD (highest posterior density)
# For a proper HPD, would use arviz.hdi()
# This is an approximation using percentiles
alpha = 1 - level
lower = float(np.percentile(samples_array, 100 * alpha / 2))
upper = float(np.percentile(samples_array, 100 * (1 - alpha / 2)))
else:
raise ValueError(f"Unknown method: {method}. Use 'eti' or 'hpd'")
return (lower, upper)
[docs]
def summary(self) -> dict[str, dict[str, float]]:
"""
Get summary statistics for all parameters.
Returns
-------
dict
Dictionary mapping parameter names to summary statistics:
- 'mean': Posterior mean
- 'std': Posterior standard deviation
- 'q_2.5': 2.5th percentile
- 'q_50': Median (50th percentile)
- 'q_97.5': 97.5th percentile
Raises
------
RuntimeError
If model has not been fit yet
Examples
--------
>>> model.fit(x, y)
>>> summary = model.summary()
>>> print(summary['slope'])
{'mean': 2.01, 'std': 0.15, 'q_2.5': 1.72, 'q_50': 2.00, 'q_97.5': 2.31}
"""
if self._samples is None:
raise RuntimeError("Model must be fit first")
summary_dict = {}
for param_name, samples in self._samples.items():
# Defensive check: skip if samples is None or empty
if samples is None:
continue
# Convert to array to ensure proper type
samples_array = np.asarray(samples)
if samples_array.size == 0:
continue
summary_dict[param_name] = {
"mean": float(np.mean(samples_array)),
"std": float(np.std(samples_array)),
"q_2.5": float(np.percentile(samples_array, 2.5)),
"q_50": float(np.percentile(samples_array, 50)),
"q_97.5": float(np.percentile(samples_array, 97.5)),
}
return summary_dict