import logging
import random
import csv
import heapq
class SequenceGenerator:
"""nb_gm_002 generator
Generates skewed random sequence of stimuli for the n-back task, based on Ralph (2014).
Each sequence contains specific fraction of matched trials (targets).
"""
def __init__(
self,
choices: list,
n=2,
trials=24, # Number of total trials
targets=8, # Number of targets
lures1=2, # Number of lures (foil) similar to the (N+1)-back
lures2=2 # Number of lures (foil) similar to the (N-1)-back
):
self.n, self.choices, self.trials, self.targets, self.lures1, self.lures2 = n, choices, trials, targets, lures1, lures2
self.distractors = trials - targets - lures1 - lures2
self.seq = []
def generate(self) -> list:
trial = 1
self.seq = []
while trial <= self.trials:
self.seq += self.random_stimulus(trial)
trial += 1
return self.seq
def random_stimulus(self, trial):
rnd = random.randint(1, self.trials - trial + 1)
targets, lures1, lures2 = self.targets, self.lures1, self.lures2
if rnd <= targets and len(self.seq) >= self.n:
self.targets -= 1
return self.seq[-self.n]
elif targets < rnd <= targets + lures1 and len(self.seq) >= self.n + 1:
self.lures1 -= 1
return self.seq[-(self.n+1)]
elif targets + lures1 < rnd <= targets + lures1 + lures2 and len(self.seq) >= self.n - 1:
self.lures2 -= 1
return self.seq[-(self.n-1)]
# distract
self.distractors -= 1
choices = [item for item in self.choices if item not in self.seq[-self.n - 1:-self.n + 1]]
return random.choice(choices)
def count_targets_and_lures(self):
n = self.n
seq = self.seq
targets = 0.0
lures = 0.0
for index in range(n, len(seq)):
if seq[index] == seq[index - n]:
targets += 1.0
elif seq[index] == seq[index - (n-1)] or seq[index] == seq[index - (n+1)]:
lures += 1.0
return targets, lures
def __test_generate_stat_csv(filename):
alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
trials = 24
n = 2
with open(filename, mode='w') as stat_dist_file:
writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
for i in range(1000):
generator = SequenceGenerator(alphabetic_choices, n=n, trials=trials)
seq = generator.generate()
dist = [float(seq.count(c)) for c in alphabetic_choices]
ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices)/2), dist)) > (trials*2/3)
writer.writerow([str(i)] + dist + [str(ralph_skewed)])
if __name__ == '__main__':
__test_generate_stat_csv('../benchmarks/nb_gm_002_2back_24trials.csv')