"""
Created on Wed Apr 22 15:21:11 2015
Code to compute spike-triggered average.
"""
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
def compute_sta(stim, rho, num_timesteps):
"""Compute the spike-triggered average from a stimulus and spike-train.
Args:
stim: stimulus time-series
rho: spike-train time-series
num_timesteps: how many timesteps to use in STA
Returns:
spike-triggered average for num_timesteps timesteps before spike"""
sta = np.zeros((num_timesteps,))
# This command finds the indices of all of the spikes that occur
# after 300 ms into the recording.
spike_times = rho[num_timesteps:].nonzero()[0] + num_timesteps
# Fill in this value. Note that you should not count spikes that occur
# before 300 ms into the recording.
num_spikes = rho[150:].sum()
# Compute the spike-triggered average of the spikes found.
# To do this, compute the average of all of the vectors
# starting 300 ms (exclusive) before a spike and ending at the time of
# the event (inclusive). Each of these vectors defines a list of
# samples that is contained within a window of 300 ms before each
# spike. The average of these vectors should be completed in an
# element-wise manner.
#
# Your code goes here.
print('num of spikes for STA:', num_spikes)
for sp in range(num_spikes):
t = spike_times[sp]
w = stim[t-150:t].T
sta = sta + w
sta = sta / num_spikes
return sta