Newer
Older
notebooks / python / spike_train_visualization.py
#%% [markdown]
# This snippet plots several spiking trains in a raster plot.

# Install the following package before running :
# - streamlit for the demo app
# - matplotlib and seaborn for visualization
# - numpy for simple random generator and array handling
#%%
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

bar = st.sidebar.progress(0)

def app():
  # theme and colors
  sns.set_style("white")
  st.title("Spike Train Visualizer")
  st.sidebar.title("Parameters")
  num_of_channels = st.sidebar.slider("Number of channels?", 1, 20, 5)

  plot = spikeplot(num_of_channels)
  plot.show()

  txt = st.empty()
  txt.markdown("This sample code uses `matplotlib.pyplot.eventplot` to visualize a set of spike trains in a single plot.")
  st.pyplot()
  # demo app
  bar.progress(100)
  bar.empty()

def spikeplot(num_of_channels=5):

  # `spike_times` contains the timestamp for each even in the spike train.
  # generate random timestamps (0 to 100s)
  spike_times = np.random.random([num_of_channels,50]) * 100

  bar.progress(10)

  #* creating the plot. It also accepts color and linelength as arrays for colors and lengths.
  plt.eventplot(spike_times, 
                color="blue",
                linelength=0.9)

  bar.progress(50)

  # title and axis
  plt.title("Sample spike train plot")
  plt.xlabel("time")
  # plt.ylabel("channels")
  #plt.axis('off')
  plt.yticks(range(0, num_of_channels),[f'Channel {l+1}' for l in range(num_of_channels)])


  return plt

if __name__ == "__main__":
    app()