Uncertainty Quantification
This guide provides a comprehensive overview of uncertainty quantification in piblin-jax, covering both the theoretical foundations and practical implementation details.
Overview
piblin-jax provides comprehensive Bayesian uncertainty quantification through NumPyro, a probabilistic programming framework built on JAX. This enables:
- Full posterior distributions
Not just point estimates, but complete probability distributions over parameter values given your data.
- Credible intervals
Rigorous uncertainty ranges with clear probabilistic interpretation (e.g., “95% probability the parameter is in this range”).
- Uncertainty propagation
Automatically propagate uncertainty through transforms and computations.
- Model comparison
Principled comparison of competing models using Bayesian evidence.
- GPU acceleration
Leverages JAX for fast MCMC sampling on GPU hardware.
Why Bayesian Methods?
Traditional vs Bayesian Fitting
Traditional Least-Squares:
Provides point estimates of parameters
Standard errors assume asymptotic normality
Difficult to incorporate prior knowledge
No natural way to compare models
Bayesian Approach:
Provides full posterior distributions
Credible intervals have direct probabilistic meaning
Naturally incorporates prior information
Model comparison via Bayes factors or information criteria
Uncertainty propagation is straightforward
Probabilistic Interpretation
Consider fitting a power-law model to viscosity data:
- Least-squares result:
K = 5.0 ± 0.3 (standard error)
This means: “If we repeated the experiment many times and computed K each time, 68% of those estimates would fall within ±0.3 of the true value.”
- Bayesian result:
K = 5.0 [4.5, 5.5] (95% credible interval)
This means: “Given the data, there’s a 95% probability that the true K is between 4.5 and 5.5.”
The Bayesian interpretation is more intuitive and directly answers the question “What do we know about this parameter?”
Bayesian Workflow
Standard Analysis Pipeline
Specify the Model Define the mathematical relationship between parameters and data:
from piblin_jax.bayesian.models import PowerLawModel model = PowerLawModel(n_samples=2000)
Fit the Model Run MCMC to sample from the posterior distribution:
model.fit(shear_rate, viscosity)
Check Diagnostics Verify that MCMC chains have converged:
if model.is_fitted: print("Sampling successful") else: print("Warning: Check convergence diagnostics")
Examine Posterior Analyze parameter distributions:
print(model.summary()) samples = model.samples
Make Predictions Generate predictions with uncertainty:
predictions = model.predict(new_x, return_uncertainty=True)
Compare Models Evaluate competing models:
aic = model.aic() bic = model.bic()
Built-in Models
Power-Law Model
Equation:
Parameters:
K: consistency index (Pa·s^n)
n: flow behavior index (dimensionless)
sigma: observation noise standard deviation
Use case: Shear-thinning/thickening fluids
Example:
from piblin_jax.bayesian.models import PowerLawModel
import numpy as np
# Data
shear_rate = np.logspace(-1, 2, 20)
viscosity = np.array([...]) # Experimental data
# Fit model
model = PowerLawModel(n_samples=2000, n_warmup=1000)
model.fit(shear_rate, viscosity)
# View results
print(model.summary())
# Extract samples
K_samples = model.samples['K']
n_samples = model.samples['n']
Priors:
K ~ LogNormal(log(10), 2.0): Weakly informative, centered at 10 Pa·s^n
n ~ Normal(0.8, 0.5): Weakly informative, centered at shear-thinning
sigma ~ HalfNormal(1.0): Observation noise
Arrhenius Model
Equation:
Parameters:
A: pre-exponential factor (Pa·s)
Ea: activation energy (J/mol)
sigma: observation noise
Use case: Temperature-dependent viscosity
Example:
from piblin_jax.bayesian.models import ArrheniusModel
# Temperature data (K)
temperature = np.array([273, 298, 323, 348, 373])
viscosity = np.array([15.2, 8.5, 5.1, 3.2, 2.1])
# Fit model
model = ArrheniusModel(n_samples=2000)
model.fit(temperature, viscosity)
# Extract activation energy
Ea_mean = np.mean(model.samples['Ea'])
print(f"Activation energy: {Ea_mean/1000:.1f} kJ/mol")
Priors:
A ~ LogNormal(log(1e-3), 5.0): Very weak prior
Ea ~ Normal(50000, 20000): Centered at typical liquid Ea
sigma ~ HalfNormal(1.0)
Cross Model
Equation:
Parameters:
η₀: zero-shear viscosity (Pa·s)
η∞: infinite-shear viscosity (Pa·s)
λ: relaxation time (s)
m: rate constant (dimensionless)
sigma: observation noise
Use case: Polymer melts/solutions with zero-shear plateau
Example:
from piblin_jax.bayesian.models import CrossModel
# Wide shear rate range
shear_rate = np.logspace(-3, 3, 50)
viscosity = np.array([...])
# Fit model
model = CrossModel(n_samples=2000)
model.fit(shear_rate, viscosity)
# Extract plateaus
eta_0 = np.mean(model.samples['eta_0'])
eta_inf = np.mean(model.samples['eta_inf'])
print(f"Shear-thinning ratio: {eta_0/eta_inf:.1f}x")
Priors:
η₀ ~ LogNormal(log(100), 2.0)
η∞ ~ LogNormal(log(1), 2.0)
λ ~ LogNormal(log(1), 2.0)
m ~ Normal(0.7, 0.3): Constrained to (0, ∞)
sigma ~ HalfNormal(scale based on data)
Carreau-Yasuda Model
Equation:
Parameters:
η₀, η∞: viscosity plateaus (Pa·s)
λ: time constant (s)
a: transition parameter (dimensionless)
n: power-law index (dimensionless)
sigma: observation noise
Use case: Complex non-Newtonian behavior with smooth transitions
Example:
from piblin_jax.bayesian.models import CarreauYasudaModel
model = CarreauYasudaModel(n_samples=2000)
model.fit(shear_rate, viscosity)
# Most flexible model, but requires good data
# Compare with simpler models using AIC
Priors:
η₀ ~ LogNormal(log(1000), 2.0)
η∞ ~ LogNormal(log(0.1), 2.0)
λ ~ LogNormal(log(1), 2.0)
a ~ LogNormal(log(2), 1.0)
n ~ Normal(0.5, 0.3): Constrained to (0, 1)
sigma ~ HalfNormal(scale based on data)
MCMC Sampling
How MCMC Works
Markov Chain Monte Carlo (MCMC) is an algorithm for sampling from probability distributions that are difficult to sample from directly.
Key concepts:
Markov Chain: Sequence of samples where each depends only on the previous one
Stationary Distribution: Target distribution (posterior) that chain converges to
Burn-in (warmup): Initial samples discarded before convergence
Thinning: Optional subsampling to reduce autocorrelation
NumPyro’s NUTS sampler:
piblin-jax uses the No-U-Turn Sampler (NUTS), a variant of Hamiltonian Monte Carlo:
Automatically tunes step size and trajectory length
More efficient than basic MCMC (Metropolis-Hastings)
Typically requires fewer samples for same accuracy
Works well in high dimensions
Sampling Parameters
- n_samples (default: 1000)
Number of posterior samples to draw after warmup. More samples = better posterior approximation but slower.
1000-2000: Good for most applications
3000-5000: High-precision uncertainty estimates
10000+: Publication-quality results
- n_warmup (default: 1000)
Number of warmup/burn-in samples to discard. Used to tune sampler parameters and reach stationary distribution.
500-1000: Usually sufficient
2000+: Complex models or difficult posteriors
- n_chains (default: 1)
Number of independent MCMC chains to run. Multiple chains help diagnose convergence issues.
1: Fast, but less diagnostic information
4: Standard for convergence diagnostics (R-hat, ESS)
Example:
# High-quality fit with convergence diagnostics
model = PowerLawModel(
n_samples=2000,
n_warmup=1000,
n_chains=4
)
model.fit(shear_rate, viscosity)
Convergence Diagnostics
Always check if sampling succeeded:
if not model.is_fitted:
print("Warning: Sampling may not have converged")
# Increase n_samples or n_warmup
# Check for model misspecification
- R-hat statistic (Gelman-Rubin):
Measures agreement between chains.
R-hat ≈ 1.0: Good convergence
R-hat > 1.1: Poor convergence, increase warmup
- Effective sample size (ESS):
Accounts for autocorrelation in samples.
ESS ≈ n_samples: Low autocorrelation (good)
ESS << n_samples: High autocorrelation (increase samples)
Posterior Analysis
Summary Statistics
The summary() method provides key statistics:
summary = model.summary()
print(summary)
Output:
Parameter Posterior Summary:
----------------------------
K: mean=5.02, std=0.15, 95% CI=[4.73, 5.31]
n: mean=0.598, std=0.012, 95% CI=[0.575, 0.621]
sigma: mean=0.51, std=0.09, 95% CI=[0.38, 0.71]
Interpreting statistics:
mean: Posterior mean (Bayesian point estimate)
std: Posterior standard deviation (uncertainty measure)
95% CI: 95% credible interval (central posterior density)
Accessing Samples
Direct access to posterior samples:
samples = model.samples
# Extract specific parameter
K_samples = samples['K'] # Array of shape (n_samples,)
n_samples = samples['n']
# Custom statistics
K_median = np.median(K_samples)
K_mode = K_samples[np.argmax(np.histogram(K_samples, bins=50)[0])]
# Quantiles
K_quantiles = np.percentile(K_samples, [2.5, 50, 97.5])
Credible Intervals
- Equal-tailed interval (ETI):
Default method. 2.5th and 97.5th percentiles for 95% CI:
lower = np.percentile(K_samples, 2.5) upper = np.percentile(K_samples, 97.5)
- Highest density interval (HDI):
Shortest interval containing 95% of probability mass. Preferred for skewed distributions:
from piblin_jax.bayesian.utils import compute_hdi lower, upper = compute_hdi(K_samples, credible_mass=0.95)
Parameter Correlations
Posterior samples reveal parameter correlations:
import matplotlib.pyplot as plt
# Joint distribution
plt.scatter(samples['K'], samples['n'], alpha=0.3, s=1)
plt.xlabel('K')
plt.ylabel('n')
plt.title('Joint Posterior Distribution')
# Correlation coefficient
corr = np.corrcoef(samples['K'], samples['n'])[0, 1]
print(f"Correlation: {corr:.3f}")
High correlation indicates parameters are not independently identifiable from the data (common in complex models).
Model Comparison
Information Criteria
Akaike Information Criterion (AIC):
Lower is better. Balances fit quality and model complexity:
aic = model.aic()
Bayesian Information Criterion (BIC):
Similar to AIC but penalizes complexity more heavily:
bic = model.bic()
Comparing models:
from piblin_jax.bayesian.models import PowerLawModel, CrossModel
# Fit both models
power_law = PowerLawModel(n_samples=2000)
power_law.fit(shear_rate, viscosity)
cross = CrossModel(n_samples=2000)
cross.fit(shear_rate, viscosity)
# Compare
print(f"Power-law AIC: {power_law.aic():.1f}")
print(f"Cross AIC: {cross.aic():.1f}")
delta_aic = abs(power_law.aic() - cross.aic())
if delta_aic < 2:
print("Models are essentially equivalent")
elif delta_aic < 10:
print("Moderate evidence for preferred model")
else:
print("Strong evidence for preferred model")
Bayes Factors
More rigorous model comparison using marginal likelihoods.
Note: Requires setting enable_bayes_factor=True during fitting:
model = PowerLawModel(n_samples=2000)
model.fit(shear_rate, viscosity, enable_bayes_factor=True)
# Get log marginal likelihood
log_ml = model.log_marginal_likelihood
Uncertainty Propagation
Dataset Integration
Add uncertainty to datasets:
from piblin_jax.data.datasets import OneDimensionalDataset
# Create dataset
dataset = OneDimensionalDataset(
independent_variable_data=shear_rate,
dependent_variable_data=viscosity
)
# Fit Bayesian model
model = PowerLawModel(n_samples=2000)
model.fit(shear_rate, viscosity)
# Add uncertainty to dataset
dataset_with_unc = dataset.with_uncertainty(
model=model,
n_samples=1000,
keep_samples=True
)
# Check status
print(f"Has uncertainty: {dataset_with_unc.has_uncertainty}")
Transform Propagation
Propagate uncertainty through transforms:
from piblin_jax.transform.dataset import GaussianSmoothing
# Apply transform with uncertainty propagation
smoother = GaussianSmoothing(sigma=2.0)
smoothed = smoother.apply_to(
dataset_with_unc,
propagate_uncertainty=True
)
# Uncertainty is now propagated
lower, upper = smoothed.get_credible_intervals(level=0.95)
Monte Carlo Propagation
For custom operations, use Monte Carlo:
# Function to propagate uncertainty through
def my_operation(K, n, shear_rate):
return K * shear_rate ** (n - 1)
# Sample-based propagation
results = []
for i in range(len(samples['K'])):
K_i = samples['K'][i]
n_i = samples['n'][i]
result_i = my_operation(K_i, n_i, new_shear_rate)
results.append(result_i)
results = np.array(results)
# Compute uncertainty
mean_result = np.mean(results, axis=0)
lower_result = np.percentile(results, 2.5, axis=0)
upper_result = np.percentile(results, 97.5, axis=0)
Best Practices
Model Selection
Start simple: Begin with power-law or Arrhenius
Check residuals: Look for systematic patterns
Add complexity: Move to Cross or Carreau-Yasuda if needed
Compare formally: Use AIC/BIC to justify complexity
Physical meaning: Prefer models with interpretable parameters
Prior Selection
Weakly informative priors (default in piblin-jax):
Constrain parameters to physically reasonable ranges
Don’t dominate the data
Help with numerical stability
Custom priors:
For domain expertise, modify priors:
# Example: Strong prior on power-law index
# (requires model subclassing, see API docs)
Computational Efficiency
- Start small:
Test with n_samples=500, n_warmup=500 first
- Increase gradually:
Double samples until results stabilize
- Use GPU:
Install JAX with GPU support for 10-100x speedup:
pip install "jax[cuda12]" # NVIDIA pip install "jax[metal]" # Apple Silicon
- Parallel chains:
Use n_chains=4 on multi-core CPU
Common Issues
Poor convergence (R-hat > 1.1):
Increase n_warmup
Try different initial values
Simplify the model
Low ESS:
Increase n_samples
Check for high parameter correlation
Consider reparameterization
Unrealistic posteriors:
Check data scaling (avoid extreme values)
Verify model is appropriate for data
Inspect prior sensitivity
Slow sampling:
Reduce n_samples for initial exploration
Use GPU acceleration
Consider simpler model
Advanced Topics
Custom Models
Subclass BayesianModel to create custom models.
See API documentation for details.
Hierarchical Models
Model variation across groups (e.g., multiple experiments). Requires custom NumPyro model definition.
Model Averaging
Combine predictions from multiple models weighted by evidence:
# Fit multiple models
models = [power_law, cross, carreau_yasuda]
weights = compute_model_weights(models) # Based on AIC
# Weighted average predictions
predictions = sum(w * m.predict(x) for w, m in zip(weights, models))
References
Bayesian Statistics:
Gelman, A., et al. (2013). Bayesian Data Analysis, 3rd Edition. Chapman and Hall/CRC.
McElreath, R. (2020). Statistical Rethinking, 2nd Edition. CRC Press.
MCMC Methods:
Hoffman, M.D., & Gelman, A. (2014). “The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo.” Journal of Machine Learning Research, 15, 1593-1623.
NumPyro:
Phan, D., et al. (2019). “Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro.” arXiv:1912.11554.
Next Steps
See Uncertainty Quantification Tutorial for step-by-step examples
See Rheological Models Tutorial for model-specific guidance
See API reference for detailed method documentation
See
examples/bayesian_*.pyfor complete working examples