Newer
Older
adaptive-nback / generators / skewed_random.py
Morteza Ansarinia on 20 Feb 2019 2 KB generate skewed data for stat analysis.
import logging
import random
import csv

class SequenceGenerator:
    """Generates random sequence of stimuli for the n-back task. Implementation is based on Ralph (2014)."""

    def __init__(
        self,
        choices: list,
        n=2,
        trials=64,                              # Number of total trials
        targets=22,                             # Number of targets
        lures1=1,                               # Number of lures (foil) similar to the (N+1)-back
        lures2=1                                # Number of lures (foil) similar to the (N-1)-back
    ):
        self.n, self.choices, self.trials, self.targets, self.lures1, self.lures2 = n, choices, trials, targets, lures1, lures2
        self.distractors = trials - targets - lures1 - lures2
        self.seq = []

    def generate(self) -> list:
        trial = 1
        self.seq = []
        while trial <= self.trials:
            self.seq += self.random_stimulus(trial)
            trial += 1
        return self.seq

    def random_stimulus(self, trial):
        rnd = random.randint(1, self.trials - trial + 1)
        targets, lures1, lures2 = self.targets, self.lures1, self.lures2
        if rnd <= targets and len(self.seq) >= self.n:
            self.targets -= 1
            return self.seq[-self.n]
        elif targets < rnd <= targets + lures1  and len(self.seq) >= self.n + 1:
            self.lures1 -= 1
            return self.seq[-(self.n+1)]
        elif targets + lures1 < rnd <= targets + lures1 + lures2 and len(self.seq) >= self.n - 1:
            self.lures2 -= 1
            return self.seq[-(self.n-1)]

        # distract
        self.distractors -= 1
        choices = [item for item in self.choices if item not in self.seq[-self.n - 1:-self.n + 1]]
        return random.choice(choices)


if __name__ == '__main__':
    choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    generator = SequenceGenerator(choices, n=2, trials=240)
    with open('../demo/data/skewed_random_statistical_distributions.csv', mode='w') as stat_dist_file:
        writer = csv.writer(stat_dist_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + choices)
        for i in range(1000):
            seq = generator.generate()
            dist = [str(seq.count(c)) for c in choices]
            writer.writerow([i] + dist)