Newer
Older
adaptive-nback / tests / benchmark_skewness.py
import random
import heapq
import csv

# Common parameters
alphabetic_choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
trials_range = (10, 40)
n = 3
sample_size = 100


def to_benchmark_csv_row(sample_index, seq):
    trials = len(seq)
    freqs = [float(seq.count(c)) for c in alphabetic_choices]
    ralph_skewed = sum(heapq.nlargest(int(len(alphabetic_choices) / 2), freqs)) > (trials * 2 / 3)
    return [str(sample_index)] + freqs + [str(ralph_skewed)]

def benchmark_nb_gm_001(filename):
    import generators.nb_gm_001 as nb_gm_001

    generator = nb_gm_001.SequenceGenerator(alphabetic_choices, n, trials_range[0])

    with open(filename, mode='w') as benchmark_results_file:
        writer = csv.writer(benchmark_results_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
        for i in range(sample_size):
            trials = random.randint(trials_range[0], trials_range[1])
            seq = generator.generate(trials)
            print(f"sequence {i}/{sample_size}: {trials} trials")
            writer.writerow(to_benchmark_csv_row(i, seq))
    show_skweness_diagram(filename, alphabetic_choices, 'nb_gm_001')


def benchmark_nb_gm_002(filename):
    import generators.nb_gm_002 as nb_gm_002

    generator = nb_gm_002.SequenceGenerator(alphabetic_choices, n, trials_range[0])

    with open(filename, mode='w') as benchmark_results_file:
        writer = csv.writer(benchmark_results_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
        for i in range(sample_size):
            trials = random.randint(trials_range[0], trials_range[1])
            seq = generator.generate(trials)
            # print(f"sequence {i}/{sample_size}: {trials} trials")
            writer.writerow(to_benchmark_csv_row(i, seq))
    show_skweness_diagram(filename, alphabetic_choices, 'nb_gm_002')


def benchmark_nb_gm_003(filename):
    import generators.nb_gm_003 as nb_gm_003

    generator = nb_gm_003.SequenceGenerator(alphabetic_choices, n, trials_range[0])

    with open(filename, mode='w') as benchmark_results_file:
        writer = csv.writer(benchmark_results_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writerow(['index'] + alphabetic_choices + ['ralph_skewed'])
        for i in range(sample_size):
            trials = random.randint(trials_range[0], trials_range[1])
            print(f"sequence {i}/{sample_size}: {trials} trials")
            seq = generator.generate(trials)
            writer.writerow(to_benchmark_csv_row(i, seq))
    show_skweness_diagram(filename, alphabetic_choices, 'nb_gm_003')


def show_skweness_diagram(filename, choices, title):
    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    print(filename)
    data = pd.read_csv(filename)
    data['trials'] = data[choices].sum(axis=1)
    max_trials = data['trials'].max()
    min_trials = data['trials'].min()
    stats = []
    for sequence_length in range(int(min_trials), int(max_trials) + 1):
        num_of_sequences = np.sum(data[data.trials == sequence_length].trials)
        skewed_sequences = np.sum(data[np.logical_and(data.trials == sequence_length, data.ralph_skewed)].trials)
        if num_of_sequences == 0:
            continue
        skewness = skewed_sequences * 100.0 / num_of_sequences
        stats.append([sequence_length, skewness])
    stats = pd.DataFrame(stats, columns=['trials', 'skewness']) #.dropna(subset=['skewness'])
    #print(stats.trials)
    #plt.ylim([0, 110])
    plt.scatter(stats.trials, stats.skewness)
    p = np.poly1d(np.polyfit(stats.trials, stats.skewness, 3))
    plt.plot(stats.trials, p(stats.trials), color='red')
    plt.title(title)
    plt.ylabel('skewed blocks (%)')
    plt.xlabel('# of trials')
    plt.savefig(f'{title}.png', bbox_inches='tight')
    plt.show()


if __name__ == '__main__':
    benchmark_nb_gm_003('../benchmarks/nb_gm_003_2back.csv')