Source code for piblin_jax.bayesian.models.carreau_yasuda

"""
Carreau-Yasuda viscosity model for rheological analysis.

This module implements the Carreau-Yasuda model, which generalizes the Carreau
model for more flexible fitting of shear-thinning fluids.
"""

from typing import Any

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist

from piblin_jax.backend.operations import jit
from piblin_jax.bayesian.base import BayesianModel


[docs] class CarreauYasudaModel(BayesianModel): """ Carreau-Yasuda viscosity model for shear-thinning fluids. The Carreau-Yasuda model is a generalization of the Carreau model that provides better flexibility in the transition region: η(γ̇) = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a) where: - η is the viscosity (Pa·s) - γ̇ is the shear rate (s⁻¹) - η₀ is the zero-shear viscosity (Pa·s) - η∞ is the infinite-shear viscosity (Pa·s) - λ is the relaxation time (s) - a is the transition parameter (dimensionless) - n is the power-law index (dimensionless) Special cases: - a → ∞: Reduces to power-law model in transition region - a = 2: Carreau model - a = 1: Cross model (approximately) Parameters ---------- n_samples : int, optional Number of MCMC samples to draw (default: 1000) n_warmup : int, optional Number of warmup samples for MCMC (default: 500) n_chains : int, optional Number of MCMC chains to run (default: 2) random_seed : int, optional Random seed for reproducibility (default: 0) Attributes ---------- samples : dict[str, array] | None Posterior samples from MCMC containing: - 'eta0': Zero-shear viscosity samples - 'eta_inf': Infinite-shear viscosity samples - ``'lambda_'``: Relaxation time samples - 'a': Transition parameter samples - 'n': Power-law index samples - 'sigma': Observation noise samples Examples -------- >>> import numpy as np >>> from piblin_jax.bayesian.models import CarreauYasudaModel >>> >>> # Generate synthetic Carreau-Yasuda data >>> shear_rate = np.logspace(-2, 3, 50) >>> eta0, eta_inf, lam, a, n = 100.0, 1.0, 1.0, 2.0, 0.5 >>> viscosity = eta_inf + (eta0 - eta_inf) * (1 + (lam * shear_rate)**a)**((n-1)/a) >>> >>> # Fit model >>> model = CarreauYasudaModel(n_samples=1000, n_warmup=500) >>> model.fit(shear_rate, viscosity) >>> >>> # Get parameter estimates >>> summary = model.summary() >>> print(f"Zero-shear viscosity: {summary['eta0']['mean']:.1f} Pa·s") >>> print(f"Power-law index: {summary['n']['mean']:.2f}") >>> >>> # Predict >>> predictions = model.predict(np.array([0.1, 1.0, 10.0])) Notes ----- The model uses the following priors: - eta0 ~ LogNormal(4, 2): Zero-shear viscosity - eta_inf ~ LogNormal(0, 2): Infinite-shear viscosity - ``lambda_`` ~ LogNormal(0, 2): Relaxation time - a ~ LogNormal(0.69, 0.5): Transition parameter (centered around 2) - n ~ Normal(0.5, 0.3): Power-law index (shear-thinning range) - sigma ~ HalfNormal(scale): Observation noise The Carreau-Yasuda model: - Five parameters provide maximum flexibility - Captures smooth transition between plateaus - Parameter 'a' controls transition sharpness - Reduces to simpler models as special cases The model assumes: - Monotonic shear-thinning behavior - No yield stress - Single relaxation mechanism For materials with yield stress, consider the Herschel-Bulkley model. For simpler analysis with fewer parameters, use Cross or Carreau models. References ---------- .. [1] Yasuda, K., Armstrong, R. C., & Cohen, R. E. (1981). "Shear flow properties of concentrated solutions of linear and star branched polystyrenes." Rheologica Acta, 20(2), 163-178. .. [2] Carreau, P. J. (1972). "Rheological equations from molecular network theories." Transactions of the Society of Rheology, 16(1), 99-127. .. [3] Bird, R. B., Armstrong, R. C., & Hassager, O. (1987). "Dynamics of Polymeric Liquids. Vol. 1: Fluid Mechanics." Wiley, New York. """
[docs] def model(self, x: Any, y: Any = None, **kwargs: Any) -> None: """ Define the NumPyro probabilistic model for Carreau-Yasuda viscosity. Parameters ---------- x : array_like Shear rate data (γ̇) in s⁻¹ y : array_like | None, optional Viscosity observations (η) in Pa·s. If None, generates prior samples. **kwargs : dict Additional model parameters (unused) Notes ----- This method is called internally by fit() and should not be called directly. """ # Convert to JAX arrays x = jnp.asarray(x) if y is not None: y = jnp.asarray(y) # Priors # eta0: Zero-shear viscosity (should be > eta_inf) eta0 = numpyro.sample("eta0", dist.LogNormal(4.0, 2.0)) # eta_inf: Infinite-shear viscosity eta_inf = numpyro.sample("eta_inf", dist.LogNormal(0.0, 2.0)) # lambda_: Relaxation time lambda_ = numpyro.sample("lambda_", dist.LogNormal(0.0, 2.0)) # a: Transition parameter (Yasuda parameter) # Centered around 2 (Carreau model), but allow flexibility a = numpyro.sample("a", dist.LogNormal(0.69, 0.5)) # log(2) ≈ 0.69 # n: Power-law index (< 1 for shear-thinning) n = numpyro.sample("n", dist.Normal(0.5, 0.3)) # sigma: Observation noise if y is not None: sigma_scale = jnp.maximum(jnp.std(y) * 0.1, 0.01) else: sigma_scale = jnp.array(1.0) sigma = numpyro.sample("sigma", dist.HalfNormal(sigma_scale)) # Model: η = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a) eta_pred = eta_inf + (eta0 - eta_inf) * (1 + (lambda_ * x) ** a) ** ((n - 1) / a) # Likelihood with numpyro.plate("data", x.shape[0]): numpyro.sample("obs", dist.Normal(eta_pred, sigma), obs=y)
@staticmethod @jit def _compute_predictions( # type: ignore[no-untyped-def] eta0_samples, eta_inf_samples, lambda_samples, a_samples, n_samples, shear_rate ): """ JIT-compiled prediction computation for 5-10x speedup. Parameters ---------- eta0_samples : array Posterior samples for zero-shear viscosity η₀ eta_inf_samples : array Posterior samples for infinite-shear viscosity η∞ lambda_samples : array Posterior samples for time constant λ a_samples : array Posterior samples for Yasuda parameter a n_samples : array Posterior samples for power-law index n shear_rate : array Shear rate values to predict at Returns ------- array Predicted viscosity samples (n_samples × n_points) Notes ----- This function is JIT-compiled with JAX for optimal performance. First call will be slower due to compilation, but subsequent calls will be 5-10x faster on CPU and up to 100x faster on GPU. Model: η = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a) """ return eta_inf_samples[:, None] + (eta0_samples[:, None] - eta_inf_samples[:, None]) * ( 1 + (lambda_samples[:, None] * shear_rate[None, :]) ** a_samples[:, None] ) ** ((n_samples[:, None] - 1) / a_samples[:, None])
[docs] def predict(self, shear_rate: Any, credible_interval: float = 0.95) -> dict[str, np.ndarray]: """ Predict viscosity with uncertainty at given shear rates. Uses posterior samples from MCMC to generate predictions with credible intervals. Parameters ---------- shear_rate : array_like Shear rate values (γ̇) in s⁻¹ at which to predict viscosity credible_interval : float, optional Credible interval level between 0 and 1 (default: 0.95) Returns ------- dict Dictionary containing: - 'mean': Mean predicted viscosity (array) - 'lower': Lower credible bound (array) - 'upper': Upper credible bound (array) - 'samples': Full posterior predictive samples (2D array) Raises ------ RuntimeError If model has not been fit yet Examples -------- >>> model = CarreauYasudaModel() >>> model.fit(shear_rate_data, viscosity_data) >>> predictions = model.predict(np.array([0.1, 1.0, 10.0, 100.0])) >>> print(predictions['mean']) [95.3 48.2 12.5 2.8] """ if self._samples is None: raise RuntimeError("Model must be fit before prediction") # Convert to JAX arrays for JIT compilation shear_rate = jnp.asarray(shear_rate) eta0_samples = jnp.asarray(self._samples["eta0"]) eta_inf_samples = jnp.asarray(self._samples["eta_inf"]) lambda_samples = jnp.asarray(self._samples["lambda_"]) a_samples = jnp.asarray(self._samples["a"]) n_samples = jnp.asarray(self._samples["n"]) # Use JIT-compiled prediction: 5-10x faster on CPU, up to 100x on GPU eta_samples = self._compute_predictions( eta0_samples, eta_inf_samples, lambda_samples, a_samples, n_samples, shear_rate ) # Compute statistics mean = jnp.mean(eta_samples, axis=0) alpha = 1 - credible_interval lower = jnp.percentile(eta_samples, 100 * alpha / 2, axis=0) upper = jnp.percentile(eta_samples, 100 * (1 - alpha / 2), axis=0) return { "mean": np.array(mean), "lower": np.array(lower), "upper": np.array(upper), "samples": np.array(eta_samples), }