# First install PyStan using conda or pip:
# `pip install pystan` or `conda install pystan`
#%%
# PyStan
import pystan
# visualization
import matplotlib.pyplot as plt
import seaborn as sns
# numerical calc and data handling
import numpy as np
import pandas as pd
# initialize visualizarion and reproducability
sns.set()
np.random.seed(101)
# initialize Stan model code
model = """
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
model {
y ~ normal(alpha + beta * x, sigma);
}
"""
# grand truth parameters and model
alpha = 2.0
beta = 1.5
sigma = 2.0
x = 10 * np.random.rand(100)
y = alpha + beta * x
y = np.random.normal(y, scale=sigma)
# plot grand truth model
plt.scatter(x,y)
plt.show()
# prepare dictionary data to pass to the sampler
data = {'N': len(x), 'x': x, 'y': y}
# stan model
#TODO serialize/deserialize using pickle (next code block)
sm = pystan.StanModel(model_code = model)
# train and sample the model
fit = sm.sampling(
data = data,
iter=1000, # number of samples per chain
chains=4, # number of markov chains
warmup=500, # discard samples from starting point
thin=1, # retention
seed= 100 # random seed
)
summary_dict = fit.summary()
summary_df = pd.DataFrame(summary_dict['summary'],
index = summary_dict['summary_rownames'],
columns = summary_dict['summary_colnames'])
alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta']
x_line = np.linspace(-.5, 10.5, 100) # with margins
plt.plot(x_line, alpha_mean + beta_mean * x_line)
plt.show()
# extract sampling traces
alpha = fit['alpha']
beta = fit['beta']
sigma = fit['sigma']
lp = fit['lp__']
np.random.shuffle(alpha)
np.random.shuffle(beta)
#TODO: plot traces (regression lines)
#TODO: plot posteriors
#%% sample code to serialize a Stan model to a file using pickle
import pickle
stan_model = pystan.StanModel(model_code=model)
filename = 'sample_stan_model.pkl'
# to save
with open(filename, 'wb') as f:
pickle.dump(stan_model, f)
# to load
with open(filename, 'rb') as f:
stan_model = pickle.load(f)