diff --git a/py/.gitignore b/py/.gitignore new file mode 100644 index 0000000..a0c3440 --- /dev/null +++ b/py/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +._* +tmp/ \ No newline at end of file diff --git a/py/.gitignore b/py/.gitignore new file mode 100644 index 0000000..a0c3440 --- /dev/null +++ b/py/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +._* +tmp/ \ No newline at end of file diff --git a/py/stan.py b/py/stan.py index 58343ad..c4884ed 100644 --- a/py/stan.py +++ b/py/stan.py @@ -2,7 +2,9 @@ # First install PyStan using conda or pip: # `pip install pystan` or `conda install pystan` -#%% +#%% cell 1. inference using Stan + + # PyStan import pystan @@ -14,6 +16,9 @@ import numpy as np import pandas as pd +# to serialize stan models to file +import pickle +import os # initialize visualizarion and reproducability sns.set() @@ -48,14 +53,26 @@ # plot grand truth model plt.scatter(x,y) +plt.title("Grand truth model") plt.show() # prepare dictionary data to pass to the sampler data = {'N': len(x), 'x': x, 'y': y} # stan model -#TODO serialize/deserialize using pickle (next code block) -sm = pystan.StanModel(model_code = model) +# * to speed things up, it serializes/deserializes model to a file with pickle + +sm_filename = './tmp/ln_stan_model.pkl' + +if not os.path.exists(sm_filename): + # save + with open(sm_filename, 'wb') as f: + sm = pystan.StanModel(model_code = model) + pickle.dump(sm, f) +else: + # load + f = open(sm_filename, 'rb') + sm = pickle.load(f) # train and sample the model @@ -74,37 +91,80 @@ index = summary_dict['summary_rownames'], columns = summary_dict['summary_colnames']) - -alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta'] - -x_line = np.linspace(-.5, 10.5, 100) # with margins -plt.plot(x_line, alpha_mean + beta_mean * x_line) - -plt.show() - # extract sampling traces alpha = fit['alpha'] beta = fit['beta'] sigma = fit['sigma'] -lp = fit['lp__'] +lp = fit['lp__'] # log posterior. better to be converged -np.random.shuffle(alpha) -np.random.shuffle(beta) -#TODO: plot traces (regression lines) -#TODO: plot posteriors +# plot traces +x_line = np.linspace(-.5, 10.5, 100) # with margins -#%% sample code to serialize a Stan model to a file using pickle +#np.random.shuffle(alpha) +#np.random.shuffle(beta) + +for i in range(len(x_line)): + plt.plot(x_line, alpha[i] + beta[i] * x_line, alpha=0.05, color='blue') + +# plot mean line +alpha_mean, beta_mean = summary_df['mean']['alpha'], summary_df['mean']['beta'] + +plt.plot(x_line, alpha_mean + beta_mean * x_line, color='red') + +plt.title("Stan Diagnostics") +plt.show() +# see 3rd cell for posterior and trace visualizations for each parameter + +#%% cell 2. extra: sample code to serialize a Stan model to a file using pickle import pickle -stan_model = pystan.StanModel(model_code=model) -filename = 'sample_stan_model.pkl' +filename = '._sample_stan_model.pkl' # to save with open(filename, 'wb') as f: + stan_model = pystan.StanModel(model_code=model) pickle.dump(stan_model, f) # to load +# todo: open must throw an error if the file does not exists with open(filename, 'rb') as f: - stan_model = pickle.load(f) \ No newline at end of file + stan_model = pickle.load(f) + + +#%% cell 3. plottings + +# * trace and parameter visualization (alpha) +# todo: change to a generic function that plots everything for a parameter +plt.plot(alpha) +plt.axhline(np.mean(alpha), linestyle='--', lw=1, color='lightblue') + +# CI lines +low_conf, high_conf = np.percentile(alpha, 5), np.percentile(alpha, 95) +plt.axhline(low_conf, linestyle = '--', lw=1, color='darkblue') +plt.axhline(high_conf, linestyle = '--', lw=1, color='darkblue', label='90% CI') + +plt.title("Diagnostic trace for alpha") +plt.ylabel("alpha") +plt.xlabel("sample") +plt.legend() + +plt.show() + + +# * density plot +plt.hist(alpha, bins=30, density=True) +sns.kdeplot(alpha) # density as a line + +# CI +plt.axvline(low_conf, linestyle='--', lw=1, color='darkblue') +plt.axvline(high_conf, linestyle='--', lw=1, color='darkblue', label='90% CI') + +plt.xlabel('alpha') +plt.ylabel('density') +plt.legend() + +plt.title("Posterior distribution of alpha") +plt.show() +#%%