| x | y | |
|---|---|---|
| 0 | 0.000000 | 3.113179 |
| 1 | 0.010101 | 3.774512 |
| 2 | 0.020202 | 4.045562 |
| 3 | 0.030303 | 3.207971 |
| 4 | 0.040404 | 3.336638 |
| ... | ... | ... |
| 95 | 0.959596 | 1.951793 |
| 96 | 0.969697 | 0.224769 |
| 97 | 0.979798 | -0.387220 |
| 98 | 0.989899 | 1.304032 |
| 99 | 1.000000 | 0.174600 |
100 rows × 2 columns
Lecture 25
X = d.x.to_numpy().reshape(-1,1)
y = d.y.to_numpy()
with pm.Model() as model:
l = pm.Gamma("l", alpha=2, beta=1)
s = pm.HalfCauchy("s", beta=5)
nug = pm.HalfCauchy("nug", beta=5)
cov = s**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=l)
gp = pm.gp.Marginal(cov_func=cov)
y_ = gp.marginal_likelihood(
"y", X=X, y=y, sigma=nug
)Beyond the ability of PyMC to use different sampling steps - it can also use different sampler algorithm implementations to run your model.
These can be changed via the nuts_sampler argument which currently supports:
pymc - standard NUTS sampler using pymc’s C backend
blackjax - uses the blackjax library which is a collection of samplers written for JAX
numpyro - probabilistic programming library for pyro built using JAX
nutpie - provides a wrapper to the nuts-rs Rust library (slight variation on NUTS implementation)
6.11 s ± 35.2 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
with model:
post_jax = pm.sample(nuts_sampler="blackjax", chains=4, progressbar=False)3.96 s ± 21.3 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
At the moment both Python & R offer two variants of Stan:
pystan & RStan - native language interface to the underlying Stan C++ libraries
CmdStanPy & CmdStanR - are wrappers around the CmdStan command line interface
./model.stan)Any of the above tools will require a modern C++ toolchain (C++17 support required).
Stan code is divided up into specific blocks depending on usage - all of the following blocks are optional but the ordering has to match what is given below.
functions {
// user-defined functions
}
data {
// declares the required data for the model
}
transformed data {
// allows the definition of constants and transformations of the data
}
parameters {
// declares the model’s parameters
}
transformed parameters {
// variables defined in terms of data and parameters
}
model {
// defines the log probability function
}
generated quantities {
// derived quantities based on parameters, data, and random number generation
}CmdStanMCMC: model=bernoulli chains=4['method=sample', 'algorithm=hmc', 'adapt', 'engaged=1']
csv_files:
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_1.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_2.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_3.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_4.csv
output_files:
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_0-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_1-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_2-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpmrtup1np/bernoulli_umwgr17/bernoulli-20250416091714_3-stdout.txt
{'theta': array([0.16229, 0.17837, 0.1677 , 0.3568 , 0.14067, 0.37961, 0.20291, 0.44423, 0.39214, 0.19888, 0.17334, 0.52999, 0.42058, 0.16633, 0.16216, 0.26243, 0.26993, 0.26993, 0.2764 , 0.4083 , 0.37723,
0.3893 , 0.26481, 0.39696, 0.5198 , 0.47839, 0.08574, 0.27821, 0.21389, 0.22669, ..., 0.45443, 0.45443, 0.35297, 0.34254, 0.13414, 0.09956, 0.12715, 0.13077, 0.30545, 0.18828, 0.18828,
0.19007, 0.11608, 0.43204, 0.35169, 0.30381, 0.30381, 0.12406, 0.12406, 0.17983, 0.2686 , 0.4043 , 0.39981, 0.27243, 0.28782, 0.44704, 0.50479, 0.48816, 0.61159, 0.33675], shape=(4000,))}
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.
Checking sampler transitions for divergences.
No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.
Processing complete, no problems detected.
Lec25/gp.stan
data {
int<lower=1> N;
array[N] real x;
vector[N] y;
}
parameters {
real<lower=0> l;
real<lower=0> s;
real<lower=0> nug;
}
model {
// Covariance
matrix[N, N] K = gp_exp_quad_cov(x, s, l);
K = add_diag(K, nug^2);
matrix[N, N] L = cholesky_decompose(K);
// priors
l ~ gamma(2, 1);
s ~ cauchy(0, 5);
nug ~ cauchy(0, 1);
// model
y ~ multi_normal_cholesky(rep_vector(0, N), L);
}09:17:14 - cmdstanpy - INFO - CmdStan start processing
09:17:14 - cmdstanpy - INFO - Chain [1] start processing
09:17:14 - cmdstanpy - INFO - Chain [2] start processing
09:17:14 - cmdstanpy - INFO - Chain [3] start processing
09:17:14 - cmdstanpy - INFO - Chain [4] start processing
09:17:17 - cmdstanpy - INFO - Chain [4] done processing
09:17:17 - cmdstanpy - INFO - Chain [1] done processing
09:17:17 - cmdstanpy - INFO - Chain [3] done processing
09:17:17 - cmdstanpy - INFO - Chain [2] done processing
09:17:17 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: cholesky_decompose: A is not symmetric. A[1,2] = nan, but A[2,1] = nan (in 'gp.stan', line 15, column 2 to column 41)
Exception: cholesky_decompose: A is not symmetric. A[1,2] = nan, but A[2,1] = nan (in 'gp.stan', line 15, column 2 to column 41)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp.stan', line 13, column 2 to column 44)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp.stan', line 13, column 2 to column 44)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp.stan', line 15, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp.stan', line 15, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp.stan', line 15, column 2 to column 41)
Consider re-running with show_console=True if the above output is unclear!
| Mean | MCSE | StdDev | MAD | 5% | 50% | 95% | ESS_bulk | ESS_tail | R_hat | |
|---|---|---|---|---|---|---|---|---|---|---|
| lp__ | -43.020600 | 0.032322 | 1.262810 | 1.073850 | -45.525500 | -42.711700 | -41.604900 | 1701.51 | 1828.55 | 1.00344 |
| l | 0.108023 | 0.000618 | 0.025276 | 0.023351 | 0.072006 | 0.104541 | 0.154795 | 1893.82 | 1925.06 | 1.00196 |
| s | 2.263910 | 0.022461 | 0.862234 | 0.633634 | 1.338890 | 2.064410 | 3.828720 | 1856.83 | 1641.52 | 1.00099 |
| nug | 0.730725 | 0.001202 | 0.057528 | 0.056061 | 0.641708 | 0.727218 | 0.830269 | 2308.52 | 2305.62 | 1.00166 |
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.
Checking sampler transitions for divergences.
No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.
Processing complete, no problems detected.
The nutpie package can also be used to compile and run stan models, it uses a package called bridgestan to interface with stan.
import nutpie
m = nutpie.compile_stan_model(filename="Lec25/gp.stan")
m = m.with_data(x=d["x"],y=d["y"],N=len(d["x"]))
gp_fit_nutpie = nutpie.sample(m, chains=4)Sampler Progress
Total Chains: 4
Active Chains: 0
Finished Chains: 4
Sampling for now
Estimated Time to Completion: now
| Progress | Draws | Divergences | Step Size | Gradients/Draw |
|---|---|---|---|---|
| 1400 | 0 | 0.83 | 3 | |
| 1400 | 0 | 0.82 | 3 | |
| 1400 | 0 | 0.80 | 7 | |
| 1400 | 0 | 0.80 | 7 |
Lec25/gp2.stan
functions {
// From https://mc-stan.org/docs/stan-users-guide/gaussian-processes.html#predictive-inference-with-a-gaussian-process
vector gp_pred_rng(
array[] real x2,
vector y1,
array[] real x1,
real alpha,
real rho,
real sigma,
real delta
) {
int N1 = rows(y1);
int N2 = size(x2);
vector[N2] f2;
{
matrix[N1, N1] L_K;
vector[N1] K_div_y1;
matrix[N1, N2] k_x1_x2;
matrix[N1, N2] v_pred;
vector[N2] f2_mu;
matrix[N2, N2] cov_f2;
matrix[N2, N2] diag_delta;
matrix[N1, N1] K;
K = gp_exp_quad_cov(x1, alpha, rho);
for (n in 1:N1) {
K[n, n] = K[n, n] + square(sigma);
}
L_K = cholesky_decompose(K);
K_div_y1 = mdivide_left_tri_low(L_K, y1);
K_div_y1 = mdivide_right_tri_low(K_div_y1', L_K)';
k_x1_x2 = gp_exp_quad_cov(x1, x2, alpha, rho);
f2_mu = (k_x1_x2' * K_div_y1);
v_pred = mdivide_left_tri_low(L_K, k_x1_x2);
cov_f2 = gp_exp_quad_cov(x2, alpha, rho) - v_pred' * v_pred;
diag_delta = diag_matrix(rep_vector(delta, N2));
f2 = multi_normal_rng(f2_mu, cov_f2 + diag_delta);
}
return f2;
}
}
data {
int<lower=1> N; // number of observations
array[N] real x; // univariate covariate
vector[N] y; // target variable
int<lower=1> Np; // number of test points
array[Np] real xp; // univariate test points
}
transformed data {
real delta = 1e-9;
}
parameters {
real<lower=0> l;
real<lower=0> s;
real<lower=0> nug;
}
model {
// Covariance
matrix[N, N] K = gp_exp_quad_cov(x, s, l);
K = add_diag(K, nug^2);
matrix[N, N] L = cholesky_decompose(K);
// priors
l ~ gamma(2, 1);
s ~ cauchy(0, 5);
nug ~ cauchy(0, 1);
// model
y ~ multi_normal_cholesky(rep_vector(0, N), L);
}
generated quantities {
vector[Np] f = gp_pred_rng(xp, y, x, s, l, nug, delta);
}09:17:43 - cmdstanpy - INFO - CmdStan start processing
09:17:43 - cmdstanpy - INFO - Chain [1] start processing
09:17:43 - cmdstanpy - INFO - Chain [2] start processing
09:17:43 - cmdstanpy - INFO - Chain [3] start processing
09:17:43 - cmdstanpy - INFO - Chain [4] start processing
09:17:46 - cmdstanpy - INFO - Chain [4] done processing
09:17:46 - cmdstanpy - INFO - Chain [1] done processing
09:17:46 - cmdstanpy - INFO - Chain [2] done processing
09:17:46 - cmdstanpy - INFO - Chain [3] done processing
09:17:46 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: gp_exp_quad_cov: sigma is 0, but must be positive! (in 'gp2.stan', line 59, column 2 to column 44)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 61, column 2 to column 41)
Consider re-running with show_console=True if the above output is unclear!
| Mean | MCSE | StdDev | MAD | 5% | 50% | 95% | ESS_bulk | ESS_tail | R_hat | |
|---|---|---|---|---|---|---|---|---|---|---|
| lp__ | -42.978800 | 0.029290 | 1.230960 | 1.058280 | -45.340000 | -42.688600 | -41.599400 | 1835.49 | 2245.48 | 1.00102 |
| l | 0.106693 | 0.000558 | 0.024443 | 0.021555 | 0.071577 | 0.103845 | 0.152534 | 2073.21 | 1806.14 | 1.00285 |
| s | 2.217030 | 0.020762 | 0.829421 | 0.598059 | 1.316430 | 2.025060 | 3.716110 | 2131.79 | 1924.12 | 1.00211 |
| nug | 0.730309 | 0.001118 | 0.056770 | 0.055789 | 0.644002 | 0.726151 | 0.828365 | 2617.77 | 2518.06 | 1.00335 |
| f[1] | 3.474290 | 0.007158 | 0.436905 | 0.430087 | 2.749300 | 3.469260 | 4.189870 | 3727.72 | 3818.26 | 1.00035 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| f[117] | -0.687289 | 0.033885 | 2.052400 | 1.841560 | -4.102700 | -0.637946 | 2.538720 | 3738.04 | 3595.41 | 1.00081 |
| f[118] | -0.653030 | 0.034672 | 2.101530 | 1.872460 | -4.197940 | -0.578823 | 2.647560 | 3747.05 | 3566.01 | 1.00086 |
| f[119] | -0.616564 | 0.035333 | 2.143370 | 1.913370 | -4.244260 | -0.547837 | 2.696430 | 3752.84 | 3568.81 | 1.00116 |
| f[120] | -0.578995 | 0.035891 | 2.179080 | 1.926490 | -4.269130 | -0.501916 | 2.816850 | 3764.33 | 3479.13 | 1.00133 |
| f[121] | -0.541126 | 0.036365 | 2.209780 | 1.911760 | -4.319800 | -0.484247 | 2.923980 | 3768.46 | 3514.99 | 1.00115 |
125 rows × 10 columns
array([ 3.47429, 3.57525, 3.62487, 3.61779, 3.55041, 3.42122, 3.23093, 2.98257, 2.68135, 2.33445, 1.95065, 1.53986, 1.11254, 0.67917, 0.24975, -0.16664, -0.56212, -0.93022, -1.26588,
-1.56547, -1.8266 , -2.04792, -2.22885, -2.36928, -2.4694 , -2.52945, -2.54968, -2.53042, -2.47218, -2.37592, -2.24323, -2.0765 , -1.87898, -1.6547 , -1.40827, -1.14463, -0.8689 , -0.58623,
-0.30175, -0.0207 , 0.25158, 0.50953, 0.74746, 0.95963, 1.14046, 1.2848 , 1.38822, 1.44747, 1.46083, 1.42853, 1.35294, 1.23873, 1.09269, 0.92342, 0.74079, 0.55531, 0.37741,
0.21679, 0.0817 , -0.02153, -0.08898, -0.11938, -0.11416, -0.07734, -0.01514, 0.06461, 0.15324, 0.24215, 0.32376, 0.39228, 0.44441, 0.47957, 0.49987, 0.50972, 0.51507, 0.52254,
0.53852, 0.56833, 0.61563, 0.68204, 0.76708, 0.86828, 0.98154, 1.10165, 1.22279, 1.33903, 1.44477, 1.53496, 1.60526, 1.65199, 1.67219, 1.66366, 1.62504, 1.556 , 1.45736,
1.33125, 1.18103, 1.0112 , 0.8271 , 0.63462, 0.43976, 0.24831, 0.06554, -0.10406, -0.25697, -0.39064, -0.50358, -0.5953 , -0.66617, -0.71728, -0.75026, -0.76712, -0.77006, -0.76136,
-0.74328, -0.71794, -0.68729, -0.65303, -0.61656, -0.579 , -0.54113])
Sta 663 - Spring 2025