# First install PyStan using conda or pip:
# `pip install pystan` or `conda install pystan`
#%% cell 1. inference using Stan
# PyStan
import pystan
# visualization
import matplotlib.pyplot as plt
import seaborn as sns
# numerical calc and data handling
import numpy as np
import pandas as pd
# to serialize stan models to file
import pickle
import os
# initialize visualizarion and reproducability
sns.set()
np.random.seed(101)
# initialize Stan model code
model = """
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
model {
y ~ normal(alpha + beta * x, sigma);
}
"""
# grand truth parameters and model
alpha = 2.0
beta = 1.5
sigma = 2.0
x = 10 * np.random.rand(100)
y = alpha + beta * x
y = np.random.normal(y, scale=sigma)
# plot grand truth model
plt.scatter(x,y)
plt.title("Grand truth model")
plt.show()
# prepare dictionary data to pass to the sampler
data = {'N': len(x), 'x': x, 'y': y}
# stan model
# * to speed things up, it serializes/deserializes model to a file with pickle
sm_filename = './tmp/ln_stan_model.pkl'
if not os.path.exists(sm_filename):
# save
with open(sm_filename, 'wb') as f:
sm = pystan.StanModel(model_code = model)
pickle.dump(sm, f)
else:
# load
f = open(sm_filename, 'rb')
sm = pickle.load(f)
# train and sample the model
fit = sm.sampling(
data = data,
iter=1000, # number of samples per chain
chains=4, # number of markov chains
warmup=500, # discard samples from starting point
thin=1, # retention
seed= 100 # random seed
)
summary_dict = fit.summary()
summary_df = pd.DataFrame(summary_dict['summary'],
index = summary_dict['summary_rownames'],
columns = summary_dict['summary_colnames'])
# extract sampling traces
alpha = fit['alpha']
beta = fit['beta']
sigma = fit['sigma']
lp = fit['lp__'] # log posterior. better to be converged
# plot traces
x_line = np.linspace(-.5, 10.5, 100) # with margins
#np.random.shuffle(alpha)
#np.random.shuffle(beta)
for i in range(len(x_line)):
plt.plot(x_line, alpha[i] + beta[i] * x_line, alpha=0.05, color='blue')
# plot mean line
alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta']
plt.plot(x_line, alpha_mean + beta_mean * x_line, color='red')
plt.title("Stan Diagnostics")
plt.show()
# see 3rd cell for posterior and trace visualizations for each parameter
#%% cell 2. extra: sample code to serialize a Stan model to a file using pickle
import pickle
filename = '._sample_stan_model.pkl'
# to save
with open(filename, 'wb') as f:
stan_model = pystan.StanModel(model_code=model)
pickle.dump(stan_model, f)
# to load
# todo: open must throw an error if the file does not exists
with open(filename, 'rb') as f:
stan_model = pickle.load(f)
#%% cell 3. plottings
# * trace and parameter visualization (alpha)
# todo: change to a generic function that plots everything for a parameter
plt.plot(alpha)
plt.axhline(np.mean(alpha), linestyle='--', lw=1, color='lightblue')
# CI lines
low_conf, high_conf = np.percentile(alpha, 5), np.percentile(alpha, 95)
plt.axhline(low_conf, linestyle = '--', lw=1, color='darkblue')
plt.axhline(high_conf, linestyle = '--', lw=1, color='darkblue', label='90% CI')
plt.title("Diagnostic trace for alpha")
plt.ylabel("alpha")
plt.xlabel("sample")
plt.legend()
plt.show()
# * density plot
plt.hist(alpha, bins=30, density=True)
sns.kdeplot(alpha) # density as a line
# CI
plt.axvline(low_conf, linestyle='--', lw=1, color='darkblue')
plt.axvline(high_conf, linestyle='--', lw=1, color='darkblue', label='90% CI')
plt.xlabel('alpha')
plt.ylabel('density')
plt.legend()
plt.title("Posterior distribution of alpha")
plt.show()
#%%