"""
Arrhenius temperature-viscosity model for rheological analysis.
This module implements the Arrhenius equation for modeling temperature-dependent
viscosity using Bayesian inference.
"""
from typing import Any
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from piblin_jax.backend.operations import jit
from piblin_jax.bayesian.base import BayesianModel
[docs]
class ArrheniusModel(BayesianModel):
"""
Arrhenius temperature-viscosity model.
The Arrhenius equation describes how viscosity changes with temperature:
η(T) = A * exp(Ea / (R*T))
where:
- η is the viscosity (Pa·s)
- T is the absolute temperature (K)
- A is the pre-exponential factor (Pa·s)
- Ea is the activation energy (J/mol)
- R is the universal gas constant (8.314 J/(mol·K))
This model is widely used for polymer melts, glass-forming liquids, and
other materials where viscosity decreases exponentially with temperature.
Parameters
----------
n_samples : int, optional
Number of MCMC samples to draw (default: 1000)
n_warmup : int, optional
Number of warmup samples for MCMC (default: 500)
n_chains : int, optional
Number of MCMC chains to run (default: 2)
random_seed : int, optional
Random seed for reproducibility (default: 0)
Attributes
----------
samples : dict[str, array] | None
Posterior samples from MCMC containing:
- 'A': Pre-exponential factor samples
- 'Ea': Activation energy samples
- 'sigma': Observation noise samples
R : float
Universal gas constant (8.314 J/(mol·K))
:no-index:
Examples
--------
>>> import numpy as np
>>> from piblin_jax.bayesian.models import ArrheniusModel
>>>
>>> # Temperature-dependent viscosity data
>>> temperature = np.array([300, 320, 340, 360, 380, 400]) # K
>>> viscosity = np.array([1000, 450, 220, 120, 70, 45]) # Pa·s
>>>
>>> # Fit Arrhenius model
>>> model = ArrheniusModel(n_samples=1000, n_warmup=500)
>>> model.fit(temperature, viscosity)
>>>
>>> # Get activation energy
>>> summary = model.summary()
>>> Ea_mean = summary['Ea']['mean']
>>> print(f"Activation energy: {Ea_mean/1000:.1f} kJ/mol")
>>>
>>> # Predict at new temperature
>>> temp_new = np.array([350])
>>> predictions = model.predict(temp_new)
>>> print(f"Predicted viscosity at 350K: {predictions['mean'][0]:.1f} Pa·s")
Notes
-----
The model uses the following priors:
- A ~ LogNormal(-10, 5): Wide prior on pre-exponential factor
- Ea ~ Normal(50000, 30000): Prior centered around typical activation energies
- sigma ~ HalfNormal(scale): Observation noise (scale = 10% of mean viscosity)
The Arrhenius equation assumes:
- Activation energy is constant over the temperature range
- Single relaxation mechanism (no structural transitions)
- Newtonian behavior at each temperature
For materials with glass transitions or multiple relaxation processes,
consider the Williams-Landel-Ferry (WLF) equation or Vogel-Fulcher-Tammann
(VFT) equation.
References
----------
.. [1] Arrhenius, S. (1889). "Über die Reaktionsgeschwindigkeit bei der
Inversion von Rohrzucker durch Säuren." Zeitschrift für Physikalische
Chemie, 4, 226-248.
.. [2] Ferry, J. D. (1980). "Viscoelastic Properties of Polymers," 3rd ed.
Wiley, New York.
"""
# Universal gas constant (J/(mol·K))
R = 8.314
[docs]
def model(self, x: Any, y: Any = None, **kwargs: Any) -> None:
"""
Define the NumPyro probabilistic model for Arrhenius viscosity.
This method is called internally by the MCMC inference engine and defines
the probabilistic generative process for temperature-dependent viscosity.
Parameters
----------
x : array_like
Temperature data (T) in Kelvin
y : array_like | None, optional
Viscosity observations (η) in Pa·s. If None, generates prior samples.
**kwargs : dict
Additional model parameters (unused)
Examples
--------
This method is typically not called directly. Instead, use the fit() method:
>>> model = ArrheniusModel()
>>> model.fit(temperature, rate_constant) # Internally calls model()
Notes
-----
This method is called internally by fit() and should not be called directly.
It defines the generative model η(T) = A * exp(Ea/(R*T)) + ε where ε ~ Normal(0, σ).
"""
# Convert to JAX arrays
x = jnp.asarray(x)
if y is not None:
y = jnp.asarray(y)
# Priors
# A: Pre-exponential factor (very small positive value)
# Using log-normal with wide variance to accommodate exponential scaling
A = numpyro.sample("A", dist.LogNormal(-10.0, 5.0))
# Ea: Activation energy (J/mol)
# Typical range: 20-100 kJ/mol = 20000-100000 J/mol
Ea = numpyro.sample("Ea", dist.Normal(50000.0, 30000.0))
# sigma: Observation noise
# Use adaptive scale based on mean viscosity if available
if y is not None:
sigma_scale = jnp.maximum(jnp.mean(y) * 0.1, 0.01)
else:
sigma_scale = jnp.array(1.0)
sigma = numpyro.sample("sigma", dist.HalfNormal(sigma_scale))
# Model: η(T) = A * exp(Ea / (R*T))
eta_pred = A * jnp.exp(Ea / (self.R * x))
# Likelihood
with numpyro.plate("data", x.shape[0]):
numpyro.sample("obs", dist.Normal(eta_pred, sigma), obs=y)
@staticmethod
@jit
def _compute_predictions(A_samples, Ea_samples, temperature, R): # type: ignore[no-untyped-def]
"""
JIT-compiled prediction computation for 5-10x speedup.
Parameters
----------
A_samples : array
Posterior samples for pre-exponential factor A
Ea_samples : array
Posterior samples for activation energy Ea
temperature : array
Temperature values to predict at (Kelvin)
R : float
Universal gas constant
Returns
-------
array
Predicted viscosity samples (n_samples × n_points)
Notes
-----
This function is JIT-compiled with JAX for optimal performance.
First call will be slower due to compilation, but subsequent calls
will be 5-10x faster on CPU and up to 100x faster on GPU.
"""
return A_samples[:, None] * jnp.exp(Ea_samples[:, None] / (R * temperature[None, :]))
[docs]
def predict(self, temperature: Any, credible_interval: float = 0.95) -> dict[str, np.ndarray]:
"""
Predict viscosity with uncertainty at given temperatures.
Uses posterior samples from MCMC to generate predictions with
credible intervals.
Parameters
----------
temperature : array_like
Temperature values (T) in Kelvin at which to predict viscosity
credible_interval : float, optional
Credible interval level between 0 and 1 (default: 0.95)
Returns
-------
dict
Dictionary containing:
- 'mean': Mean predicted viscosity (array)
- 'lower': Lower credible bound (array)
- 'upper': Upper credible bound (array)
- 'samples': Full posterior predictive samples (2D array)
Raises
------
RuntimeError
If model has not been fit yet
Examples
--------
>>> model = ArrheniusModel()
>>> model.fit(temp_data, viscosity_data)
>>> predictions = model.predict(np.array([300, 350, 400]))
>>> print(predictions['mean'])
[980.5 165.3 48.2]
"""
if self._samples is None:
raise RuntimeError("Model must be fit before prediction")
# Convert to JAX arrays for JIT compilation
temperature = jnp.asarray(temperature)
A_samples = jnp.asarray(self._samples["A"])
Ea_samples = jnp.asarray(self._samples["Ea"])
# Use JIT-compiled prediction: 5-10x faster on CPU, up to 100x on GPU
eta_samples = self._compute_predictions(A_samples, Ea_samples, temperature, self.R)
# Compute statistics
mean = jnp.mean(eta_samples, axis=0)
alpha = 1 - credible_interval
lower = jnp.percentile(eta_samples, 100 * alpha / 2, axis=0)
upper = jnp.percentile(eta_samples, 100 * (1 - alpha / 2), axis=0)
return {
"mean": np.array(mean),
"lower": np.array(lower),
"upper": np.array(upper),
"samples": np.array(eta_samples),
}