""" Created on Wed Apr 22 15:21:11 2015 Code to compute spike-triggered average. """ from __future__ import division import numpy as np import matplotlib.pyplot as plt def compute_sta(stim, rho, num_timesteps): """Compute the spike-triggered average from a stimulus and spike-train. Args: stim: stimulus time-series rho: spike-train time-series num_timesteps: how many timesteps to use in STA Returns: spike-triggered average for num_timesteps timesteps before spike""" sta = np.zeros((num_timesteps,)) # This command finds the indices of all of the spikes that occur # after 300 ms into the recording. spike_times = rho[num_timesteps:].nonzero()[0] + num_timesteps # Fill in this value. Note that you should not count spikes that occur # before 300 ms into the recording. num_spikes = rho[150:].sum() # Compute the spike-triggered average of the spikes found. # To do this, compute the average of all of the vectors # starting 300 ms (exclusive) before a spike and ending at the time of # the event (inclusive). Each of these vectors defines a list of # samples that is contained within a window of 300 ms before each # spike. The average of these vectors should be completed in an # element-wise manner. # # Your code goes here. print('num of spikes for STA:', num_spikes) for sp in range(num_spikes): t = spike_times[sp] w = stim[t-150:t].T sta = sta + w sta = sta / num_spikes return sta