# First install PyStan using conda or pip: # `pip install pystan` or `conda install pystan` #%% cell 1. inference using Stan # PyStan import pystan # visualization import matplotlib.pyplot as plt import seaborn as sns # numerical calc and data handling import numpy as np import pandas as pd # to serialize stan models to file import pickle import os # initialize visualizarion and reproducability sns.set() np.random.seed(101) # initialize Stan model code model = """ data { int<lower=0> N; vector[N] x; vector[N] y; } parameters { real alpha; real beta; real<lower=0> sigma; } model { y ~ normal(alpha + beta * x, sigma); } """ # grand truth parameters and model alpha = 2.0 beta = 1.5 sigma = 2.0 x = 10 * np.random.rand(100) y = alpha + beta * x y = np.random.normal(y, scale=sigma) # plot grand truth model plt.scatter(x,y) plt.title("Grand truth model") plt.show() # prepare dictionary data to pass to the sampler data = {'N': len(x), 'x': x, 'y': y} # stan model # * to speed things up, it serializes/deserializes model to a file with pickle sm_filename = './tmp/ln_stan_model.pkl' if not os.path.exists(sm_filename): # save with open(sm_filename, 'wb') as f: sm = pystan.StanModel(model_code = model) pickle.dump(sm, f) else: # load f = open(sm_filename, 'rb') sm = pickle.load(f) # train and sample the model fit = sm.sampling( data = data, iter=1000, # number of samples per chain chains=4, # number of markov chains warmup=500, # discard samples from starting point thin=1, # retention seed= 100 # random seed ) summary_dict = fit.summary() summary_df = pd.DataFrame(summary_dict['summary'], index = summary_dict['summary_rownames'], columns = summary_dict['summary_colnames']) # extract sampling traces alpha = fit['alpha'] beta = fit['beta'] sigma = fit['sigma'] lp = fit['lp__'] # log posterior. better to be converged # plot traces x_line = np.linspace(-.5, 10.5, 100) # with margins #np.random.shuffle(alpha) #np.random.shuffle(beta) for i in range(len(x_line)): plt.plot(x_line, alpha[i] + beta[i] * x_line, alpha=0.05, color='blue') # plot mean line alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta'] plt.plot(x_line, alpha_mean + beta_mean * x_line, color='red') plt.title("Stan Diagnostics") plt.show() # see 3rd cell for posterior and trace visualizations for each parameter #%% cell 2. extra: sample code to serialize a Stan model to a file using pickle import pickle filename = '._sample_stan_model.pkl' # to save with open(filename, 'wb') as f: stan_model = pystan.StanModel(model_code=model) pickle.dump(stan_model, f) # to load # todo: open must throw an error if the file does not exists with open(filename, 'rb') as f: stan_model = pickle.load(f) #%% cell 3. plottings # * trace and parameter visualization (alpha) # todo: change to a generic function that plots everything for a parameter plt.plot(alpha) plt.axhline(np.mean(alpha), linestyle='--', lw=1, color='lightblue') # CI lines low_conf, high_conf = np.percentile(alpha, 5), np.percentile(alpha, 95) plt.axhline(low_conf, linestyle = '--', lw=1, color='darkblue') plt.axhline(high_conf, linestyle = '--', lw=1, color='darkblue', label='90% CI') plt.title("Diagnostic trace for alpha") plt.ylabel("alpha") plt.xlabel("sample") plt.legend() plt.show() # * density plot plt.hist(alpha, bins=30, density=True) sns.kdeplot(alpha) # density as a line # CI plt.axvline(low_conf, linestyle='--', lw=1, color='darkblue') plt.axvline(high_conf, linestyle='--', lw=1, color='darkblue', label='90% CI') plt.xlabel('alpha') plt.ylabel('density') plt.legend() plt.title("Posterior distribution of alpha") plt.show() #%%