Newer
Older
notebooks / py / stan.py

# First install PyStan using conda or pip:
# `pip install pystan` or `conda install pystan`

#%%
# PyStan
import pystan

# visualization
import matplotlib.pyplot as plt
import seaborn as sns

# numerical calc and data handling
import numpy as np
import pandas as pd


# initialize visualizarion and reproducability
sns.set()
np.random.seed(101)

# initialize Stan model code

model = """
  data {
    int<lower=0> N;
    vector[N] x;
    vector[N] y;
  }
  parameters {
    real alpha;
    real beta;
    real<lower=0> sigma;
  }
  model {
    y ~ normal(alpha + beta * x, sigma);
  }
"""

# grand truth parameters and model
alpha = 2.0
beta = 1.5
sigma = 2.0

x = 10 * np.random.rand(100)
y = alpha + beta * x
y = np.random.normal(y, scale=sigma)

# plot grand truth model
plt.scatter(x,y)
plt.show()

# prepare dictionary data to pass to the sampler
data = {'N': len(x), 'x': x, 'y': y}

# stan model
#TODO serialize/deserialize using pickle (next code block)
sm = pystan.StanModel(model_code = model)


# train and sample the model
fit = sm.sampling(
  data = data, 
  iter=1000, # number of samples per chain
  chains=4, # number of markov chains
  warmup=500, # discard samples from starting point
  thin=1, # retention
  seed= 100 # random seed
  )

summary_dict = fit.summary()

summary_df = pd.DataFrame(summary_dict['summary'],
                  index = summary_dict['summary_rownames'],
                  columns = summary_dict['summary_colnames'])


alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta']

x_line = np.linspace(-.5, 10.5, 100) # with margins
plt.plot(x_line, alpha_mean + beta_mean * x_line)

plt.show()

# extract sampling traces
alpha = fit['alpha']
beta = fit['beta']
sigma = fit['sigma']
lp = fit['lp__']

np.random.shuffle(alpha)
np.random.shuffle(beta)

#TODO: plot traces (regression lines)
#TODO: plot posteriors

#%% sample code to serialize a Stan model to a file using pickle

import pickle

stan_model = pystan.StanModel(model_code=model)
filename = 'sample_stan_model.pkl'

# to save
with open(filename, 'wb') as f:
  pickle.dump(stan_model, f)

# to load
with open(filename, 'rb') as f:
  stan_model = pickle.load(f)