/* multiplexor test */

#include <stdio.h>
#include "atree.h"

#define CONTROL_BITS    2
#define LEAVES          256
#define TEST_SIZE       500
#define TRAIN_SIZE      500 
#define VOTERS          1
#define EPOCHS          100
#define VERBOSITY       1

#define WIDTH           ((CONTROL_BITS) + (1 << (CONTROL_BITS)))


char multiplexor(v)
char *v;
{
int i;
int lead = 0;

    for (i = 0; i < CONTROL_BITS; i++)
    {
       lead = (lead << 1) + v[i];
    }
    return(v[lead + CONTROL_BITS]);
}

main()
{
    atree *tree[VOTERS];
    bit_vec training_set[TRAIN_SIZE];
    bit_vec result_set[TRAIN_SIZE];
    bit_vec *test;
    char vec[WIDTH];
    char unpacked_result[1];
    int correct;
    int voter;
    int i;
    int j;

    /* Initialize */

    atree_init();

    if (VOTERS % 2 != 1)
    {
       (void) fprintf(stderr, "VOTERS must be odd\n");
       exit(1);
    }

    /* Create the test data */

    (void) printf("Creating training data\n");

    for (i = 0; i < TRAIN_SIZE; i++)
    {
        for (j = 0; j < WIDTH; j++)
        {
            vec[j] = RANDOM(2);
        }
        training_set[i] = *(bv_pack(vec, WIDTH));
        unpacked_result[0] = multiplexor(vec);
        result_set[i] = *(bv_pack(unpacked_result, 1));
    }

    /* Create a tree and train it */

    (void) printf("Training tree\n");

    for (voter = 0; voter < VOTERS; voter++)
    {
        tree[voter] = atree_create(WIDTH, LEAVES);
        (void) atree_train(tree[voter], training_set, result_set, 0,
                           TRAIN_SIZE, TRAIN_SIZE-1, EPOCHS, VERBOSITY);
    }

    /* Test the trained tree */

    (void) printf("Testing the tree\n");

    correct = 0;
    for (i = 0; i < TEST_SIZE; i++)
    {
        int weight = 0;

        for (j = 0; j < WIDTH; j++)
        {
            vec[j] = RANDOM(2);
        }
        test = bv_pack(vec, WIDTH);

        for (voter = 0; voter < VOTERS; voter++)
        {
            weight += atree_eval(tree[voter], test);
        }

        if (multiplexor(vec) == (weight > VOTERS / 2))
        {
            correct++;
        }

        bv_free(test);
    }

    (void) printf("%d correct out of %d in final test\n", correct, TEST_SIZE);

    exit(0);
}
