# First install PyStan using conda or pip: # `pip install pystan` or `conda install pystan` #%% # PyStan import pystan # visualization import matplotlib.pyplot as plt import seaborn as sns # numerical calc and data handling import numpy as np import pandas as pd # initialize visualizarion and reproducability sns.set() np.random.seed(101) # initialize Stan model code model = """ data { int<lower=0> N; vector[N] x; vector[N] y; } parameters { real alpha; real beta; real<lower=0> sigma; } model { y ~ normal(alpha + beta * x, sigma); } """ # grand truth parameters and model alpha = 2.0 beta = 1.5 sigma = 2.0 x = 10 * np.random.rand(100) y = alpha + beta * x y = np.random.normal(y, scale=sigma) # plot grand truth model plt.scatter(x,y) plt.show() # prepare dictionary data to pass to the sampler data = {'N': len(x), 'x': x, 'y': y} # stan model #TODO serialize/deserialize using pickle (next code block) sm = pystan.StanModel(model_code = model) # train and sample the model fit = sm.sampling( data = data, iter=1000, # number of samples per chain chains=4, # number of markov chains warmup=500, # discard samples from starting point thin=1, # retention seed= 100 # random seed ) summary_dict = fit.summary() summary_df = pd.DataFrame(summary_dict['summary'], index = summary_dict['summary_rownames'], columns = summary_dict['summary_colnames']) alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta'] x_line = np.linspace(-.5, 10.5, 100) # with margins plt.plot(x_line, alpha_mean + beta_mean * x_line) plt.show() # extract sampling traces alpha = fit['alpha'] beta = fit['beta'] sigma = fit['sigma'] lp = fit['lp__'] np.random.shuffle(alpha) np.random.shuffle(beta) #TODO: plot traces (regression lines) #TODO: plot posteriors #%% sample code to serialize a Stan model to a file using pickle import pickle stan_model = pystan.StanModel(model_code=model) filename = 'sample_stan_model.pkl' # to save with open(filename, 'wb') as f: pickle.dump(stan_model, f) # to load with open(filename, 'rb') as f: stan_model = pickle.load(f)