Newer
Older
adaptive-nback / generators / nb_gm_001.py
import random
import csv

import heapq


class SequenceGenerator:
    """Generate N-Back sequences with random sampling, but increase matching probability over time."""

    def __init__(
        self,
        choices: list,
        n=2,
        trials=24,
        target_probability_start=0.33,
        target_probability_end=0.50
    ):
        self.n, self.choices, self.trials = n, choices, trials
        self.target_probability_step = target_probability_end - target_probability_start
        self.target_probability = target_probability_start
        self.seq = []

    def generate(self) -> list:
        self.seq = []
        for t in range(self.trials):
            self.seq += self.random_sample()
        return self.seq

    def random_sample(self):
        is_target = (random.random() > self.target_probability)
        self.target_probability += self.target_probability_step
        choices = [item for item in self.choices if len(self.seq)<self.n or item != self.seq[-self.n]]
        return random.choice(choices) if len(self.seq) < self.n or not is_target else self.seq[-self.n]

    def count_targets_and_lures(self):
        n = self.n
        seq = self.seq
        targets = 0
        lures = 0
        for index in range(n, len(seq)):
            if seq[index] == seq[index - n]:
                targets += 1
            elif seq[index] == seq[index - (n-1)] or seq[index] == seq[index - (n+1)]:
                lures += 1
        return targets, lures


def __test_generate_stat_csv(filename):
    alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    trials = 24
    n = 2
    with open(filename, mode='w') as stat_dist_file:
        writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
        for i in range(1000):
            generator = SequenceGenerator(alphabetic_choices, n=n, trials=trials)
            seq = generator.generate()
            dist = [float(seq.count(c)) for c in alphabetic_choices]
            ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices)/2), dist)) > (trials*2/3)
            writer.writerow([str(i)] + dist + [str(ralph_skewed)])


if __name__ == '__main__':
    __test_generate_stat_csv('../benchmarks/nb_gm_001_2back_24trials.csv')