#%%
import numpy as np
import matplotlib.pyplot as plt
import math
# encode a message using poisson rate coding
def encode_message(symbol: str,
signal_freq: float,
duration: int,
**kwargs):
codebook = kwargs.get("codebook",None)
noise_freq = kwargs.get("noise_rate", 0)
if codebook is None: raise TypeError("Invalid code book.")
freq = signal_freq + noise_freq
# generate spikes more than the required size.
size = math.ceil(2 * freq * duration)
isi = np.random.exponential(scale = 1/freq, size=size)
spikes = np.cumsum(isi) # convert intervals to a train
# filter out-of-bound spikes
spikes = spikes[np.where(spikes<duration)]
# ! add leak spikes
return spikes
# generate spike train with signal rate of 2 and noise rate of 2.
spikes = encode_message("A", 4, duration=2, codebook=['A','B'],noise_rate=0)
spikes = encode_message("A", 4, duration=2, codebook=['A','B'],noise_rate=0)
#%% [markdown]
# the following code plots signal train alognside the noise spikes.
# generate signal spikes
signal_spikes = encode_message("A", 5, duration=10, codebook=['A','B'])
# generate noise spikes
noise_spikes = encode_message("", 1, duration=10, codebook=[])
spikes = np.concatenate((signal_spikes, noise_spikes))
spikes = np.sort(np.unique(spikes))
plt.eventplot(spikes, color = "white")
plt.eventplot(signal_spikes, color='green')
plt.eventplot(noise_spikes, color='red')
plt.show()
#%%