
/**********************************************************************
 * $Id: scl.c,v 1.3 92/11/30 12:02:05 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its documentation for any purpose is hereby granted without fee, 
 * provided that the above copyright notice appears in all copies and that 
 * both the copyright notice and this permission notice appear in 
 * supporting documentation, and that the name of University of Toronto 
 * not be used in advertising or publicity pertaining to distribution 
 * of the software without specific, written prior permission.  
 * University of Toronto makes no representations about the suitability 
 * of this software for any purpose. It is provided "as is" without 
 * express or implied warranty. 
 *
 * UNIVERSITY OF TORONTO DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS 
 * SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND 
 * FITNESS, IN NO EVENT SHALL UNIVERSITY OF TORONTO BE LIABLE FOR ANY 
 * SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER 
 * RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF 
 * CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 **********************************************************************/


 /*********************************************************************
 *
 *  KCL/SCL/HCL competitive learning modules written by Sue Becker
 *                                           Dept. of Computer Science
 *  November 1991                            Univ. of Toronto
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/useful.h>
#include <xerion/version.h>
#include <xerion/simulator.h>
#include <xerion/commands.h>
#include <xerion/minimize.h>

#include "scl.h"
#include "help.h"

static void	initNet                ARGS((Net	net)) ;
static void	deinitNet              ARGS((Net	net)) ;
static void	calculateNetErrorDeriv ARGS((Net net,ExampleSet exampleSet));
static void	calculateNetError      ARGS((Net net,ExampleSet exampleSet));
static void	calculateAveNetVariance ARGS((Net net,ExampleSet exampleSet));

static void	initGroup              ARGS((Group	group)) ;
static Proc     softCompGroupActivityUpdate ARGS((Group group)) ;

static void	initUnit               ARGS((Unit  unit)) ;
static Proc     setSoftCompSumExpInput ARGS((Unit  unit, Real *data)) ;
static void     setSoftCompTotalInput  ARGS((Unit  unit)) ;
static Proc     setSoftCompOutput      ARGS((Unit  unit, Real *data)) ;
static Proc	gradUpdate             ARGS((Unit  unit)) ;
static Proc	zeroLinks              ARGS((Unit  unit, void *data)) ;
static Proc	zeroVariance           ARGS((Unit  unit, void *data)) ;
static Proc	incVariance            ARGS((Unit  unit, void *data)) ;
static Proc	normalizeVariance      ARGS((Unit  unit, int  *data)) ;
static Proc	incNetVariance         ARGS((Unit  unit, void *data)) ;

static Real	square    ARGS((double  x)) ;

int command_suggestInitialVariance     ARGS((int tokc, char *tokv[])) ;

/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command 
 *				  line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  extern void	createScatterDisplay() ;
  extern int	addUserDisplay() ;

  authors = "Sue Becker" ;

  addUserDisplay("scatter", createScatterDisplay) ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Insert any private initialization routines here */
  setCreateNetHook  (initNet) ;
  setDestroyNetHook (deinitNet) ;
  setCreateGroupHook(initGroup) ;
  setCreateUnitHook(initUnit) ;

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	allocates the memory for the net extension record
 *			and initializes some net parameters.
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet (net)
  Net	net ;
{
  net->calculateErrorDerivProc = calculateNetErrorDeriv ;
  net->calculateErrorProc      = calculateNetError ;

  net->extension = (NetExtension)calloc(1, sizeof(NetExtensionRec)) ;

  Mepsilon(net) = 0.0001 ;
  Mvariance(net) = 0.2 ;
  MinitialVariance(net) = 0.2 ;
  MminVariance(net) = 0.02 ;
  MepochVarianceDecay(net) = 0.98 ;
  MinitialLearningRate(net) = 0.0001 ;
  MlearningRateDecayFunction(net) = 1 ;
  MlearningRateDecay(net) = 1.0 ;
  Mtau(net) = 1000.0 ;
  Mmomentum(net) = 0.0 ;
  MdirectionMethod(net) = MZSTEEPEST;
  MstepMethod(net) = MZFIXEDSTEP;
  
}

/**********************************************************************/
static void	deinitNet (net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the update procedures for the units in
 *			a group. 
 *	Parameters:	
 *		Group	group - the group to set the procedures for
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup (group)
  Group	group ;
{
  group->groupActivityUpdateProc = softCompGroupActivityUpdate;
  group->unitActivityUpdateProc = NULL ;
  group->unitGradientUpdateProc = gradUpdate ;
  group->extension = (GroupExtension)calloc(1, sizeof(GroupExtensionRec)) ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		initUnit
 *	Description:	allocates space for unit extension fields.
 *	Parameters:	
 *		Unit	unit - the unit to allocate the space for
 *	Return Value:	NONE
 ***********************************************************************/
static void	initUnit (unit)
  Unit	unit ;
{
  unit->extension = (UnitExtension)calloc(1, sizeof(UnitExtensionRec)) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		softCompGroupActivityUpdate
 *	Description:
 *             Sets each unit's totalInput to the squared difference
 *	         between its weight vector and input vector
 *             Each unit's output is a Gaussian of the total input, normalized
 *               by the sum of Gaussians for the group.
 *             The global error is decremented by the log prob of this case.
 *      Parameters:
 *		Group	group - the group object ;
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static Proc	softCompGroupActivityUpdate (group)
  Group		group ;
{
  Real twoVar = Mvariance(group->net) * 2.0;
  Real sqrt2PiVar; 
  Real norm;
  Real npats;
  Real nunits = (Real)group->numUnits;

  if (group->type & INPUT) return;

  sqrt2PiVar = sqrt(twoVar * PI);
  norm = pow(sqrt2PiVar,(Real)group->numUnits);
  npats = MbatchSize(group->net);
  if (npats == 0) npats = (Real)group->net->trainingExampleSet->numExamples;

  group->extension->sumExpInput = 0.0;
  groupForAllUnits(group, setSoftCompSumExpInput, &twoVar);
  groupForAllUnits(group, setSoftCompOutput, &group->extension->sumExpInput);

  /* decrement net->error by log probability of this case
   *   = mean of each unit's Gaussian(totalInput)
   */
   group->net->error -= log(group->extension->sumExpInput /
			    (norm * nunits * npats));
}
/**********************************************************************/

/***********************************************************************
 *	Name:	     setSoftCompSumExpInput
 *	Description:
 *              Set the following unit fields:
 *                totalInput = squared difference between 
 *                             weight vector and current input vector
 *                expInput   =  gaussian of totalInput
 *              Increment group's sumExpInput field by unit's expInput.
 *	Parameters:	
 *		Unit   unit - the unit to set the totalInput for
 *		Data   *Real - a pointer to twoVar, double the global variance
 *	Return Value:	NONE
 ***********************************************************************/
static Proc  setSoftCompSumExpInput(unit, data)
  Unit unit;
  Real *data;
{
  Real  twoVar = *data;

  setSoftCompTotalInput(unit);

  unit->group->extension->sumExpInput +=
       (unit->extension->expInput = exp(- unit->totalInput / twoVar));
   }
/**********************************************************************/

static void setSoftCompTotalInput (unit)
  Unit unit;
{
  Real  totalInput = 0.0 ;
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	linkIndex ;

  for (linkIndex = 0 ; linkIndex < numIncoming ; ++linkIndex) {
     Link	link = incoming[linkIndex] ;
     totalInput += square(link->weight - (link->preUnit->output));
   }
   unit->totalInput = totalInput;
}
  
/***********************************************************************
 *	Name:	       setSoftCompOutput
 *	Description:   Set unit's output field (probability) to
 *                     expInput/sumExpInput 
 *	Parameters:	
 *		Unit  unit - the unit to set the output for
 *		Data  *Real - a pointer to unit->group->extension->sumExpInput
 *	Return Value:	NONE
 ***********************************************************************/
static Proc  setSoftCompOutput(unit, data)
  Unit unit;
  Real *data;
{
  Real  sumExpInput = *data;

  unit->output = unit->extension->expInput / sumExpInput;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		gradUpdate
 *	Description:	sets the deriv field on incoming links of unit
 *	Parameters:	
 *		Unit	unit - the unit to set the grads for
 *	Return Value:	NONE
 ***********************************************************************/
static Proc	gradUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real  prob = unit->output;
  int	idx ;
  
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;

    link->deriv -= prob * (link->preUnit->output - link->weight);
  }
}
/**********************************************************************/

/***********************************************************************
 *	Name:		calculateNetErrorDeriv
 *	Description:	gradient calculation procedure for scl net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;
  Real          varDecay = MepochVarianceDecay(net);
  Real          minVar = MminVariance(net);

  /* zero the net error and link derivative fields */
  net->error = 0.0 ;
  netForAllUnits(net, ~INPUT, zeroLinks, NULL) ;

  /* If its the first epoch - initialize the global variance
   *                        - initialize the learning rate
   */
  if (net->currentEpoch == 0) {
    Mepsilon(net) = MinitialLearningRate(net); 
    Mvariance(net) = MinitialVariance(net); }

  /* otherwise - decay the global variance
   *           - decay the learning rate
   */
  else {
    if (varDecay > 0.0 && Mvariance(net) > minVar) 
       Mvariance(net) = MAX(minVar, varDecay * Mvariance(net));
    switch (MlearningRateDecayFunction(net)) {
    case 1: Mepsilon(net) *= MlearningRateDecay(net);
            break;
    case 2: Mepsilon(net) /= (1.0 + (Real)McurrentEpoch(net) / Mtau(net));
            break;
    }}

  /* For each example	- do a forward pass updating the activities
   *			- update the derivatives
   */
  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
    MupdateNetGradients(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateNetErrorDeriv: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetError
 *	Description:	error calculation procedure for scl net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  net->error = 0.0 ;
  if (net->currentEpoch == 0) {
     Mvariance(net) = MinitialVariance(net);
   }

  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }
  if (numExamples <= 0)
    IErrorAbort("calculateNetError: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		calculateAveNetVariance
 *	Description:	set net initial variance and variance extension
 *                      fields to average of unit variances on exampleSet.
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateAveNetVariance(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;
  int		setSize = exampleSet->numExamples ;
  int           counter;

  /* calculate variance of each input pattern about each unit's mean
   * and store in unit->extension->variance
   */
  netForAllUnits(net, ~INPUT, zeroVariance, NULL) ;
  for (numExamples = 0 ; numExamples < setSize ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
    netForAllUnits(net, ~INPUT, incVariance, NULL) ;
  }
  if (numExamples <= 0)
    IErrorAbort("calculateAveNetVariance: no examples processed") ;
  netForAllUnits(net, ~INPUT, normalizeVariance, &setSize) ;

  /* calculate mean of units' variances,
   * and store in net->aveVariance
   */
  MaveVariance(net) = 0.0 ;
  MunitCounter(net) = 0.0 ;
  netForAllUnits(net, ~INPUT, incNetVariance, net) ;
  MaveVariance(net) /= (Real)MunitCounter(net);

}
/**********************************************************************/

/**********************************************************************/
static Proc	zeroLinks(unit, data)
  Unit	unit ;
  void	*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv  = 0.0;
}

/**********************************************************************/
static Proc	zeroVariance(unit, data)
  Unit	unit ;
  void	*data ;
{
  unit->extension->variance = 0.0;
}

/**********************************************************************/

static Proc	incVariance(unit, data)
  Unit	unit ;
  void	*data ;
{
  setSoftCompTotalInput(unit);
  unit->extension->variance += unit->totalInput;
}
/**********************************************************************/

static Proc	normalizeVariance(unit, data)
  Unit	unit ;
  int	*data ;
{
  int setSize = (int)(*data);
  unit->extension->variance /= (Real)setSize;
}

/**********************************************************************/

static Proc	incNetVariance(unit, data)
  Unit	unit ;
  void	*data;
{
  Net   net = (Net)data ;
  ++(MunitCounter(net));
  MaveVariance(net) += unit->extension->variance;
}

/**********************************************************************/
static Real	square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/**********************************************************************/


int command_suggestInitialVariance(tokc, tokv)
int tokc;
char *tokv[];
{
  Real initialVariance, finalVariance;
  if (GiveHelp(tokc)) {
    IUsage(" ");
    ISynopsis("Suggest initial and minimum variance values");
    IHelp(IHelpArgs, NULL);
    return 1;
  }
  calculateAveNetVariance(currentNet, currentNet->trainingExampleSet);
  initialVariance = initialVarianceProportion * MaveVariance(currentNet);
  finalVariance   = finalVarianceProportion   * initialVariance;

  fprintf(dout, "I suggest an initialVariance of %g which is %g times\n the average variance of the input vectors about each unit's weight vector. \nI also suggest a minVariance of %g\n which is %g times the suggested initialVariance.\n",
	  initialVariance, initialVarianceProportion,
	  finalVariance, finalVarianceProportion);

   return 1;
}

