"""
Carreau-Yasuda viscosity model for rheological analysis.
This module implements the Carreau-Yasuda model, which generalizes the Carreau
model for more flexible fitting of shear-thinning fluids.
"""
from typing import Any
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from piblin_jax.backend.operations import jit
from piblin_jax.bayesian.base import BayesianModel
[docs]
class CarreauYasudaModel(BayesianModel):
"""
Carreau-Yasuda viscosity model for shear-thinning fluids.
The Carreau-Yasuda model is a generalization of the Carreau model that
provides better flexibility in the transition region:
η(γ̇) = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a)
where:
- η is the viscosity (Pa·s)
- γ̇ is the shear rate (s⁻¹)
- η₀ is the zero-shear viscosity (Pa·s)
- η∞ is the infinite-shear viscosity (Pa·s)
- λ is the relaxation time (s)
- a is the transition parameter (dimensionless)
- n is the power-law index (dimensionless)
Special cases:
- a → ∞: Reduces to power-law model in transition region
- a = 2: Carreau model
- a = 1: Cross model (approximately)
Parameters
----------
n_samples : int, optional
Number of MCMC samples to draw (default: 1000)
n_warmup : int, optional
Number of warmup samples for MCMC (default: 500)
n_chains : int, optional
Number of MCMC chains to run (default: 2)
random_seed : int, optional
Random seed for reproducibility (default: 0)
Attributes
----------
samples : dict[str, array] | None
Posterior samples from MCMC containing:
- 'eta0': Zero-shear viscosity samples
- 'eta_inf': Infinite-shear viscosity samples
- ``'lambda_'``: Relaxation time samples
- 'a': Transition parameter samples
- 'n': Power-law index samples
- 'sigma': Observation noise samples
Examples
--------
>>> import numpy as np
>>> from piblin_jax.bayesian.models import CarreauYasudaModel
>>>
>>> # Generate synthetic Carreau-Yasuda data
>>> shear_rate = np.logspace(-2, 3, 50)
>>> eta0, eta_inf, lam, a, n = 100.0, 1.0, 1.0, 2.0, 0.5
>>> viscosity = eta_inf + (eta0 - eta_inf) * (1 + (lam * shear_rate)**a)**((n-1)/a)
>>>
>>> # Fit model
>>> model = CarreauYasudaModel(n_samples=1000, n_warmup=500)
>>> model.fit(shear_rate, viscosity)
>>>
>>> # Get parameter estimates
>>> summary = model.summary()
>>> print(f"Zero-shear viscosity: {summary['eta0']['mean']:.1f} Pa·s")
>>> print(f"Power-law index: {summary['n']['mean']:.2f}")
>>>
>>> # Predict
>>> predictions = model.predict(np.array([0.1, 1.0, 10.0]))
Notes
-----
The model uses the following priors:
- eta0 ~ LogNormal(4, 2): Zero-shear viscosity
- eta_inf ~ LogNormal(0, 2): Infinite-shear viscosity
- ``lambda_`` ~ LogNormal(0, 2): Relaxation time
- a ~ LogNormal(0.69, 0.5): Transition parameter (centered around 2)
- n ~ Normal(0.5, 0.3): Power-law index (shear-thinning range)
- sigma ~ HalfNormal(scale): Observation noise
The Carreau-Yasuda model:
- Five parameters provide maximum flexibility
- Captures smooth transition between plateaus
- Parameter 'a' controls transition sharpness
- Reduces to simpler models as special cases
The model assumes:
- Monotonic shear-thinning behavior
- No yield stress
- Single relaxation mechanism
For materials with yield stress, consider the Herschel-Bulkley model.
For simpler analysis with fewer parameters, use Cross or Carreau models.
References
----------
.. [1] Yasuda, K., Armstrong, R. C., & Cohen, R. E. (1981). "Shear flow
properties of concentrated solutions of linear and star branched
polystyrenes." Rheologica Acta, 20(2), 163-178.
.. [2] Carreau, P. J. (1972). "Rheological equations from molecular network
theories." Transactions of the Society of Rheology, 16(1), 99-127.
.. [3] Bird, R. B., Armstrong, R. C., & Hassager, O. (1987). "Dynamics of
Polymeric Liquids. Vol. 1: Fluid Mechanics." Wiley, New York.
"""
[docs]
def model(self, x: Any, y: Any = None, **kwargs: Any) -> None:
"""
Define the NumPyro probabilistic model for Carreau-Yasuda viscosity.
Parameters
----------
x : array_like
Shear rate data (γ̇) in s⁻¹
y : array_like | None, optional
Viscosity observations (η) in Pa·s. If None, generates prior samples.
**kwargs : dict
Additional model parameters (unused)
Notes
-----
This method is called internally by fit() and should not be called directly.
"""
# Convert to JAX arrays
x = jnp.asarray(x)
if y is not None:
y = jnp.asarray(y)
# Priors
# eta0: Zero-shear viscosity (should be > eta_inf)
eta0 = numpyro.sample("eta0", dist.LogNormal(4.0, 2.0))
# eta_inf: Infinite-shear viscosity
eta_inf = numpyro.sample("eta_inf", dist.LogNormal(0.0, 2.0))
# lambda_: Relaxation time
lambda_ = numpyro.sample("lambda_", dist.LogNormal(0.0, 2.0))
# a: Transition parameter (Yasuda parameter)
# Centered around 2 (Carreau model), but allow flexibility
a = numpyro.sample("a", dist.LogNormal(0.69, 0.5)) # log(2) ≈ 0.69
# n: Power-law index (< 1 for shear-thinning)
n = numpyro.sample("n", dist.Normal(0.5, 0.3))
# sigma: Observation noise
if y is not None:
sigma_scale = jnp.maximum(jnp.std(y) * 0.1, 0.01)
else:
sigma_scale = jnp.array(1.0)
sigma = numpyro.sample("sigma", dist.HalfNormal(sigma_scale))
# Model: η = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a)
eta_pred = eta_inf + (eta0 - eta_inf) * (1 + (lambda_ * x) ** a) ** ((n - 1) / a)
# Likelihood
with numpyro.plate("data", x.shape[0]):
numpyro.sample("obs", dist.Normal(eta_pred, sigma), obs=y)
@staticmethod
@jit
def _compute_predictions( # type: ignore[no-untyped-def]
eta0_samples, eta_inf_samples, lambda_samples, a_samples, n_samples, shear_rate
):
"""
JIT-compiled prediction computation for 5-10x speedup.
Parameters
----------
eta0_samples : array
Posterior samples for zero-shear viscosity η₀
eta_inf_samples : array
Posterior samples for infinite-shear viscosity η∞
lambda_samples : array
Posterior samples for time constant λ
a_samples : array
Posterior samples for Yasuda parameter a
n_samples : array
Posterior samples for power-law index n
shear_rate : array
Shear rate values to predict at
Returns
-------
array
Predicted viscosity samples (n_samples × n_points)
Notes
-----
This function is JIT-compiled with JAX for optimal performance.
First call will be slower due to compilation, but subsequent calls
will be 5-10x faster on CPU and up to 100x faster on GPU.
Model: η = η∞ + (η₀ - η∞) * [1 + (λγ̇)^a]^((n-1)/a)
"""
return eta_inf_samples[:, None] + (eta0_samples[:, None] - eta_inf_samples[:, None]) * (
1 + (lambda_samples[:, None] * shear_rate[None, :]) ** a_samples[:, None]
) ** ((n_samples[:, None] - 1) / a_samples[:, None])
[docs]
def predict(self, shear_rate: Any, credible_interval: float = 0.95) -> dict[str, np.ndarray]:
"""
Predict viscosity with uncertainty at given shear rates.
Uses posterior samples from MCMC to generate predictions with
credible intervals.
Parameters
----------
shear_rate : array_like
Shear rate values (γ̇) in s⁻¹ at which to predict viscosity
credible_interval : float, optional
Credible interval level between 0 and 1 (default: 0.95)
Returns
-------
dict
Dictionary containing:
- 'mean': Mean predicted viscosity (array)
- 'lower': Lower credible bound (array)
- 'upper': Upper credible bound (array)
- 'samples': Full posterior predictive samples (2D array)
Raises
------
RuntimeError
If model has not been fit yet
Examples
--------
>>> model = CarreauYasudaModel()
>>> model.fit(shear_rate_data, viscosity_data)
>>> predictions = model.predict(np.array([0.1, 1.0, 10.0, 100.0]))
>>> print(predictions['mean'])
[95.3 48.2 12.5 2.8]
"""
if self._samples is None:
raise RuntimeError("Model must be fit before prediction")
# Convert to JAX arrays for JIT compilation
shear_rate = jnp.asarray(shear_rate)
eta0_samples = jnp.asarray(self._samples["eta0"])
eta_inf_samples = jnp.asarray(self._samples["eta_inf"])
lambda_samples = jnp.asarray(self._samples["lambda_"])
a_samples = jnp.asarray(self._samples["a"])
n_samples = jnp.asarray(self._samples["n"])
# Use JIT-compiled prediction: 5-10x faster on CPU, up to 100x on GPU
eta_samples = self._compute_predictions(
eta0_samples, eta_inf_samples, lambda_samples, a_samples, n_samples, shear_rate
)
# Compute statistics
mean = jnp.mean(eta_samples, axis=0)
alpha = 1 - credible_interval
lower = jnp.percentile(eta_samples, 100 * alpha / 2, axis=0)
upper = jnp.percentile(eta_samples, 100 * (1 - alpha / 2), axis=0)
return {
"mean": np.array(mean),
"lower": np.array(lower),
"upper": np.array(upper),
"samples": np.array(eta_samples),
}