home *** CD-ROM | disk | FTP | other *** search
/ Usenet 1994 January / usenetsourcesnewsgroupsinfomagicjanuary1994.iso / sources / misc / volume14 / back-prop / part04 / misc.c < prev    next >
Encoding:
C/C++ Source or Header  |  1990-09-15  |  10.7 KB  |  457 lines

  1. /* **************************************************** */
  2. /* file misc.c:  contains pattern manipulation routines */
  3. /*               and miscellaneous other functions.     */
  4. /*                                                      */
  5. /* Copyright (c) 1990 by Donald R. Tveter               */
  6. /*                                                      */
  7. /* **************************************************** */
  8.  
  9. #include <stdio.h>
  10. #ifdef INTEGER
  11. #include "ibp.h"
  12. #else
  13. #include "rbp.h"
  14. #endif
  15.  
  16. extern short backoutput();
  17. extern void backinner();
  18. extern short cbackoutput();
  19. extern void cbackinner();
  20. extern WTTYPE rdr();
  21. extern WTTYPE readchar();
  22. extern void saveweights();
  23. extern WTTYPE scale();
  24. extern double unscaleint();
  25. extern void updatej();
  26. extern void updateo();
  27.  
  28. extern char backprop;
  29. extern FILE *data;
  30. extern char datafilename[50];
  31. extern UNIT *hlayer;
  32. extern UNIT *ilayer;
  33. extern char informat;
  34. extern UNIT *jlayer;
  35. extern UNIT *klayer;
  36. extern LAYER *last;
  37. extern int lastprint;
  38. extern int npats;
  39. extern int prevnpats;
  40. extern int readerror;
  41. extern int saverate;
  42. extern int skiprate;
  43. extern LAYER *start;
  44. extern char summary;
  45. extern WTTYPE toler;
  46. #ifdef INTEGER
  47. extern int totaldiff;
  48. #else
  49. extern double totaldiff;
  50. #endif
  51. extern int totaliter;
  52. extern int unlearnedpats;
  53. extern char update;
  54. extern WTTYPE wtlimit;
  55. extern char wtlimithit;
  56. extern int wttotal;
  57.  
  58. void nullpatterns()  /* dispose of any patterns before reading more */
  59. {
  60.   PATLIST *pl, *nextpl;
  61.   PATNODE *pn, *nextpn;
  62.   if (start->patstart != NULL)
  63.      {
  64.        pl = start->patstart;
  65.        nextpl = pl->next;
  66.        while (pl != NULL)
  67.           {
  68.             pn = pl->pats;
  69.             nextpn = pn->next;
  70.             while (pn != NULL)
  71.                {
  72.                  free(pn);
  73.                  pn = nextpn;
  74.                  nextpn = pn->next;
  75.                };
  76.             free(pl);
  77.             pl = nextpl;
  78.             nextpl = pl->next;
  79.           };
  80.        pl = last->patstart;
  81.        nextpl = pl->next;
  82.        while (pl != NULL)
  83.           {
  84.             pn = pl->pats;
  85.             nextpn = pn->next;
  86.             while (pn != NULL)
  87.                {
  88.                  free(pn);
  89.                  pn = nextpn;
  90.                  nextpn = pn->next;
  91.                };
  92.             free(pl);
  93.             pl = nextpl;
  94.             nextpl = pl->next;
  95.           };
  96.      };
  97.   start->patstart = NULL;
  98.   last->patstart = NULL;
  99.   npats = 0;
  100.   prevnpats = 0;
  101. }
  102.  
  103. void resetpats()
  104. {
  105.  start->currentpat = NULL;
  106.  last->currentpat = NULL;
  107. }
  108.  
  109. void findendofpats(layer)  /* purpose is to set all layer->currentpat */
  110. LAYER *layer;              /* fields to end of pattern list so more   */
  111.                            /* patterns can be added at the end.       */
  112. {
  113.  PATLIST *pl;
  114.  
  115.  pl = (PATLIST *) layer->patstart;
  116.  while (pl->next != NULL) pl = pl->next;
  117.  layer->currentpat = pl;
  118. }
  119.  
  120. int copyhidden(input,hidden,l)
  121. UNIT *input, **hidden;
  122. int l;
  123. {
  124.   if (hidden == NULL)
  125.      {
  126.        printf("ran out of hidden units in layer %d\n",l);
  127.        return(1);
  128.      }
  129.   input->oj = (*hidden)->oj;
  130.   *hidden = (*hidden)->next;
  131.   return(0);
  132. }
  133.  
  134. void nextpat()
  135. {
  136.   if (start->currentpat == NULL)
  137.      {
  138.        start->currentpat = start->patstart;
  139.        last->currentpat = last->patstart;
  140.      }
  141.   else
  142.      {
  143.        start->currentpat = (start->currentpat)->next;
  144.        last->currentpat = (last->currentpat)->next;
  145.      };
  146. }
  147.  
  148. void setonepat()       /* sets up patterns on input units */
  149. {
  150.   register PATNODE *p;
  151.   register UNIT *u;
  152.   register LAYER *innerlayers;
  153.   UNIT *hunit, *iunit, *junit, *kunit;
  154.   PATLIST *pl;
  155.   
  156.   hunit = hlayer;
  157.   iunit = ilayer;
  158.   junit = jlayer;
  159.   kunit = klayer;
  160.   pl = start->currentpat;
  161.   p = (PATNODE *) pl->pats;
  162.   u = (UNIT *) start->units;
  163.   while (p != NULL)
  164.      {
  165.        if (p->val > KCODE) u->oj = p->val;
  166.        else if (p->val == HCODE)
  167.                {if (copyhidden(u,&hunit,2) == 1) return;}
  168.        else if (p->val == ICODE)
  169.                {if (copyhidden(u,&iunit,3) == 1) return;}
  170.        else if (p->val == JCODE)
  171.                {if (copyhidden(u,&junit,4) == 1) return;}
  172.        else if (copyhidden(u,&kunit,5) == 1) return;
  173.        u = u->next;
  174.        p = p->next;
  175.      };
  176.  
  177.   innerlayers = start->next;
  178.   while (innerlayers->next != NULL)
  179.      {  /* set errors on the inner layer units to 0 */
  180.        u = (UNIT *) innerlayers->units;
  181.        while (u != NULL)
  182.           {
  183.             u->error = 0;
  184.             u = u->next;
  185.           };
  186.        innerlayers = innerlayers->next;
  187.      };
  188. }
  189.  
  190. void limitwts()
  191. {
  192.   register LAYER *layer;
  193.   register UNIT *u;
  194.   register WTNODE *w;
  195.  
  196.   layer = start->next;
  197.   while (layer != NULL)
  198.    {
  199.     u = (UNIT *) layer->units;
  200.     while (u != NULL)
  201.      {
  202.       w = (WTNODE *) u->wtlist;
  203.       while (w != NULL)
  204.        {
  205. #ifdef SYMMETRIC
  206.         if (*(w->weight) > wtlimit)
  207.            {
  208.              *(w->weight) = wtlimit;
  209.              wtlimithit = 1;
  210.            }
  211.         else if (*(w->weight) < -wtlimit)
  212.            {
  213.              *(w->weight) = -wtlimit;
  214.              wtlimithit = 1;
  215.            };
  216. #else
  217.         if (w->weight > wtlimit)
  218.            {
  219.              w->weight = wtlimit;
  220.              wtlimithit = 1;
  221.            }
  222.         else if (w->weight < -wtlimit)
  223.            {
  224.              w->weight = -wtlimit;
  225.              wtlimithit = 1;
  226.            };
  227. #endif
  228.         w = w->next;
  229.        };
  230.       u = u->next;
  231.      };
  232.     layer = layer->next;
  233.    };
  234. }
  235.  
  236. #ifndef SYMMETRIC
  237.  
  238. void whittle(amount)    /* removes weights whose absolute */
  239. WTTYPE amount;          /* value is less than amount      */
  240. {LAYER *layer;
  241.  UNIT *u;
  242.  WTNODE *w, *wprev;
  243.  
  244.  layer = start->next;
  245.  while (layer != NULL)
  246.    {
  247.      u = (UNIT *) layer->units;
  248.      while (u != NULL)
  249.        {
  250.          w = (WTNODE *) u->wtlist;
  251.          wprev = (WTNODE *) NULL;
  252.          while (w->next != (WTNODE *) NULL)
  253.            {
  254.              if ((w->weight) < amount && (w->weight) > -amount)
  255.                {
  256.                  if (wprev == NULL) (WTNODE *) u->wtlist = w->next;
  257.                  else (WTNODE *) wprev->next = w->next;
  258.                  wttotal = wttotal - 1;
  259.                }
  260.              else wprev = w;
  261.              w = w->next;
  262.            }
  263.          u = u->next;
  264.        }
  265.      layer = layer->next;
  266.    }
  267. }
  268.  
  269. #endif
  270.  
  271. void oneset() /* go through the patterns once and update weights */
  272. { int i;
  273.   LAYER *layer;
  274.   register UNIT *u;
  275.   register WTNODE *w;
  276.   short numbernotclose, attempted, passed;
  277.  
  278. begin:
  279.  layer = last;      /* make all b->totals = 0 */
  280.  while (layer->backlayer != NULL)
  281.     {
  282.       u = (UNIT *) layer->units;
  283.       while (u != NULL)
  284.          {
  285.            w = (WTNODE *) u->wtlist;
  286.            while (w != NULL)
  287.               {
  288. #ifdef SYMMETRIC
  289.                 *(w->total) = 0;
  290. #else
  291.                 w->total = 0;
  292. #endif
  293.                 w = w->next;
  294.               };
  295.            u = u->next;
  296.          };
  297.       layer = layer->backlayer;
  298.     };
  299.  attempted = 0;
  300.  passed = 0;
  301.  unlearnedpats = npats;
  302.  resetpats();
  303.  for(i=1;i<=npats;i++)
  304.     {
  305.       nextpat();
  306.       if (last->currentpat->bypass <= 0)
  307.          {
  308.            setonepat();
  309.            forward();
  310.            attempted = attempted + 1;
  311.            if (update == 'c' || update == 'C')
  312.               numbernotclose = cbackoutput();
  313.            else numbernotclose = backoutput();
  314.            if (numbernotclose != 0)
  315.               {
  316. #ifndef SYMMETRIC
  317.                 if (update == 'c' || update == 'C') cbackinner();
  318.                 else backinner();
  319. #endif
  320.               }
  321.            else /* this one pattern has been learned */
  322.               {
  323.                 passed = passed + 1;
  324.                 unlearnedpats = unlearnedpats - 1;
  325.                 last->currentpat->bypass = skiprate;
  326. #ifndef SYMMETRIC
  327.                 if (backprop)
  328.                    {
  329.                      if (update == 'c' || update == 'C') cbackinner();
  330.                      else backinner();
  331.                    };
  332. #endif
  333.               }
  334.          }
  335.       else last->currentpat->bypass = last->currentpat->bypass - 1;
  336.     };
  337.  if (unlearnedpats == 0) return;
  338.  if (attempted == passed)
  339.     {
  340.       resetpats();
  341.       for (i=1;i<=npats;i++)
  342.          {
  343.            nextpat();
  344.            last->currentpat->bypass = 0;
  345.          };
  346.       goto begin;
  347.     };
  348.  if (update == 'j') updatej();
  349.  else if (update == 'o' || update == 'd') updateo();
  350.  if (wtlimit != 0) limitwts();
  351. }
  352.  
  353. void kick(size,amount) /* give the network a kick */
  354. WTTYPE size;
  355. WTTYPE amount;
  356. { LAYER *layer;
  357.   UNIT *u;
  358.   WTNODE *w;
  359.   WTTYPE value;
  360.   WTTYPE delta;
  361.   int sign;
  362.  
  363.   layer = start->next;
  364.   while (layer != NULL)
  365.    {
  366.     u = (UNIT *) layer->units;
  367.     while (u != NULL)
  368.      {
  369.       w = (WTNODE *) u->wtlist;
  370.       while (w != NULL)
  371.        {
  372. #ifdef SYMMETRIC
  373.          value = *(w->weight);
  374. #else
  375.          value = w->weight;
  376. #endif
  377.          if (value != 0) sign = 1;
  378.          else if (rand() > 16383) sign = -1;
  379.          else sign = 1;
  380.          delta = (sign * amount * rand()) / 32768;
  381.          if (value >= size) value = value - delta;
  382.          else if (value < -size) value = value + delta;
  383. #ifdef SYMMETRIC
  384.          if (((UNIT *) w->backunit)->unitnumber != u->unitnumber)
  385.             *(w->weight) = value;
  386. #else
  387.          w->weight = value;
  388. #endif
  389.          w = w->next;
  390.        }
  391.       u = u->next;
  392.      }
  393.     layer = layer->next;
  394.    } 
  395. }
  396.  
  397. void printpats(first,finish,printheader,printerrors,callfromrun)
  398. int first,finish,printheader,printerrors,callfromrun;
  399. {
  400.   int i;
  401.   double err;
  402.  
  403.   if (summary == '+' && callfromrun)
  404.      {
  405.        printf("%6d   ",totaliter);
  406.        printf("%6d learned ",npats-unlearnedpats);
  407.        printf("%6d unlearned     ",unlearnedpats);
  408.        err = unscaleint(totaldiff) / (npats * last->unitcount);
  409.        printf("%7.5lf error/unit\n",err);
  410.        return;
  411.      };
  412.   lastprint = totaliter;
  413.   if (printheader == 1)
  414.      printf("%d iterations, file = %s\n",totaliter,datafilename);
  415.   resetpats();
  416.   for (i=2;i<=first;i++) nextpat();
  417.   for (i=first;i<=finish;i++)
  418.      { 
  419.        nextpat();
  420.        setonepat();
  421.        printf("%3d ",i);
  422.        forward();
  423.        printoutunits(last,printerrors);
  424.      };
  425. }
  426.  
  427. void run(n,prpatsrate)
  428. int n;              /* the number of iterations to run */
  429. int prpatsrate;     /* rate at which to print output patterns */
  430.  
  431. { int i;
  432.   char wtlimitbefore;
  433.  
  434.   printf("running . . .\n");
  435.   for (i=1;i<=n;i++)
  436.     {
  437.       totaldiff = 0;
  438.       wtlimitbefore = wtlimithit;
  439.       oneset();
  440.       totaliter = totaliter + 1;
  441.       if (wtlimitbefore == 0 && wtlimithit == 1)
  442.          printf(">>>>> WEIGHT LIMIT HIT <<<<< at %d\n",totaliter);
  443.       if (unlearnedpats == 0)
  444.         {
  445.           if (update != 'c' && update != 'C') totaliter = totaliter - 1;
  446.           if ((prpatsrate > 0) && (lastprint != totaliter))
  447.              printpats(1,npats,1,1,1);
  448.           printf("patterns learned to within %4.2lf",unscale(toler));
  449.           printf(" at iteration %d\n",totaliter);
  450.           return;
  451.         };
  452.       if (totaliter % saverate == 0) saveweights();
  453.       if ((prpatsrate > 0) && ((i % prpatsrate == 0) || (i == n)))
  454.          printpats(1,npats,1,1,1);
  455.     };
  456.