Newer
Older
notebooks / py / stan.py

# First install PyStan using conda or pip:
# `pip install pystan` or `conda install pystan`

#%% cell 1. inference using Stan


# PyStan
import pystan

# visualization
import matplotlib.pyplot as plt
import seaborn as sns

# numerical calc and data handling
import numpy as np
import pandas as pd

# to serialize stan models to file
import pickle
import os

# initialize visualizarion and reproducability
sns.set()
np.random.seed(101)

# initialize Stan model code

model = """
  data {
    int<lower=0> N;
    vector[N] x;
    vector[N] y;
  }
  parameters {
    real alpha;
    real beta;
    real<lower=0> sigma;
  }
  model {
    y ~ normal(alpha + beta * x, sigma);
  }
"""

# grand truth parameters and model
alpha = 2.0
beta = 1.5
sigma = 2.0

x = 10 * np.random.rand(100)
y = alpha + beta * x
y = np.random.normal(y, scale=sigma)

# plot grand truth model
plt.scatter(x,y)
plt.title("Grand truth model")
plt.show()

# prepare dictionary data to pass to the sampler
data = {'N': len(x), 'x': x, 'y': y}

# stan model
# * to speed things up, it serializes/deserializes model to a file with pickle

sm_filename = './tmp/ln_stan_model.pkl'

if not os.path.exists(sm_filename):
  # save
  with open(sm_filename, 'wb') as f:
    sm = pystan.StanModel(model_code = model)
    pickle.dump(sm, f)
else:
  # load
  f = open(sm_filename, 'rb')
  sm = pickle.load(f)


# train and sample the model
fit = sm.sampling(
  data = data, 
  iter=1000, # number of samples per chain
  chains=4, # number of markov chains
  warmup=500, # discard samples from starting point
  thin=1, # retention
  seed= 100 # random seed
  )

summary_dict = fit.summary()

summary_df = pd.DataFrame(summary_dict['summary'],
                  index = summary_dict['summary_rownames'],
                  columns = summary_dict['summary_colnames'])

# extract sampling traces
alpha = fit['alpha']
beta = fit['beta']
sigma = fit['sigma']
lp = fit['lp__'] # log posterior. better to be converged


# plot traces
x_line = np.linspace(-.5, 10.5, 100) # with margins

#np.random.shuffle(alpha)
#np.random.shuffle(beta)

for i in range(len(x_line)):
  plt.plot(x_line, alpha[i] + beta[i] * x_line, alpha=0.05, color='blue')

# plot mean line
alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta']

plt.plot(x_line, alpha_mean + beta_mean * x_line, color='red')

plt.title("Stan Diagnostics")
plt.show()
# see 3rd cell for posterior and trace visualizations for each parameter

#%% cell 2. extra: sample code to serialize a Stan model to a file using pickle

import pickle

filename = '._sample_stan_model.pkl'

# to save
with open(filename, 'wb') as f:
  stan_model = pystan.StanModel(model_code=model)
  pickle.dump(stan_model, f)

# to load
# todo: open must throw an error if the file does not exists
with open(filename, 'rb') as f:
  stan_model = pickle.load(f)


#%% cell 3. plottings

# * trace and parameter visualization (alpha)
# todo: change to a generic function that plots everything for a parameter
plt.plot(alpha)
plt.axhline(np.mean(alpha), linestyle='--', lw=1, color='lightblue')

# CI lines
low_conf, high_conf = np.percentile(alpha, 5), np.percentile(alpha, 95)
plt.axhline(low_conf, linestyle = '--', lw=1, color='darkblue')
plt.axhline(high_conf, linestyle = '--', lw=1, color='darkblue', label='90% CI')

plt.title("Diagnostic trace for alpha")
plt.ylabel("alpha")
plt.xlabel("sample")
plt.legend()

plt.show()


# * density plot
plt.hist(alpha, bins=30, density=True)
sns.kdeplot(alpha) # density as a line

# CI
plt.axvline(low_conf, linestyle='--', lw=1, color='darkblue')
plt.axvline(high_conf, linestyle='--', lw=1, color='darkblue', label='90% CI')

plt.xlabel('alpha')
plt.ylabel('density')
plt.legend()

plt.title("Posterior distribution of alpha")
plt.show()
#%%