Newer
Older
adaptive-nback / generators / nb_gm_003.py
import random
import scipy.stats


class SequenceGenerator:
    """nb_gm_003:
        - pseudo-random sampling.
        - specific number of matching trials.
        - even distribution of stimuli.
    """

    def __init__(
        self,
        choices,
        n,
        targets_ratio=0.33            # ratio of matched trials (targets) in all trials
    ):
        self.trials, self.choices, self.n, self.targets_ratio = None, choices, n, targets_ratio
        self.seq = []

        # create norm distributions for cost functions
        self.evendist_norm = None
        self.targets_ratio_norm = scipy.stats.norm(targets_ratio, 0.2)
        self.tlratio_norm = None

    def reset(self, trials):
        self.trials = trials
        self.evendist_norm = scipy.stats.norm(0, trials/len(self.choices))
        self.tlratio_norm = scipy.stats.norm(2.0, trials/2)
        if self.seq:
            self.seq.clear()

    def generate(self, trials):
        self.reset(trials)
        while not self.seq or len(self.seq) < self.trials:
            # self.seq += self.best_choice()
            chunk_size = self.n + 1 if len(self.seq) + self.n + 1 <= self.trials else self.trials-len(self.seq)
            self.seq += self.best_chunk(chunk_size)
        return self.seq

    def best_chunk(self, chunk_size) -> list:
        from itertools import permutations
        min_cost, best_chunk = None, None
        chunks = list(permutations(self.choices, chunk_size))
        random.shuffle(chunks)
        for chunk in chunks:
            cost = self.cost(self.seq + list(chunk))
            if min_cost is None or cost < min_cost:
                min_cost, best_chunk = cost, chunk
        return list(best_chunk)

    def best_choice(self) -> list:
        best_choice, min_cost = None, None
        random.shuffle(self.choices)  # to avoid ordering effect
        for choice in self.choices:
            cost = self.cost(self.seq + [choice])
            if min_cost is None or cost < min_cost:
                min_cost, best_choice = cost, choice
        return [best_choice]

    def cost(self, seq):
        # DEBUG print(self.matchratio_cost(seq), self.evendist_cost(seq))
        return self.even_distribution_cost(seq) + self.targets_ratio_cost(seq)

    def even_distribution_cost(self, seq):
        even_ratio = self.trials / len(self.choices)
        costs = {c: abs(seq.count(c) - even_ratio) for c in self.choices}
        max_deviation_from_even_dist = max(list(costs.values()))
        cost = 1.0 - self.evendist_norm.pdf(max_deviation_from_even_dist)
        return cost

    def targets_ratio_cost(self, seq):
        targets, _ = self.count_targets_and_lures(seq, self.n)
        return 1.0 - self.targets_ratio_norm.pdf(targets/len(seq))

    @classmethod
    def count_targets_and_lures(cls, seq, n):
        mask = ''
        for index in range(n, len(seq)):
            if seq[index] == seq[index - n]:
                mask += 'T'
            elif seq[index] in seq[index - (n-1):index - (n+1)]:
                mask += 'L'
            else:
                mask += 'D'
        return mask.count('T'), mask.count('L')

    def tlratio_cost(self, seq) -> float:
        """
         Calculates the T:L ratio deviation from the norm in a block of trials.
        :param seq: sequence of trials
        :return: float number between 0 and 1 representing the tl_ratio cost
        """

        targets, lures = self.count_targets_and_lures(seq, self.n)
        tl = targets/lures if lures != 0 else targets
        return 1.0 - self.tlratio_norm.pdf(tl)


if __name__ == '__main__':
    import time

    n = 2
    gen = SequenceGenerator(['1','2','3','4','5','6'], n)

    st = time.time()
    s = gen.generate(24)
    st = time.time() - st
    print(f"{gen.count_targets_and_lures(s,n)}")
    print(f"time = {st:0.2f}s")
    st = time.time()
    s = gen.generate(48)
    st = time.time() - st
    print(f"{gen.count_targets_and_lures(s,n)}")
    print(f"time = {st:0.2f}s")
    st = time.time()
    s = gen.generate(72)
    st = time.time() - st
    print(f"{gen.count_targets_and_lures(s,n)}")
    print(f"time = {st:0.2f}s")