import random import csv import heapq class SequenceGenerator: """Generate N-Back sequences with random sampling, but increase matching probability over time.""" def __init__( self, choices: list, n=2, trials=24, target_probability_start=0.33, target_probability_end=0.50 ): self.n, self.choices, self.trials = n, choices, trials self.target_probability_step = target_probability_end - target_probability_start self.target_probability = target_probability_start self.seq = [] def generate(self) -> list: self.seq = [] for t in range(self.trials): self.seq += self.random_sample() return self.seq def random_sample(self): is_target = (random.random() > self.target_probability) self.target_probability += self.target_probability_step choices = [item for item in self.choices if len(self.seq)<self.n or item != self.seq[-self.n]] return random.choice(choices) if len(self.seq) < self.n or not is_target else self.seq[-self.n] def count_targets_and_lures(self): n = self.n seq = self.seq targets = 0 lures = 0 for index in range(n, len(seq)): if seq[index] == seq[index - n]: targets += 1 elif seq[index] == seq[index - (n-1)] or seq[index] == seq[index - (n+1)]: lures += 1 return targets, lures def __test_generate_stat_csv(filename): alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] trials = 24 n = 2 with open(filename, mode='w') as stat_dist_file: writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed']) for i in range(1000): generator = SequenceGenerator(alphabetic_choices, n=n, trials=trials) seq = generator.generate() dist = [float(seq.count(c)) for c in alphabetic_choices] ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices)/2), dist)) > (trials*2/3) writer.writerow([str(i)] + dist + [str(ralph_skewed)]) if __name__ == '__main__': __test_generate_stat_csv('../benchmarks/nb_gm_001_2back_24trials.csv')