Newer
Older
adaptive-nback / generators / skewed_random.py
import logging
import random
import csv

import heapq

class SequenceGenerator:
    """Generates random sequence of stimuli for the n-back task. Implementation is based on Ralph (2014)."""

    def __init__(
        self,
        choices: list,
        n=2,
        trials=24,                              # Number of total trials
        targets=8,                              # Number of targets
        lures1=2,                               # Number of lures (foil) similar to the (N+1)-back
        lures2=2                                # Number of lures (foil) similar to the (N-1)-back
    ):
        self.n, self.choices, self.trials, self.targets, self.lures1, self.lures2 = n, choices, trials, targets, lures1, lures2
        self.distractors = trials - targets - lures1 - lures2
        self.seq = []

    def generate(self) -> list:
        trial = 1
        self.seq = []
        while trial <= self.trials:
            self.seq += self.random_stimulus(trial)
            trial += 1
        return self.seq

    def random_stimulus(self, trial):
        rnd = random.randint(1, self.trials - trial + 1)
        targets, lures1, lures2 = self.targets, self.lures1, self.lures2
        if rnd <= targets and len(self.seq) >= self.n:
            self.targets -= 1
            return self.seq[-self.n]
        elif targets < rnd <= targets + lures1 and len(self.seq) >= self.n + 1:
            self.lures1 -= 1
            return self.seq[-(self.n+1)]
        elif targets + lures1 < rnd <= targets + lures1 + lures2 and len(self.seq) >= self.n - 1:
            self.lures2 -= 1
            return self.seq[-(self.n-1)]

        # distract
        self.distractors -= 1
        choices = [item for item in self.choices if item not in self.seq[-self.n - 1:-self.n + 1]]
        return random.choice(choices)

    def count_targets_and_lures(self):
        n = self.n
        seq = self.seq
        targets = 0.0
        lures = 0.0
        for index in range(n, len(seq)):
            if seq[index] == seq[index - n]:
                targets += 1.0
            elif seq[index] == seq[index - (n-1)] or seq[index] == seq[index - (n+1)]:
                lures += 1.0
        return targets, lures


def __test_generate_stat_csv(filename):
    alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    trials = 240
    n = 2
    with open(filename, mode='w') as stat_dist_file:
        writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
        for i in range(1000):
            generator = SequenceGenerator(alphabetic_choices, n=n, trials=trials)
            seq = generator.generate()
            dist = [float(seq.count(c)) for c in alphabetic_choices]
            ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices)/2), dist)) > (trials*2/3)
            writer.writerow([str(i)] + dist + [str(ralph_skewed)])


if __name__ == '__main__':
    __test_generate_stat_csv('../stat/skewed_random_statistical_distributions_240trials_1000runs.csv')