Uncertainty Quantification

This guide provides a comprehensive overview of uncertainty quantification in piblin-jax, covering both the theoretical foundations and practical implementation details.

Overview

piblin-jax provides comprehensive Bayesian uncertainty quantification through NumPyro, a probabilistic programming framework built on JAX. This enables:

Full posterior distributions

Not just point estimates, but complete probability distributions over parameter values given your data.

Credible intervals

Rigorous uncertainty ranges with clear probabilistic interpretation (e.g., “95% probability the parameter is in this range”).

Uncertainty propagation

Automatically propagate uncertainty through transforms and computations.

Model comparison

Principled comparison of competing models using Bayesian evidence.

GPU acceleration

Leverages JAX for fast MCMC sampling on GPU hardware.

Why Bayesian Methods?

Traditional vs Bayesian Fitting

Traditional Least-Squares:

  • Provides point estimates of parameters

  • Standard errors assume asymptotic normality

  • Difficult to incorporate prior knowledge

  • No natural way to compare models

Bayesian Approach:

  • Provides full posterior distributions

  • Credible intervals have direct probabilistic meaning

  • Naturally incorporates prior information

  • Model comparison via Bayes factors or information criteria

  • Uncertainty propagation is straightforward

Probabilistic Interpretation

Consider fitting a power-law model to viscosity data:

Least-squares result:

K = 5.0 ± 0.3 (standard error)

This means: “If we repeated the experiment many times and computed K each time, 68% of those estimates would fall within ±0.3 of the true value.”

Bayesian result:

K = 5.0 [4.5, 5.5] (95% credible interval)

This means: “Given the data, there’s a 95% probability that the true K is between 4.5 and 5.5.”

The Bayesian interpretation is more intuitive and directly answers the question “What do we know about this parameter?”

Bayesian Workflow

Standard Analysis Pipeline

  1. Specify the Model Define the mathematical relationship between parameters and data:

    from piblin_jax.bayesian.models import PowerLawModel
    model = PowerLawModel(n_samples=2000)
    
  2. Fit the Model Run MCMC to sample from the posterior distribution:

    model.fit(shear_rate, viscosity)
    
  3. Check Diagnostics Verify that MCMC chains have converged:

    if model.is_fitted:
        print("Sampling successful")
    else:
        print("Warning: Check convergence diagnostics")
    
  4. Examine Posterior Analyze parameter distributions:

    print(model.summary())
    samples = model.samples
    
  5. Make Predictions Generate predictions with uncertainty:

    predictions = model.predict(new_x, return_uncertainty=True)
    
  6. Compare Models Evaluate competing models:

    aic = model.aic()
    bic = model.bic()
    

Built-in Models

Power-Law Model

Equation:

\[\eta = K \dot{\gamma}^{n-1}\]

Parameters:

  • K: consistency index (Pa·s^n)

  • n: flow behavior index (dimensionless)

  • sigma: observation noise standard deviation

Use case: Shear-thinning/thickening fluids

Example:

from piblin_jax.bayesian.models import PowerLawModel
import numpy as np

# Data
shear_rate = np.logspace(-1, 2, 20)
viscosity = np.array([...])  # Experimental data

# Fit model
model = PowerLawModel(n_samples=2000, n_warmup=1000)
model.fit(shear_rate, viscosity)

# View results
print(model.summary())

# Extract samples
K_samples = model.samples['K']
n_samples = model.samples['n']

Priors:

  • K ~ LogNormal(log(10), 2.0): Weakly informative, centered at 10 Pa·s^n

  • n ~ Normal(0.8, 0.5): Weakly informative, centered at shear-thinning

  • sigma ~ HalfNormal(1.0): Observation noise

Arrhenius Model

Equation:

\[\eta(T) = A \exp\left(\frac{E_a}{RT}\right)\]

Parameters:

  • A: pre-exponential factor (Pa·s)

  • Ea: activation energy (J/mol)

  • sigma: observation noise

Use case: Temperature-dependent viscosity

Example:

from piblin_jax.bayesian.models import ArrheniusModel

# Temperature data (K)
temperature = np.array([273, 298, 323, 348, 373])
viscosity = np.array([15.2, 8.5, 5.1, 3.2, 2.1])

# Fit model
model = ArrheniusModel(n_samples=2000)
model.fit(temperature, viscosity)

# Extract activation energy
Ea_mean = np.mean(model.samples['Ea'])
print(f"Activation energy: {Ea_mean/1000:.1f} kJ/mol")

Priors:

  • A ~ LogNormal(log(1e-3), 5.0): Very weak prior

  • Ea ~ Normal(50000, 20000): Centered at typical liquid Ea

  • sigma ~ HalfNormal(1.0)

Cross Model

Equation:

\[\eta = \eta_\infty + \frac{\eta_0 - \eta_\infty}{1 + (\lambda \dot{\gamma})^m}\]

Parameters:

  • η₀: zero-shear viscosity (Pa·s)

  • η∞: infinite-shear viscosity (Pa·s)

  • λ: relaxation time (s)

  • m: rate constant (dimensionless)

  • sigma: observation noise

Use case: Polymer melts/solutions with zero-shear plateau

Example:

from piblin_jax.bayesian.models import CrossModel

# Wide shear rate range
shear_rate = np.logspace(-3, 3, 50)
viscosity = np.array([...])

# Fit model
model = CrossModel(n_samples=2000)
model.fit(shear_rate, viscosity)

# Extract plateaus
eta_0 = np.mean(model.samples['eta_0'])
eta_inf = np.mean(model.samples['eta_inf'])
print(f"Shear-thinning ratio: {eta_0/eta_inf:.1f}x")

Priors:

  • η₀ ~ LogNormal(log(100), 2.0)

  • η∞ ~ LogNormal(log(1), 2.0)

  • λ ~ LogNormal(log(1), 2.0)

  • m ~ Normal(0.7, 0.3): Constrained to (0, ∞)

  • sigma ~ HalfNormal(scale based on data)

Carreau-Yasuda Model

Equation:

\[\eta = \eta_\infty + (\eta_0 - \eta_\infty)[1 + (\lambda \dot{\gamma})^a]^{(n-1)/a}\]

Parameters:

  • η₀, η∞: viscosity plateaus (Pa·s)

  • λ: time constant (s)

  • a: transition parameter (dimensionless)

  • n: power-law index (dimensionless)

  • sigma: observation noise

Use case: Complex non-Newtonian behavior with smooth transitions

Example:

from piblin_jax.bayesian.models import CarreauYasudaModel

model = CarreauYasudaModel(n_samples=2000)
model.fit(shear_rate, viscosity)

# Most flexible model, but requires good data
# Compare with simpler models using AIC

Priors:

  • η₀ ~ LogNormal(log(1000), 2.0)

  • η∞ ~ LogNormal(log(0.1), 2.0)

  • λ ~ LogNormal(log(1), 2.0)

  • a ~ LogNormal(log(2), 1.0)

  • n ~ Normal(0.5, 0.3): Constrained to (0, 1)

  • sigma ~ HalfNormal(scale based on data)

MCMC Sampling

How MCMC Works

Markov Chain Monte Carlo (MCMC) is an algorithm for sampling from probability distributions that are difficult to sample from directly.

Key concepts:

  1. Markov Chain: Sequence of samples where each depends only on the previous one

  2. Stationary Distribution: Target distribution (posterior) that chain converges to

  3. Burn-in (warmup): Initial samples discarded before convergence

  4. Thinning: Optional subsampling to reduce autocorrelation

NumPyro’s NUTS sampler:

piblin-jax uses the No-U-Turn Sampler (NUTS), a variant of Hamiltonian Monte Carlo:

  • Automatically tunes step size and trajectory length

  • More efficient than basic MCMC (Metropolis-Hastings)

  • Typically requires fewer samples for same accuracy

  • Works well in high dimensions

Sampling Parameters

n_samples (default: 1000)

Number of posterior samples to draw after warmup. More samples = better posterior approximation but slower.

  • 1000-2000: Good for most applications

  • 3000-5000: High-precision uncertainty estimates

  • 10000+: Publication-quality results

n_warmup (default: 1000)

Number of warmup/burn-in samples to discard. Used to tune sampler parameters and reach stationary distribution.

  • 500-1000: Usually sufficient

  • 2000+: Complex models or difficult posteriors

n_chains (default: 1)

Number of independent MCMC chains to run. Multiple chains help diagnose convergence issues.

  • 1: Fast, but less diagnostic information

  • 4: Standard for convergence diagnostics (R-hat, ESS)

Example:

# High-quality fit with convergence diagnostics
model = PowerLawModel(
    n_samples=2000,
    n_warmup=1000,
    n_chains=4
)
model.fit(shear_rate, viscosity)

Convergence Diagnostics

Always check if sampling succeeded:

if not model.is_fitted:
    print("Warning: Sampling may not have converged")
    # Increase n_samples or n_warmup
    # Check for model misspecification
R-hat statistic (Gelman-Rubin):

Measures agreement between chains.

  • R-hat ≈ 1.0: Good convergence

  • R-hat > 1.1: Poor convergence, increase warmup

Effective sample size (ESS):

Accounts for autocorrelation in samples.

  • ESS ≈ n_samples: Low autocorrelation (good)

  • ESS << n_samples: High autocorrelation (increase samples)

Posterior Analysis

Summary Statistics

The summary() method provides key statistics:

summary = model.summary()
print(summary)

Output:

Parameter Posterior Summary:
----------------------------
K: mean=5.02, std=0.15, 95% CI=[4.73, 5.31]
n: mean=0.598, std=0.012, 95% CI=[0.575, 0.621]
sigma: mean=0.51, std=0.09, 95% CI=[0.38, 0.71]

Interpreting statistics:

  • mean: Posterior mean (Bayesian point estimate)

  • std: Posterior standard deviation (uncertainty measure)

  • 95% CI: 95% credible interval (central posterior density)

Accessing Samples

Direct access to posterior samples:

samples = model.samples

# Extract specific parameter
K_samples = samples['K']  # Array of shape (n_samples,)
n_samples = samples['n']

# Custom statistics
K_median = np.median(K_samples)
K_mode = K_samples[np.argmax(np.histogram(K_samples, bins=50)[0])]

# Quantiles
K_quantiles = np.percentile(K_samples, [2.5, 50, 97.5])

Credible Intervals

Equal-tailed interval (ETI):

Default method. 2.5th and 97.5th percentiles for 95% CI:

lower = np.percentile(K_samples, 2.5)
upper = np.percentile(K_samples, 97.5)
Highest density interval (HDI):

Shortest interval containing 95% of probability mass. Preferred for skewed distributions:

from piblin_jax.bayesian.utils import compute_hdi
lower, upper = compute_hdi(K_samples, credible_mass=0.95)

Parameter Correlations

Posterior samples reveal parameter correlations:

import matplotlib.pyplot as plt

# Joint distribution
plt.scatter(samples['K'], samples['n'], alpha=0.3, s=1)
plt.xlabel('K')
plt.ylabel('n')
plt.title('Joint Posterior Distribution')

# Correlation coefficient
corr = np.corrcoef(samples['K'], samples['n'])[0, 1]
print(f"Correlation: {corr:.3f}")

High correlation indicates parameters are not independently identifiable from the data (common in complex models).

Model Comparison

Information Criteria

Akaike Information Criterion (AIC):

Lower is better. Balances fit quality and model complexity:

aic = model.aic()

Bayesian Information Criterion (BIC):

Similar to AIC but penalizes complexity more heavily:

bic = model.bic()

Comparing models:

from piblin_jax.bayesian.models import PowerLawModel, CrossModel

# Fit both models
power_law = PowerLawModel(n_samples=2000)
power_law.fit(shear_rate, viscosity)

cross = CrossModel(n_samples=2000)
cross.fit(shear_rate, viscosity)

# Compare
print(f"Power-law AIC: {power_law.aic():.1f}")
print(f"Cross AIC: {cross.aic():.1f}")

delta_aic = abs(power_law.aic() - cross.aic())
if delta_aic < 2:
    print("Models are essentially equivalent")
elif delta_aic < 10:
    print("Moderate evidence for preferred model")
else:
    print("Strong evidence for preferred model")

Bayes Factors

More rigorous model comparison using marginal likelihoods.

Note: Requires setting enable_bayes_factor=True during fitting:

model = PowerLawModel(n_samples=2000)
model.fit(shear_rate, viscosity, enable_bayes_factor=True)

# Get log marginal likelihood
log_ml = model.log_marginal_likelihood

Uncertainty Propagation

Dataset Integration

Add uncertainty to datasets:

from piblin_jax.data.datasets import OneDimensionalDataset

# Create dataset
dataset = OneDimensionalDataset(
    independent_variable_data=shear_rate,
    dependent_variable_data=viscosity
)

# Fit Bayesian model
model = PowerLawModel(n_samples=2000)
model.fit(shear_rate, viscosity)

# Add uncertainty to dataset
dataset_with_unc = dataset.with_uncertainty(
    model=model,
    n_samples=1000,
    keep_samples=True
)

# Check status
print(f"Has uncertainty: {dataset_with_unc.has_uncertainty}")

Transform Propagation

Propagate uncertainty through transforms:

from piblin_jax.transform.dataset import GaussianSmoothing

# Apply transform with uncertainty propagation
smoother = GaussianSmoothing(sigma=2.0)
smoothed = smoother.apply_to(
    dataset_with_unc,
    propagate_uncertainty=True
)

# Uncertainty is now propagated
lower, upper = smoothed.get_credible_intervals(level=0.95)

Monte Carlo Propagation

For custom operations, use Monte Carlo:

# Function to propagate uncertainty through
def my_operation(K, n, shear_rate):
    return K * shear_rate ** (n - 1)

# Sample-based propagation
results = []
for i in range(len(samples['K'])):
    K_i = samples['K'][i]
    n_i = samples['n'][i]
    result_i = my_operation(K_i, n_i, new_shear_rate)
    results.append(result_i)

results = np.array(results)

# Compute uncertainty
mean_result = np.mean(results, axis=0)
lower_result = np.percentile(results, 2.5, axis=0)
upper_result = np.percentile(results, 97.5, axis=0)

Best Practices

Model Selection

  1. Start simple: Begin with power-law or Arrhenius

  2. Check residuals: Look for systematic patterns

  3. Add complexity: Move to Cross or Carreau-Yasuda if needed

  4. Compare formally: Use AIC/BIC to justify complexity

  5. Physical meaning: Prefer models with interpretable parameters

Prior Selection

Weakly informative priors (default in piblin-jax):

  • Constrain parameters to physically reasonable ranges

  • Don’t dominate the data

  • Help with numerical stability

Custom priors:

For domain expertise, modify priors:

# Example: Strong prior on power-law index
# (requires model subclassing, see API docs)

Computational Efficiency

Start small:

Test with n_samples=500, n_warmup=500 first

Increase gradually:

Double samples until results stabilize

Use GPU:

Install JAX with GPU support for 10-100x speedup:

pip install "jax[cuda12]"  # NVIDIA
pip install "jax[metal]"   # Apple Silicon
Parallel chains:

Use n_chains=4 on multi-core CPU

Common Issues

Poor convergence (R-hat > 1.1):

  • Increase n_warmup

  • Try different initial values

  • Simplify the model

Low ESS:

  • Increase n_samples

  • Check for high parameter correlation

  • Consider reparameterization

Unrealistic posteriors:

  • Check data scaling (avoid extreme values)

  • Verify model is appropriate for data

  • Inspect prior sensitivity

Slow sampling:

  • Reduce n_samples for initial exploration

  • Use GPU acceleration

  • Consider simpler model

Advanced Topics

Custom Models

Subclass BayesianModel to create custom models. See API documentation for details.

Hierarchical Models

Model variation across groups (e.g., multiple experiments). Requires custom NumPyro model definition.

Model Averaging

Combine predictions from multiple models weighted by evidence:

# Fit multiple models
models = [power_law, cross, carreau_yasuda]
weights = compute_model_weights(models)  # Based on AIC

# Weighted average predictions
predictions = sum(w * m.predict(x) for w, m in zip(weights, models))

References

Bayesian Statistics:

  • Gelman, A., et al. (2013). Bayesian Data Analysis, 3rd Edition. Chapman and Hall/CRC.

  • McElreath, R. (2020). Statistical Rethinking, 2nd Edition. CRC Press.

MCMC Methods:

  • Hoffman, M.D., & Gelman, A. (2014). “The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.” Journal of Machine Learning Research, 15, 1593-1623.

NumPyro:

  • Phan, D., et al. (2019). “Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro.” arXiv:1912.11554.

Next Steps