home *** CD-ROM | disk | FTP | other *** search
/ Collection of Education / collectionofeducationcarat1997.iso / COMPUSCI / NNUTL101.ZIP / NNWHERE / NNBKPROP.C next >
C/C++ Source or Header  |  1993-07-12  |  6KB  |  168 lines

  1. /*-----------------------------------------------------------------------*
  2.  * Greg Stevens                                                   6/24/93*
  3.  *                              NNBKPROP.C                               *
  4.  *                                             [file 6 in a series of 6] *
  5.  *                                                                       *
  6.  * This file contains the functions for calculating the error and weight *
  7.  * changes for the backpropagation algorithm for the network.  Defined   *
  8.  * in this file are EPSILON, the weight change increment/coefficient,    *
  9.  * and function that updates the weights [UpDateWeightandThresh()].  It  *
  10.  * also contains code for a function called GetDerivs(), but this is to  *
  11.  * be used in the function UpDateWeightsandDerivs(), not by a main       *
  12.  * program.  It also contains a function InitOutPatterns, which is       *
  13.  * similar to InitPatterns for the input patterns, but takes from a      *
  14.  * different file that would contain the corresponding desired output    *
  15.  * patterns for the set of input patterns.                               *
  16.  *                                                                       *
  17.  *-----------------------------------------------------------------------*/
  18. #include "nnloadin.c"
  19.  
  20. #define EPSILON 0.25        /* constant incrementation for weight change */
  21.  
  22. /* type for holding error values for weight connections */
  23.  
  24. typedef struct 
  25.            {
  26.               float e[ NUMLAYERS ][ MAXNODES ][ MAXNODES ];
  27.            } wERRORtype;
  28.  
  29. /* type for holding error values for threshhold weights */
  30.  
  31. typedef struct
  32.            {
  33.               float e[ NUMLAYERS ][ MAXNODES ];
  34.            } tERRORtype;
  35.  
  36. /* Function Prototypes */
  37. PATTERNtype InitOutPatterns( void );
  38. tERRORtype GetDerivs( NNETtype n, PATTERNtype GoalOut, int Pattern );
  39. NNETtype UpDateWeightandThresh(NNETtype nn, PATTERNtype goal, int p );
  40.  
  41. /* Function Definitions */
  42. PATTERNtype InitOutPatterns( void )
  43. {
  44.    FILE *InFile;                                        /*file w/ pattern*/
  45.                                                         /*data           */
  46.    PATTERNtype patns;                                   /*stores patterns*/
  47.    float val;                                           /*pattern value  */
  48.    int P,U;                                             /*loop variables */
  49.  
  50.    InFile = fopen( "nnoutput.dat", "rt" );              /*open: read text*/
  51.  
  52.    if ( InFile==NULL )                                  /* if no file... */
  53.      {
  54.        printf( "File nnoutput.dat does not exist!\n" ); /*error message  */
  55.        return( patns );                                 /*leaves function*/
  56.      }
  57.  
  58.    for (P=0; (P<NUM_PATTERNS); ++P)              /* for each pattern.... */
  59.      for (U=0; (U<OUTPUT_LAYER_SIZE); ++U)       /*  for each unit in it:*/
  60.        {
  61.          fscanf( InFile, "%f", &val );
  62.          patns.p[P][U] = val;
  63.        }
  64.  
  65.    fclose( InFile );
  66.  
  67.    return( patns );
  68. }
  69.  
  70.  
  71. tERRORtype GetDerivs(NNETtype n, PATTERNtype GoalOut, int pattern)
  72. {
  73.    int layer;                    /* looping variables */
  74.    int node;
  75.    int tonode;
  76.  
  77.    tERRORtype Deriv1;            /* for holding dE/dy */
  78.    tERRORtype Deriv2;            /* for holding dE/ds */
  79.  
  80.    layer = NUMLAYERS - 1;        /* set layer to output layer */
  81.  
  82.    /* calculate dE/dy for output nodes */
  83.    for (node=0; (node<NUMNODES[layer]); ++node)  /* for each output node */
  84.     {
  85.      Deriv1.e[layer][node]=GoalOut.p[pattern][node]-n.unit[layer][node].state;
  86.     }
  87.  
  88.    /* calculate dE/ds for output nodes */
  89.    for (node=0; (node<NUMNODES[layer]); ++node)
  90.      {
  91.        if (n.unit[layer][node].actfn==0)       /* if it's a linear node...  */
  92.          Deriv2.e[layer][node] = Deriv1.e[layer][node];
  93.        else if (n.unit[layer][node].actfn==1)  /* if it's a logistic node...*/
  94.          Deriv2.e[layer][node] = Deriv1.e[layer][node] *
  95.                                  n.unit[layer][node].state *
  96.                                  (1.0 - n.unit[layer][node].state);
  97.      }
  98.  
  99.    /* calculate  dE/dy and dE/ds for hidden layers (backwards from output,*/
  100.    /* not including input layer).                                         */
  101.    for (layer=NUMLAYERS-2; (layer>0); --layer )
  102.      {
  103.        /* calculate dE/dy */
  104.        for (node=0; (node<NUMNODES[layer]); ++node)
  105.          {
  106.            Deriv1.e[layer][node] = 0;
  107.            for (tonode=0; (tonode<NUMNODES[layer+1]); ++tonode)
  108.              {
  109.                Deriv1.e[layer][node] += Deriv2.e[layer+1][tonode] *
  110.                                         n.unit[layer+1][tonode].weights[node];
  111.              }
  112.          }
  113.  
  114.        /* calculate dE/ds */
  115.        for (node=0; (node<NUMNODES[layer]); ++node)
  116.          {
  117.            if (n.unit[layer][node].actfn==0)
  118.              Deriv2.e[layer][node] = Deriv1.e[layer][node];
  119.            else if (n.unit[layer][node].actfn==1)
  120.              Deriv2.e[layer][node] = Deriv1.e[layer][node] *
  121.                                      n.unit[layer][node].state *
  122.                                      (1.0 - n.unit[layer][node].state);
  123.          }
  124.      }
  125.  
  126.    return( Deriv2 );   /* return dE/ds for each layer */
  127. }
  128.  
  129. NNETtype UpDateWeightandThresh(NNETtype nn, PATTERNtype goal, int p )
  130. {
  131.    NNETtype newnet;
  132.    wERRORtype WeightError;
  133.    tERRORtype ThreshError;
  134.    tERRORtype Derivs;
  135.    int layer, unit, inunit;
  136.  
  137.    /* find WeightError and ThreshError */
  138.  
  139.    Derivs = GetDerivs( nn, goal, p );
  140.    for (layer=1; (layer<NUMLAYERS); ++layer )
  141.      {
  142.        for (unit=0; (unit<NUMNODES[ layer ]); ++unit )
  143.          {
  144.            ThreshError.e[ layer ][ unit ] = Derivs.e[ layer ][ unit ];
  145.  
  146.            for (inunit=0; (inunit<NUMNODES[ layer-1 ]); ++inunit)
  147.              {
  148.                 WeightError.e[layer][unit][inunit] = Derivs.e[layer][unit] *
  149.                                               nn.unit[layer-1][inunit].state;
  150.              }
  151.          }
  152.      }
  153.  
  154.    /* Change Weights */
  155.    newnet = nn;
  156.    for (layer=1; (layer<NUMLAYERS); ++layer)
  157.      for (unit=0; (unit<NUMNODES[layer]); ++unit)
  158.       {
  159.         newnet.unit[layer][unit].thresh += EPSILON*ThreshError.e[layer][unit];
  160.  
  161.         for (inunit=0; (inunit<NUMNODES[layer-1]); ++inunit)
  162.           newnet.unit[layer][unit].weights[inunit] += EPSILON*
  163.                                            WeightError.e[layer][unit][inunit];
  164.       }
  165.  
  166.    return( newnet );
  167. }
  168.