/*****************************************************************************
 ****                                                                     ****
 **** lf.c                                                                ****
 ****                                                                     ****
 **** atree release 2.0                                                   ****
 **** Adaptive Logic Network (ALN) simulation program.                    ****
 **** Copyright (C) A. Dwelly, R. Manderscheid, W.W. Armstrong, 1991.     ****
 ****                                                                     ****
 **** License:                                                            ****
 **** A royalty-free license is granted for the use of this software for  ****
 **** NON_COMMERCIAL PURPOSES ONLY. The software may be copied and/or     ****
 **** modified provided this notice appears in its entirety and unchanged ****
 **** in all derived source programs.  Persons modifying the code are     ****
 **** requested to state the date, the changes made and who made them     ****
 **** in the modification history.                                        ****
 ****                                                                     ****
 **** Patent License:                                                     ****
 **** The use of a digital circuit which transmits a signal indicating    ****
 **** heuristic responsibility is protected by U. S. Patent 3,934,231     ****
 **** and others assigned to Dendronic Decisions Limited of Edmonton,     ****
 **** W. W. Armstrong, President.  A royalty-free license is granted      ****
 **** by the company to use this patent for NON_COMMERCIAL PURPOSES ONLY  ****
 **** to adapt logic trees using this program and its modifications.      ****
 ****                                                                     ****
 **** Limited Warranty:                                                   ****
 **** This software is provided "as is" without warranty of any kind,     ****
 **** either expressed or implied, including, but not limited to, the     ****
 **** implied warrantees of merchantability and fitness for a particular  ****
 **** purpose.  The entire risk as to the quality and performance of the  ****
 **** program is with the user.  Neither the authors, nor the             ****
 **** University of Alberta, its officers, agents, servants or employees  ****
 **** shall be liable or responsible in any way for any damage to         ****
 **** property or direct personal or consequential injury of any nature   ****
 **** whatsoever that may be suffered or sustained by any licensee, user  ****
 **** or any other party as a consequence of the use or disposition of    ****
 **** this software.                                                      ****
 ****                                                                     ****
 **** Modification history:                                               ****
 ****                                                                     ****
 **** 90.05.09 Initial implementation, A.Dwelly                           ****
 **** 91.07.15 Release 2, Rolf Manderscheid                               ****
 ****                                                                     ****
 *****************************************************************************/

#include <stdio.h>
#include <math.h>
#include "atree.h"
#include "lf.h"

extern char *malloc();

#define Printf (void) printf
#define VERBOSITY 0

extern FILE *yyin;
extern int line_no;

prog_t prog;
bool_t fold_flag; /* true if saving folded trees */
                  /* remaining flags true if corresponding statement
                     has been specified. */
bool_t test_size_flag;
bool_t train_size_flag;
bool_t largest_flag;
bool_t smallest_flag;
bool_t code_flag;
bool_t quant_flag;

static bit_vec *domain_set;
static bit_vec *codomain_set;
static atree ***forest;

prog_init()
{
    atree_init();

    train_size_flag = FALSE;
    test_size_flag = FALSE;
    largest_flag = FALSE;
    smallest_flag = FALSE;
    code_flag = FALSE;
    quant_flag = FALSE;
    fold_flag = FALSE;

    prog.forest_folded = FALSE;
    prog.tree_sz = 0;
    prog.min_correct = 0;
    prog.vote = 1;
    prog.max_epochs = 0;
    prog.save_tree = NULL;
    prog.load_tree = NULL;
    prog.save_code = NULL;
    prog.load_code = NULL;
}

void
read_prog(fp)

FILE *fp;

{
    yyin = fp;
    line_no = 1;

    (void) yyparse();
}

void
process_prog()

{
    int dim;
    int voter;
    int i;
    bit_vec **concat;
    FILE *save_tree_fp = NULL;
    FILE *save_code_fp = NULL;

    /*
     * When this function is called, the program has been read in, and
     * the relevant details have been stored in the global structure
     * 'prog'. We have some semantic processing to do, then the trees
     * can be trained as specified.
     */

    if (prog.tree_sz == 0 && prog.load_tree == NULL) {
        (void) fprintf(stderr, "lf: either tree size or tree load file must be specified\n");
        exit(1);
    }

    if (prog.load_code == NULL && (!code_flag || !quant_flag)) {
        (void) fprintf(stderr, "lf: either coding+quantization or code load file must be specified.\n");
        exit(1);
    }

    if (prog.save_tree != NULL
    && (save_tree_fp = fopen(prog.save_tree, "w")) == NULL)
    {
        perror(prog.save_tree);
        exit(1);
    }

    if (prog.save_code != NULL
    && (save_code_fp = fopen(prog.save_code, "w")) == NULL)
    {
        perror(prog.save_code);
        exit(1);
    }
    
    domain_set = (bit_vec *) malloc((unsigned) sizeof(bit_vec) * 
                                    prog.trainset_sz);
    MEMCHECK(domain_set);
    codomain_set = (bit_vec *) malloc((unsigned) sizeof(bit_vec) *
                                      prog.trainset_sz);
    MEMCHECK(codomain_set);

    if (prog.load_code)
    {
        FILE *fp;

        if ((fp = fopen(prog.load_code, "r")) == NULL)
        {
            perror(prog.load_code);
            exit(1);
        }
        for (dim = 0; dim < prog.total_dimensions; dim++)
        {

            if (atree_read_code(fp, &prog.code[dim]) == NULL)
            {
                (void) fprintf(stderr, "lf: too few codings in file '%s'\n",
                               prog.load_code);
                exit(1);
            }
        }
        fclose(fp);
    }
    else
    {
        for (dim = 0; dim < prog.total_dimensions; dim++)
        {
            if (!largest_flag)
            {
                prog.code[dim].high = prog.train_table[dim][0];

                for (i = 1; i < prog.trainset_sz; i++)
                {
                    if (prog.train_table[dim][i] > prog.code[dim].high)
                    {
                        prog.code[dim].high = prog.train_table[dim][i];
                    }
                }

                for (i = 0; i < prog.testset_sz; i++)
                {
                    if (prog.test_table[dim][i] > prog.code[dim].high)
                    {
                        prog.code[dim].high = prog.test_table[dim][i];
                    }
                }
            }
            if (!smallest_flag)
            {
                prog.code[dim].low = prog.train_table[dim][0];

                for (i = 1; i < prog.trainset_sz; i++)
                {
                    if (prog.train_table[dim][i] < prog.code[dim].low)
                    {
                        prog.code[dim].low = prog.train_table[dim][i];
                    }
                }

                for (i = 0; i < prog.testset_sz; i++)
                {
                    if (prog.test_table[dim][i] < prog.code[dim].low)
                    {
                        prog.code[dim].low = prog.test_table[dim][i];
                    }
                }
            }

            if (prog.code[dim].high <= prog.code[dim].low)
            {
                (void) fprintf(stderr, "lf: largest value must be greater than smallest value, column %d\n", dim + 1);
                exit(1);
            }

            /* finish constructing code for this dimension */

            if (atree_set_code(&prog.code[dim], prog.code[dim].low,
                                                prog.code[dim].high,
                                                prog.code[dim].vector_count,
                                                prog.code[dim].width,
                                                prog.walk_step[dim]))
            {
                (void) fprintf(stderr,
                               "lf: random walk failed for dimension %d\n",
                               dim + 1);
                exit(1);
            }
        } /* for (dim...) */
    }


    prog.domain_width = 0;
    prog.codomain_width = 0;
    for (dim = 0; dim < prog.total_dimensions; dim++)
    {
        if (dim < prog.dimensions)
        {
            prog.domain_width += prog.code[dim].width;
        }
        else
        {
            prog.codomain_width += prog.code[dim].width;
        }
    }

    /*
     * Now we can create the forest.
     */
    forest = (atree ***)
             malloc((unsigned) sizeof(atree **) * prog.codomain_width);
    MEMCHECK(forest);

    for (i = 0; i < prog.codomain_width; i++)
    {
        forest[i] = (atree **) malloc((unsigned) sizeof(atree *) * prog.vote);
        MEMCHECK(forest[i]);
    }

    if (prog.load_tree)
    {
        FILE *fp;

        if ((fp = fopen(prog.load_tree, "r")) == NULL) {
            perror(prog.load_tree);
            exit(1);
        }
        for (voter = 0; voter < prog.vote; voter++)
        {
            for (i = 0; i < prog.codomain_width; i++)
            {
                if ((forest[i][voter] = atree_read(fp)) == NULL)
                {
                    (void) fprintf(stderr, "too few trees in file '%s'\n",
                                   prog.load_tree);
                    exit(1);
                }
            }
        }
        fclose(fp);
    }
    else
    {
        for (voter = 0; voter < prog.vote; voter++)
        {
            for (i = 0; i < prog.codomain_width; i++)
            {
                forest[i][voter] = atree_create(prog.domain_width,prog.tree_sz);
            }
        }
    }

    /*
     * code[i] covers the dimension i.
     * we now create the training set of bit vectors.
     */
    
    concat = (bit_vec **) malloc((unsigned)(prog.total_dimensions) *
                                  sizeof(bit_vec *));
    MEMCHECK(concat);

    for (i = 0; i < prog.trainset_sz; i++)
    {
        for (dim = 0; dim < prog.total_dimensions; dim++)
        { 
            concat[dim] = prog.code[dim].vector
                          + atree_encode(prog.train_table[dim][i],
                                         &prog.code[dim]);
        }
        domain_set[i] = *(bv_concat(prog.dimensions, concat));
        codomain_set[i] = *(bv_concat(prog.codimensions,
                                      concat + prog.dimensions));
    }

    (void) free((char *) concat);

    /* 
     * Train the trees.
     */
    if (prog.max_epochs > 0) {
        for (i = 0; i < prog.codomain_width; i++)
        {
            for (voter = 0; voter < prog.vote; voter++)
            {
                (void) atree_train(forest[i][voter], domain_set, codomain_set,
                                   i, prog.trainset_sz, prog.min_correct, 
                                   prog.max_epochs, VERBOSITY);
            }
        }
    }

    /*
     * Save trees and codings
     */
    if (save_tree_fp != NULL)
    {
        int i;
        int v;

        if (fold_flag)
        {
            fold_forest();
            prog.forest_folded = TRUE;
        }

        for (v = 0; v < prog.vote; v++)
        {
            for (i = 0; i < prog.codomain_width; i++)
            {
                atree_write(save_tree_fp, forest[i][v]);
            }
        }
        fclose(save_tree_fp);
    }

    if (save_code_fp != NULL)
    {
        int dim;

        for (dim = 0; dim < prog.total_dimensions; dim++)
        {
            atree_write_code(save_code_fp, &prog.code[dim]);
        }
        fclose(save_code_fp);
    }
}

/*
 * Test the forest against the test set.
 */
void
test_prog()
{
    int i;
    int dim;
    int col;
    int nv;
    bit_vec *test_vec;
    bit_vec **result;
    bit_vec **concat;

    concat = (bit_vec **) malloc((unsigned)(prog.dimensions) *
                                sizeof(bit_vec *));
    MEMCHECK(concat);

    /*
     * We need a result vector for each codomain dimension
     * because bv_diff expects vectors of equal length.
     */
    result = (bit_vec **)
             malloc((unsigned)(prog.total_dimensions) * sizeof(bit_vec *));
    MEMCHECK(result);
    for (dim = prog.dimensions; dim < prog.total_dimensions; dim++)
    {
        result[dim] = bv_create(prog.code[dim].width);
    }

    Printf("%d\n", prog.codimensions);

    for (i = 0; i < prog.testset_sz; i++)
    {
        /*
         * Create test vector and print out test info.
         * Remember to bv_free test_vec after each iteration.
         */
        for (dim = 0; dim < prog.total_dimensions; dim++)
        {
            nv = atree_encode(prog.test_table[dim][i], &prog.code[dim]);
            if (dim < prog.dimensions)
            {
                concat[dim] = prog.code[dim].vector + nv;
            }
            Printf("%s%f %d", dim ? "\t" : "", prog.test_table[dim][i], nv);
        }
        test_vec = bv_concat(prog.dimensions, concat);

        /*
         * for each codomain dimension ...
         */
        col = 0;
        for (dim = prog.dimensions; dim < prog.total_dimensions; dim++)
        {
            int bit_no;
            int closest;

            /*
             * Calculate result for this dimension.
             */
            for (bit_no = 0; bit_no < prog.code[dim].width; bit_no++, col++)
            {
                int weight = 0;
                int voter;

                for (voter = 0; voter < prog.vote; voter++)
                {
                    if (atree_eval(forest[col][voter], test_vec))
                    {
                        weight++;
                    }
                }
                bv_set(bit_no, result[dim], weight > prog.vote / 2);
            }

            /*
             * Calculate the nearest vector in the codomain's random walk.
             */
            closest = atree_decode(result[dim], &prog.code[dim]);
            Printf("\t%f %d",
                   prog.code[dim].low +  prog.code[dim].step * closest,
                   closest);

        } /* for (dim...) */ 
        Printf("\n");
        bv_free(test_vec);
    } /* for (i...) */

    (void) free((char *) concat);
    for (dim = prog.dimensions; dim < prog.total_dimensions; dim++)
    {
        bv_free(result[dim]);
    }
    (void) free((char *) result);
}

fold_forest()
{
    int i;
    int voter;

    for (i = 0; i < prog.codomain_width; i++)
    {
        for (voter = 0; voter < prog.vote; voter++)
        {
            forest[i][voter] = atree_fold(forest[i][voter]);
        }
    }
}

main(argc,argv)

int argc;
char *argv[];

{
    FILE *fp;

    /* Initialise the default values for the program */

    prog_init();

    /* Read the parameters and check. */

    if (argc > 2)
    {
        (void) fprintf(stderr, "Usage: lf\n");
        (void) fprintf(stderr, "       lf file\n");
        exit(1);
    }

    /* Read the input file */

    if (argc == 1)
    {
        read_prog(stdin);
    }
    else
    {
        if ((fp = fopen(*++argv, "r")) == NULL)
        {
            (void) fprintf(stderr, "lf: can't open %s\n", *argv);
            exit(1);
        }
        else
        {
           read_prog(fp);
           (void) fclose(fp);
        }
    }

    /* Train the trees as specified if there were no syntax errors */

    if (!prog.error)
    {
        process_prog();
    }

    /* Execute the trees as specified */

    if (!prog.error)
    {
        if (!prog.forest_folded)
        {
            fold_forest();
        }
        test_prog();
    }

    /* Finish */

    exit(prog.error);
}
