
/**********************************************************************
 * $Id: rbp.c,v 1.3 92/11/30 12:01:14 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		      Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute, and sell this software
 * and its  documentation for  any purpose is  hereby granted  without
 * fee, provided that the above copyright notice appears in all copies
 * and  that both the  copyright notice  and   this  permission notice
 * appear in   supporting documentation,  and  that the  name  of  The
 * University  of Toronto  not  be  used in  advertising or  publicity
 * pertaining   to  distribution   of  the  software without specific,
 * written prior  permission.   The  University of   Toronto makes  no
 * representations  about  the  suitability of  this software  for any
 * purpose.  It  is    provided   "as is"  without express or  implied
 * warranty.
 *
 * THE UNIVERSITY OF TORONTO DISCLAIMS  ALL WARRANTIES WITH REGARD  TO
 * THIS SOFTWARE,  INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT SHALL THE UNIVERSITY OF TORONTO  BE LIABLE
 * FOR ANY  SPECIAL, INDIRECT OR CONSEQUENTIAL  DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR  PROFITS, WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF  OR  IN  CONNECTION WITH  THE  USE  OR PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

#include <stdio.h>
#include <values.h>
#include <math.h>
#include <nan.h>

#include <xerion/useful.h>
#include <xerion/simulator.h>

#include "rbp.h"
#include "help.h"

struct TRACE	forwardEachTime ;

static Real	unitError      ARGS((const Unit	unit, int	time)) ;
static Real	dotProduct     ARGS((Unit	unit, int	time)) ;
static void	backDotProduct ARGS((Unit	unit, int	time)) ;

static Proc	forward      ARGS((Unit	unit)) ;
static Proc	backward     ARGS((Unit	unit)) ;

static void	netForward  ARGS((Net	 net)) ;
static void	netBackward ARGS((Net	 net)) ;
static void	groupActivityUpdate ARGS((Group group, void	*data)) ;
static void	groupGradientUpdate ARGS((Group group, void	*data)) ;

static void	calculateNetErrorDeriv ARGS((Net net, ExampleSet exampleSet)) ;
static void	calculateNetError      ARGS((Net net, ExampleSet exampleSet)) ;

static void	initNet   ARGS((Net   net)) ;
static void	deinitNet ARGS((Net   net)) ;
static void	initUnit   ARGS((Unit   unit)) ;
static void	deinitUnit ARGS((Unit   unit)) ;
static void	initGroup ARGS((Group group)) ;

static Real	square     ARGS((double  x)) ;
static void	zeroUnit   ARGS((Unit	unit, void	*data)) ;
static void	zeroLinks  ARGS((Unit	unit, void	*data)) ;

static void	stepUnit   ARGS((Unit	unit, void	*data)) ;

/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command 
 *				  line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  /* Insert any private initialization routines here */
  setCreateNetHook(initNet) ;
  setDestroyNetHook(deinitNet) ;
  setCreateUnitHook(initUnit) ;
  setDestroyUnitHook(deinitUnit) ;
  setCreateGroupHook(initGroup) ;

  registerClassMask(TANH, GROUP_CLASS, "TANH") ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	sets the error calculation procedures for
 *			a network as well as allocating the memory
 *			for the extension record
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet (net)
  Net	net ;
{
  if (net->type & RECURRENT) {
    net->calculateErrorDerivProc = calculateNetErrorDeriv ;
    net->calculateErrorProc      = calculateNetError ;
    net->activityUpdateProc      = netForward ;
    net->gradientUpdateProc      = netBackward ;
  } else {
    IErrorAbort("Net must be recurrent") ;
  }

  net->extension = (NetExtension)calloc(1, sizeof(NetExtensionRec)) ;

  MzeroErrorRadius(net) = 0.2 ;
  Mrunning(net) = FALSE ;
}
/**********************************************************************/
static void	deinitNet (net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initUnit 
 *	Description:	allocates the memory for the extension record
 *	Parameters:	
 *		Unit	unit - the unit to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initUnit (unit)
  Unit	unit ;
{
  unit->extension = (UnitExtension)calloc(1, sizeof(UnitExtensionRec)) ;
}
/**********************************************************************/
static void	deinitUnit (unit)
  Unit	unit ;
{
  if (unit->extension != NULL)
    free(unit->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the update procedures for the units in
 *			a group. (forward, backward)
 *	Parameters:	
 *		Group	group - the group to set the procedures for
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup (group)
  Group	group ;
{
  group->unitActivityUpdateProc = forward ;
  group->unitGradientUpdateProc = backward ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitError
 *	Description:	calculates the error value of a unit assuming
 *			a zeroErrorRadius
 *	Parameters:	
 *		const Unit	unit - the unit to calclate the error of
 *		int		time - the time slice of the unit.
 *	Return Value:	
 *		Real		unitError - the error of the unit
 ***********************************************************************/
static Real	unitError(unit, time)
  const Unit	unit ;
  int		time ;
{
  Real		error ;
  Real		target = unit->targetHistory[time] ;
  Real		output = unit->outputHistory[time] ;

  if (MzeroErrorRadius(unit->net) <= 0.0) {
    error = (output - target) ;
  } else {
    Real	upper = target + MzeroErrorRadius(unit->net) ;
    Real	lower = target - MzeroErrorRadius(unit->net) ;
    upper = MIN(1.0, upper) ;
    lower = MAX(0.0, lower) ;
    if (output > upper)
      error = (output - upper) ;
    else if (output < lower)
      error = (output - lower) ;
    else
      error = 0.0 ;
  }

  return error ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		dotProduct
 *	Description:	calculates the dot product of all incoming
 *			links for a unit and stores it in the totalInput
 *			field of the unit.
 *	Parameters:	
 *		Unit	unit - the unit to calculate the dot 
 *				product for
 *		int	time - the time slice of the unit.
 *	Return Value:	
 *		Real	dotProduct - the dot product
 ***********************************************************************/
static Real	dotProduct(unit, time)
  Unit	unit ;
  int	time ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	totalInput ;
  int	idx ;

  totalInput = 0.0 ;
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link    = incoming[idx] ;

    totalInput += link->weight * link->preUnit->output ;
  }
  unit->totalInput = totalInput ;
  return totalInput ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		backDotProduct
 *	Description:	back propagates derivates through the incoming
 *			links of a unit and adds them to the pre units'
 *			outputDerivs and the links derivs.
 *	Parameters:	
 *		Unit	unit - the unit to back propagate from
 *		int		time - the time slice of the unit.
 *	Return Value:	NONE
 ***********************************************************************/
static void	backDotProduct(unit, time)
  Unit	unit ;
  int	time ;
{
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  Real		inputDeriv  = unit->inputDeriv ;
  int		idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link	= incoming[idx] ;
    Unit	preUnit = link->preUnit ;
  
    MnextOutputDeriv(preUnit) += inputDeriv*link->weight ;
    link->deriv += inputDeriv*preUnit->outputHistory[time-1] ;
  }
}
/**********************************************************************/


/***********************************************************************
 *	Name:		forward / backward
 *	Description:	update procedures for units
 *	Parameters:	
 *		Unit	unit - the unit of concern
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static Proc	forward(unit)
  Unit	unit ;
{
  int	time = unit->net->currentTime ;

  /* if not clamped, sum the inputs and pass the result through a sigmoid */
  if (IS_CLAMPED(unit, time)) {
    unit->outputHistory[time] = unit->clampingHistory[time] ;
  } else {
    dotProduct(unit, time) ;
    if (unit->group->type & TANH) 
      unit->outputHistory[time] = tanh(unit->totalInput) ;
    else
      unit->outputHistory[time] = sigmoid(unit->totalInput) ;
  }

  /* update the network error for output units */
  if (HAS_TARGET(unit, time))
    unit->net->error += square(unitError(unit, time)) ;
}
/**********************************************************************/
static Proc	backward(unit)
  Unit	unit ;
{
  int	time   = unit->net->currentTime ;
  Real	output = unit->outputHistory[time] ;
  Real	transferDeriv ;

  if (unit->group->type & TANH) 
    transferDeriv = tanhDeriv(output) ;
  else
    transferDeriv = sigmoidDeriv(output) ;

  /* add the deriv from the error if the unit has a target */
  if (HAS_TARGET(unit, time))
    unit->inputDeriv = 2.0*transferDeriv*unitError(unit, time) ;
  else
    unit->inputDeriv = 0.0 ;

  /* add the deriv from any backpropagated errors */
  unit->inputDeriv += transferDeriv*(unit->outputDeriv) ;

  /* propagate the derivs back through the incoming links */
  backDotProduct(unit, time) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetErrorDeriv
 *	Description:	gradient calculation procedure for backprop net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  Mrunning(net) = TRUE ;

  /* zero the net error and all derivative fields in the links */
  net->error = 0.0 ;
  netForAllUnits(net, ALL, zeroLinks, NULL) ;

  /* For each example	- zero the derivative fields in the units
   *			- do a forward pass updating the activities
   *			- do a backward pass updating the derivatives
   */
  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;

    netForAllUnits(net, ~BIAS, zeroUnit, NULL) ;
    MupdateNetActivities(net) ;
    MupdateNetGradients(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateNetErrorDeriv: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateNetError
 *	Description:	error calculation procedure for backprop net
 *			It processes 'MbatchSize(net)' examples
 *	Parameters:	
 *		Net		net - the net to use
 *		ExampleSet	exampleSet - the examples to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	calculateNetError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;
  int		time ;
  int		maxTime = net->timeSlices ;

  Mrunning(net) = TRUE ;

  net->error = 0.0 ;
  for (numExamples = 0 ; numExamples < MbatchSize(net) ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }
  if (numExamples <= 0)
    IErrorAbort("calculateNetError: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;

  Mrunning(net) = FALSE ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netForward
 *	Description:	calls the groupActivityUpdate for all groups
 *			(if not NULL) for all time slices
 *		Net	net - the net to act on
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	groupActivityUpdate (group, data)
  Group		group ;
  void		*data ;
{
  if (group->groupActivityUpdateProc != NULL)
    MupdateGroupActivities(group) ;
}
/**********************************************************************/
static void	netForward (net)
  Net		net ;
{
  int	time ;
  int	maxTime = net->timeSlices ;

  for (time = 0 ; time < maxTime ; ++time) {
    net->currentTime = time ;
    netForAllUnits  (net, ~BIAS, stepUnit,            NULL) ;
    netForAllGroups (net, ~BIAS, groupActivityUpdate, NULL) ;

    /* Do the display traces if we're not testing or training */
    if (Mrunning(net) != TRUE) {
      int	delay = 1000*MdelayCount(net) ;
      int	idx ;
      IDoTrace(&forwardEachTime) ;
      for (idx = 0; idx < delay ; idx++)
	;
    }
  }
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netBackward
 *	Description:	calls the groupGradientUpdate for all groups
 *			(if not NULL) for all time slices
 *		Net	net - the net to act on
 *	Return Value:	
 *		NONE
 ***********************************************************************/
static void	groupGradientUpdate (group, data)
  Group		group ;
  void		*data ;
{
  if (group->groupGradientUpdateProc != NULL)
    MupdateGroupGradients(group) ;
}
/**********************************************************************/
static void	netBackward (net)
  Net		net ;
{
  int	time = net->timeSlices ;

  for (--time ; time > 0 ; --time) {
    net->currentTime = time ;
    netForAllUnits      (net, ~BIAS, stepUnit,            NULL) ;
    netForAllGroupsBack (net, ALL,   groupGradientUpdate, NULL) ;
  }

  /* make sure proper activities are left in units */
  net->currentTime = net->timeSlices ;
  netForAllUnits (net, ~BIAS, stepUnit, NULL) ;
  --net->currentTime ;
}
/**********************************************************************/


/*********************************************************************
 *	Name:		zeroLinks
 *	Description:	zeroes the deriv fields in the incoming links
 *			to a unit
 *	Parameters:
 *	  Unit		unit - the unit whose links are to be zeroed
 *	  void		*data - UNUSED
 *	Return Value:
 *	  static void	zeroLinks - NONE
 *********************************************************************/
static void	zeroLinks(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  int	idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv = 0.0 ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		zeroUnit
 *	Description:	zeroes the deriv fields in a unit
 *	Parameters:
 *	  Unit		unit  - the unit to zero
 *	  void		*data - UNUSED
 *	Return Value:
 *	  static void	zeroUnit - NONE
 *********************************************************************/
static void	zeroUnit(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output      = 0.0 ;
  unit->inputDeriv  = 0.0 ;
  unit->outputDeriv = 0.0 ;
  MnextOutputDeriv(unit) = 0.0 ;
}
/********************************************************************/
static void	stepUnit(unit, data)
  Unit		unit ;
  void		*data ;
{
  int	lastTime = unit->net->currentTime - 1 ;

  if (lastTime >= 0)
    unit->output = unit->outputHistory[lastTime] ;
  else
    unit->output = 0.0 ;

  unit->outputDeriv = MnextOutputDeriv(unit) ;
  MnextOutputDeriv(unit) = 0.0 ;
}
/********************************************************************/


/*********************************************************************
 *	Name:		square
 *	Description:	squares a real valued number
 *	Parameters:
 *	  double	x - the number to square
 *	Return Value:
 *	  static Real	square - x^2
 *********************************************************************/
static Real	square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/********************************************************************/
