Source code for piblin_jax.bayesian.models.power_law

"""
Power-law viscosity model for rheological analysis.

This module implements the power-law (Ostwald-de Waele) model for shear-thinning
and shear-thickening fluids using Bayesian inference.
"""

from typing import Any

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

from piblin_jax.backend.operations import jit
from piblin_jax.bayesian.base import BayesianModel


[docs] class PowerLawModel(BayesianModel): """ Power-law viscosity model for non-Newtonian fluids. The power-law model describes the relationship between viscosity and shear rate: η(γ̇) = K * γ̇^(n-1) where: - η is the viscosity (Pa·s) - γ̇ is the shear rate (s⁻¹) - K is the consistency index (Pa·s^n) - n is the power-law index (dimensionless) The power-law index n characterizes the flow behavior: - n < 1: Shear-thinning (pseudoplastic) behavior - n = 1: Newtonian behavior - n > 1: Shear-thickening (dilatant) behavior Parameters ---------- n_samples : int, optional Number of MCMC samples to draw (default: 1000) n_warmup : int, optional Number of warmup samples for MCMC (default: 500) n_chains : int, optional Number of MCMC chains to run (default: 2) random_seed : int, optional Random seed for reproducibility (default: 0) Attributes ---------- samples : dict[str, array] | None Posterior samples from MCMC containing: - 'K': Consistency index samples - 'n': Power-law index samples - 'sigma': Observation noise samples Examples -------- >>> import numpy as np >>> from piblin_jax.bayesian.models import PowerLawModel >>> >>> # Generate synthetic power-law data >>> shear_rate = np.logspace(-1, 2, 30) # 0.1 to 100 s^-1 >>> viscosity = 5.0 * shear_rate ** (0.6 - 1) # K=5, n=0.6 (shear-thinning) >>> >>> # Fit model >>> model = PowerLawModel(n_samples=1000, n_warmup=500) >>> model.fit(shear_rate, viscosity) >>> >>> # Get parameter estimates >>> summary = model.summary() >>> print(f"K: {summary['K']['mean']:.2f} +/- {summary['K']['std']:.2f}") >>> print(f"n: {summary['n']['mean']:.2f} +/- {summary['n']['std']:.2f}") >>> >>> # Predict with uncertainty >>> shear_rate_new = np.array([1.0, 10.0, 50.0]) >>> predictions = model.predict(shear_rate_new, credible_interval=0.95) >>> print(f"Predicted viscosity at γ̇=10: {predictions['mean'][1]:.2f}") >>> print(f"95% CI: [{predictions['lower'][1]:.2f}, {predictions['upper'][1]:.2f}]") Notes ----- The model uses the following priors: - K ~ LogNormal(0, 2): Ensures positive consistency index - n ~ Normal(0.5, 0.5): Centers around typical shear-thinning behavior - sigma ~ HalfNormal(1): Observation noise The power-law model is simple but effective for many non-Newtonian fluids. However, it has limitations: - Predicts infinite viscosity at zero shear rate (for n < 1) - Predicts zero viscosity at infinite shear rate (for n < 1) - Does not capture zero-shear or infinite-shear plateaus For fluids with plateau regions, consider using CrossModel or CarreauYasudaModel. References ---------- .. [1] Ostwald, W. (1925). "Über die rechnerische Darstellung des Strukturgebietes der Viskosität." Kolloid-Zeitschrift, 36, 99-117. .. [2] de Waele, A. (1923). "Viscometry and plastometry." Journal of the Oil and Colour Chemists Association, 6, 33-69. """
[docs] def model(self, x: Any, y: Any = None, **kwargs: Any) -> None: """ Define the NumPyro probabilistic model for power-law viscosity. This method is called internally by the MCMC inference engine and defines the probabilistic generative process for power-law viscosity data. Parameters ---------- x : array_like Shear rate data (γ̇) in s⁻¹ y : array_like | None, optional Viscosity observations (η) in Pa·s. If None, generates prior samples. **kwargs : dict Additional model parameters (unused) Examples -------- This method is typically not called directly. Instead, use the fit() method: >>> model = PowerLawModel() >>> model.fit(shear_rate, viscosity) # Internally calls model() Notes ----- This method is called internally by fit() and should not be called directly. It defines the generative model η = K * γ̇^(n-1) + ε where ε ~ Normal(0, σ). """ # Convert to JAX arrays x = jnp.asarray(x) if y is not None: y = jnp.asarray(y) # Priors # K: Consistency index (positive, log-normal prior) K = numpyro.sample("K", dist.LogNormal(0.0, 2.0)) # n: Power-law index (typically between 0 and 2) n = numpyro.sample("n", dist.Normal(0.5, 0.5)) # sigma: Observation noise (positive) sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Model: η = K * γ̇^(n-1) eta_pred = K * x ** (n - 1) # Likelihood with numpyro.plate("data", x.shape[0]): numpyro.sample("obs", dist.Normal(eta_pred, sigma), obs=y)
@staticmethod @jit def _compute_predictions(K_samples, n_samples, shear_rate): # type: ignore[no-untyped-def] """ JIT-compiled prediction computation for 5-10x speedup. Parameters ---------- K_samples : array Posterior samples for consistency index K n_samples : array Posterior samples for power-law index n shear_rate : array Shear rate values to predict at Returns ------- array Predicted viscosity samples (n_samples × n_points) Notes ----- This function is JIT-compiled with JAX for optimal performance. First call will be slower due to compilation, but subsequent calls will be 5-10x faster on CPU and up to 100x faster on GPU. """ return K_samples[:, None] * shear_rate[None, :] ** (n_samples[:, None] - 1)
[docs] def predict(self, shear_rate: Any, credible_interval: float = 0.95) -> dict[str, np.ndarray]: """ Predict viscosity with uncertainty at given shear rates. Uses posterior samples from MCMC to generate predictions with credible intervals. Parameters ---------- shear_rate : array_like Shear rate values (γ̇) in s⁻¹ at which to predict viscosity credible_interval : float, optional Credible interval level between 0 and 1 (default: 0.95) Returns ------- dict Dictionary containing: - 'mean': Mean predicted viscosity (array) - 'lower': Lower credible bound (array) - 'upper': Upper credible bound (array) - 'samples': Full posterior predictive samples (2D array) Raises ------ RuntimeError If model has not been fit yet Examples -------- >>> model = PowerLawModel() >>> model.fit(shear_rate_data, viscosity_data) >>> predictions = model.predict(np.array([1.0, 10.0, 100.0])) >>> print(predictions['mean']) [5.23 2.41 1.11] >>> print(predictions['lower']) [4.89 2.21 1.01] >>> print(predictions['upper']) [5.61 2.65 1.23] """ if self._samples is None: raise RuntimeError("Model must be fit before prediction") # Convert to JAX arrays for JIT compilation shear_rate = jnp.asarray(shear_rate) K_samples = jnp.asarray(self._samples["K"]) n_samples = jnp.asarray(self._samples["n"]) # Use JIT-compiled prediction: 5-10x faster on CPU, up to 100x on GPU eta_samples = self._compute_predictions(K_samples, n_samples, shear_rate) # Compute statistics mean = jnp.mean(eta_samples, axis=0) alpha = 1 - credible_interval lower = jnp.percentile(eta_samples, 100 * alpha / 2, axis=0) upper = jnp.percentile(eta_samples, 100 * (1 - alpha / 2), axis=0) return { "mean": np.array(mean), "lower": np.array(lower), "upper": np.array(upper), "samples": np.array(eta_samples), }