diff --git a/py/spike_train_visualization.py b/py/spike_train_visualization.py index 09488d9..45729cb 100644 --- a/py/spike_train_visualization.py +++ b/py/spike_train_visualization.py @@ -1,33 +1,59 @@ #%% [markdown] -# This snippet plots a spiking train in a raster plot. -# `spike_times` contains the timestamp for each even in the spike train. +# This snippet plots several spiking trains in a raster plot. - +# Install the following package before running : +# - streamlit for the demo app +# - matplotlib and seaborn for visualization +# - numpy for simple random generator and array handling #%% +import streamlit as st import numpy as np import matplotlib.pyplot as plt import seaborn as sns -# Set the theme -sns.set_style("white") +bar = st.sidebar.progress(0) -# generate random timestamps (0 to 100s) -spike_times = np.random.random([8,50]) * 100 +def app(): + # theme and colors + sns.set_style("white") + st.title("Spike Train Visualizer") + st.sidebar.title("Parameters") + num_of_channels = st.sidebar.slider("Number of channels?", 1, 20, 5) + + plot = spikeplot(num_of_channels) + plot.show() + + txt = st.empty() + txt.markdown("This sample code uses `matplotlib.pyplot.eventplot` to visualize a set of spike trains in a single plot.") + st.pyplot() + # demo app + bar.progress(100) + bar.empty() + +def spikeplot(num_of_channels=5): + + # `spike_times` contains the timestamp for each even in the spike train. + # generate random timestamps (0 to 100s) + spike_times = np.random.random([num_of_channels,50]) * 100 + + bar.progress(10) + + #* creating the plot. It also accepts color and linelength as arrays for colors and lengths. + plt.eventplot(spike_times, + color="blue", + linelength=0.9) + + bar.progress(50) + + # title and axis + plt.title("Sample spike train plot") + plt.xlabel("time") + # plt.ylabel("channels") + #plt.axis('off') + plt.yticks(range(0, num_of_channels),[f'Channel {l+1}' for l in range(num_of_channels)]) -#* creating the plot. It also accepts color and linelength as arrays for colors and lengths. -plt.eventplot(spike_times, - color="blue", - linelength=0.9) + return plt -# title and axis -plt.title("Sample spike train plot") -plt.xlabel("time") -# plt.ylabel("channels") -#plt.axis('off') -plt.yticks(np.arange(0, 8),[f'Channel {l+1}' for l in range(8)]) - - -plt.show() - -#%% +if __name__ == "__main__": + app() \ No newline at end of file