import logging import random import csv import heapq class SequenceGenerator: """Generates random sequence of stimuli for the n-back task. Implementation is based on Ralph (2014).""" def __init__( self, choices: list, n=2, trials=24, # Number of total trials targets=8, # Number of targets lures1=2, # Number of lures (foil) similar to the (N+1)-back lures2=2 # Number of lures (foil) similar to the (N-1)-back ): self.n, self.choices, self.trials, self.targets, self.lures1, self.lures2 = n, choices, trials, targets, lures1, lures2 self.distractors = trials - targets - lures1 - lures2 self.seq = [] def generate(self) -> list: trial = 1 self.seq = [] while trial <= self.trials: self.seq += self.random_stimulus(trial) trial += 1 return self.seq def random_stimulus(self, trial): rnd = random.randint(1, self.trials - trial + 1) targets, lures1, lures2 = self.targets, self.lures1, self.lures2 if rnd <= targets and len(self.seq) >= self.n: self.targets -= 1 return self.seq[-self.n] elif targets < rnd <= targets + lures1 and len(self.seq) >= self.n + 1: self.lures1 -= 1 return self.seq[-(self.n+1)] elif targets + lures1 < rnd <= targets + lures1 + lures2 and len(self.seq) >= self.n - 1: self.lures2 -= 1 return self.seq[-(self.n-1)] # distract self.distractors -= 1 choices = [item for item in self.choices if item not in self.seq[-self.n - 1:-self.n + 1]] return random.choice(choices) def count_targets_and_lures(self): n = self.n seq = self.seq targets = 0.0 lures = 0.0 for index in range(n, len(seq)): if seq[index] == seq[index - n]: targets += 1.0 elif seq[index] == seq[index - (n-1)] or seq[index] == seq[index - (n+1)]: lures += 1.0 return targets, lures def __test_generate_stat_csv(filename): alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] trials = 240 n = 2 with open(filename, mode='w') as stat_dist_file: writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed']) for i in range(1000): generator = SequenceGenerator(alphabetic_choices, n=n, trials=trials) seq = generator.generate() dist = [float(seq.count(c)) for c in alphabetic_choices] ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices)/2), dist)) > (trials*2/3) writer.writerow([str(i)] + dist + [str(ralph_skewed)]) if __name__ == '__main__': __test_generate_stat_csv('../stat/skewed_random_statistical_distributions_240trials_1000runs.csv')