home *** CD-ROM | disk | FTP | other *** search
/ Usenet 1994 January / usenetsourcesnewsgroupsinfomagicjanuary1994.iso / sources / misc / volume14 / back-prop / part02 / bp.c
Encoding:
C/C++ Source or Header  |  1990-09-15  |  35.3 KB  |  1,129 lines

  1. /* ************************************************** */
  2. /* file bp.c:  contains the main program and network  */
  3. /*             creation routines.                     */
  4. /*                                                    */
  5. /* Copyright (c) 1990 by Donald R. Tveter             */
  6. /*                                                    */
  7. /* ************************************************** */
  8.  
  9. #include <stdio.h>
  10. #include <malloc.h>
  11. #include <signal.h>
  12. #include <setjmp.h>
  13. #define SIGINT 2
  14. #define MAXINT 2147483647
  15.  
  16. #ifdef INTEGER
  17. #include "ibp.h"
  18. #else
  19. #include "rbp.h"
  20. #endif
  21.  
  22. extern int rand();           /* built-in C function */
  23. extern int srand();          /* built-in C function */
  24.  
  25. extern void forward();       /* from int.c or real.c */
  26.  
  27. #ifdef INTEGER
  28. extern int scale();          /* from io.c */
  29. extern double unscale();     /* from io.c */
  30. #endif
  31.  
  32. extern int copyhidden();     /* from misc.c */
  33. extern void findendofpats(); /* from misc.c */
  34. extern void help();          /* from io.c */
  35. extern void kick();          /* from misc.c */
  36. extern void nullpatterns();  /* from misc.c */
  37. extern void printoutunits(); /* from misc.c */
  38. extern void printpats();     /* from misc.c */
  39. extern void printweights();  /* from io.c */
  40. extern void run();           /* from misc.c */
  41. extern WTTYPE rdr();         /* from io.c */
  42. extern int readch();         /* from io.c */
  43. extern double readchar();    /* from io.c */
  44. extern int readint();        /* from io.c */
  45. extern void restoreweights();/* from io.c */
  46. extern void saveweights();   /* from io.c */
  47. extern void texterror();     /* from io.c */
  48. extern void whittle();       /* from misc.c */
  49.  
  50. /* global variables used in all versions */
  51.  
  52. char activation;      /* activation function, p or s */
  53. WTTYPE alpha;         /* momentum term */
  54. char backprop;        /* flags whether to back propagate error for */
  55.                       /* units close to their targets */
  56. int bufferend;        /* index of last character in input line */
  57. int bufferptr;        /* position of next character in buffer */
  58. char buffer[buffsize];/* holds contents of one input line */
  59. int ch;               /* general purpose character variable */
  60. char cmdfilename[50]; /* name of file to take extra commands from */
  61. jmp_buf cmdloopstate; /* to save state in case of a SIGINT */
  62. WTTYPE D;             /* sigmoid sharpness */
  63. FILE *data;           /* file for original data */
  64. char datafilename[50];/* copy of the data file name saved here */
  65. WTTYPE dbdeta;        /* the initial eta value for the DBD method */
  66. WTTYPE decay;         /* the decay parameter for the DBD method */
  67. char deriv;           /* flags type of derivative to use */
  68. int echo;             /* controls echoing of characters during input */
  69. WTTYPE eta;           /* basic learning rate */
  70. WTTYPE eta2;          /* DSZ learning rate for inner layers */
  71. WTTYPE etamax;        /* the maximum eta for the DBD method */
  72. int extraconnect;     /* flags the use of connections between */
  73.                       /* non-adjacent layers */
  74. int format[maxformat];/* each value in format indicates where to put */
  75.                       /* a blank for compressed output mode or a */
  76.                       /* carriage return for real output */
  77. UNIT *hlayer;         /* pointer to list of units in second layer */
  78. UNIT *ilayer;         /* pointer to list of units in third layer */
  79. char informat;        /* controls format to read numbers */
  80. WTTYPE initialkick;   /* the range weights are initialized to */
  81. int iter;             /* for counting iterations in one run */
  82. UNIT *jlayer;         /* pointer to list of units in fourth layer */
  83. WTTYPE kappa;         /* the DBD learning parameter */
  84. UNIT *klayer;         /* pointer to list of units in fifth layer */
  85. LAYER *last;          /* has address of the output layer */
  86. int lastprint;        /* last iteration pattern responses printed */
  87. int lastsave;         /* last time weights were saved */
  88. short nlayers;        /* number of layers in network */
  89. int npats;            /* number of patterns currently in use */
  90. char outformat;       /* controls format to print output */
  91. int prevnpats;        /* previous number of patterns, initially 0 */
  92. WTTYPE qmark;         /* value for ? in compressed input */
  93. int readerror;        /* flags an error in reading a value */
  94. int readingpattern;   /* flags reading pattern state */
  95. int saverate;         /* rate at which to save weights */
  96. unsigned seed;        /* seed for generating random weights */
  97. short skiprate;       /* number of times to bypass a learned pattern */
  98. LAYER *start;         /* has address of the input layer */
  99. char summary;         /* flags summary output mode */
  100. WTTYPE theta1;        /* the DBD parameter */
  101. WTTYPE theta2;        /* 1 - theta1 */
  102. WTTYPE toler;         /* value used in testing for completion */
  103. WTTYPE toosmall;      /* weights smaller than toosmall were removed */
  104. #ifdef INTEGER
  105. int totaldiff;        /* totals errors to find average error per unit */
  106. #else
  107. double totaldiff;
  108. #endif
  109. int totaliter;        /* counts total iterations for the program */
  110. int unlearnedpats;    /* number unlearned in last learning cycle */
  111. char update;          /* flags type of update rule to use */
  112. char wtformat;        /* controls format to save and restore weights */
  113. WTTYPE wtlimit;       /* adjustable limit on weights */
  114. char wtlimithit;      /* flags whether the limit has been hit */
  115. int wttotal;          /* total number of weights in use */
  116.  
  117. /* global variable for the symmetric integer version */
  118.  
  119. #ifdef SYMMETRIC
  120. WTTYPE  stdthresh;    /* the standard threshold weight value */
  121. #endif
  122.  
  123. UNIT *locateunit(layerno,unitno)  /* given a layer number and unit */
  124. int layerno, unitno;              /* number this routine returns the */
  125. {int i;                           /* address of the unit */
  126.  UNIT *u;
  127.  LAYER *layer;
  128.  
  129.  if (layerno >= 1 && layerno <= nlayers)
  130.     {
  131.       layer = start;
  132.       for(i=1;i<=(layerno-1);i++) layer = layer->next;
  133.       u = (UNIT *) layer->units;
  134.       while (u != NULL && u->unitnumber != unitno) u = u->next;
  135.       if (u == NULL)
  136.          printf("there is no unit %3d in layer %3d\n",unitno,layerno);
  137.     }
  138.  else
  139.     {
  140.       printf("there is no layer %3d\n",layerno);
  141.       return(NULL);
  142.     };
  143.  return(u);     
  144. }
  145.  
  146. #ifdef SYMMETRIC
  147.  
  148. int wtaddress(i,j,biasunit,type,size) /* Returns the address of a */
  149. int i,j;                              /* weight (1), olddw (2),   */
  150. int biasunit;                         /* eta (3) or total (4).    */
  151. int type;                             /* One is created if it     */
  152. int size;                             /* doesn't already exist.   */
  153.  
  154. { int k, addr;
  155.   UNIT *u;
  156.   WTNODE *w;
  157.  
  158.   if (biasunit == 1) addr = (int) malloc(size);
  159.   else if (j >= i) addr = (int) malloc(size);
  160.   else /* the item already exists, so find its address */
  161.      {
  162.        u = locateunit(2,j);
  163.        w = (WTNODE *) u->wtlist;
  164.        k = 1;
  165.        while (k < i)
  166.           {
  167.             w = w->next;
  168.             k = k + 1;
  169.           };
  170.        if (type == 1) addr = (int) w->weight;
  171.        else if (type == 2) addr = (int) w->olddw;
  172.        else if (type == 3) addr = (int) w->eta;
  173.                else addr = (int) w->total;
  174.      };
  175.   return(addr);
  176. }
  177.  
  178. void setweight(w,i,j,biasunit) /* set initial values in w */
  179. WTNODE *w;
  180. int i, j;
  181. int biasunit;
  182. {WTTYPE *s;
  183.  
  184.   s = (WTTYPE *) wtaddress(i,j,biasunit,1,WTSIZE);
  185.   *s = 0;
  186.   w->weight = s;
  187.   s = (WTTYPE *) wtaddress(i,j,biasunit,2,WTSIZE);
  188.   *s = 0;
  189.   w->olddw = s;
  190.   s = (WTTYPE *) wtaddress(i,j,biasunit,3,WTSIZE);
  191.   *s = eta;
  192.   w->eta = s;
  193. #ifdef INTEGER
  194.   w->total = (int *) wtaddress(i,j,biasunit,4,sizeof(int));
  195. #else
  196.   w->total = (double *) wtaddress(i,j,biasunit,4,WTSIZE);
  197. #endif
  198. }
  199.  
  200. #else
  201.  
  202. void setweight(w,i,j,biasunit) /* set initial values in w */
  203. WTNODE *w;
  204. short i,j;
  205. int biasunit;
  206. {
  207.   w->weight = 0;
  208.   w->olddw = 0;
  209.   w->eta = dbdeta;
  210. }
  211.  
  212. #endif
  213.  
  214. LAYER *mklayer(prevlayer,n)  /* creates a layer of n units, pointers */
  215. LAYER *prevlayer;            /* and weights back to the units in the */
  216. int n;                       /* previous layer and links this new */
  217.                              /* layer into the list of layers */
  218. {UNIT *front, *p, *q, *bias, *prev, *ptr;
  219.  WTNODE *wfront, *wprev, *w;
  220.  LAYER *lptr;
  221.  int i, j, count;
  222.  
  223. /* make a list of nodes in this layer */
  224.  
  225.  count = 1;
  226.  front = (UNIT *) malloc(sizeof(UNIT));
  227.  front->unitnumber = count;
  228.  front->layernumber = nlayers;
  229.  prev = front;
  230.  for(i=1;i<n;i++)
  231.     {
  232.       count = count + 1;
  233.       ptr = (UNIT *) malloc(sizeof(UNIT));
  234.       prev->next = ptr;
  235.       ptr->unitnumber = count;
  236.       ptr->layernumber = nlayers;
  237.       prev = ptr;
  238.     };
  239.  prev->next = NULL;
  240.  
  241. /* make a LAYER node to point to this list of units */
  242.  
  243.  lptr = (LAYER *) malloc(sizeof(LAYER));
  244.  lptr->unitcount = n;
  245.  lptr->patstart = NULL;
  246.  lptr->currentpat = NULL;
  247.  lptr->backlayer = prevlayer;
  248.  lptr->next = NULL;
  249.  (UNIT *) lptr->units = front;   /* connect the list of units */
  250.  
  251. /* return if this is the input layer */
  252.  
  253.  if (prevlayer == NULL) return(lptr);
  254.  prevlayer->next = lptr;
  255.  
  256. /* If we are working on a deeper layer, for every node in this layer, */
  257. /* create a linked list back to units in the previous layer. */
  258.  
  259.  i = 1;
  260.  q = front;
  261.  while (q != NULL) /* do a unit */
  262.    {    
  263.      j = 1;            /* handle first connection */
  264.      p = (UNIT *) prevlayer->units;
  265.      wfront = (WTNODE *) malloc(sizeof(WTNODE));
  266.      wttotal = wttotal + 1;
  267.      (WTNODE *) q->wtlist = wfront;
  268.      wprev = wfront;
  269.      (UNIT *) wfront->backunit = p;
  270.      setweight(wfront,i,j,0);
  271.      p = p->next;
  272.      while (p != NULL) /* handle rest of connections */
  273.         {
  274.           j = j + 1;
  275.           w = (WTNODE *) malloc(sizeof(WTNODE));
  276.           wttotal = wttotal + 1;
  277.           wprev->next = w;
  278.           (UNIT *) w->backunit = p;
  279.           setweight(w,i,j,0);
  280.           wprev = w;
  281.           p = p->next;
  282.         };
  283.      j = j + 1;
  284.      bias = (UNIT *) malloc(sizeof(UNIT));   /* create a bias unit */
  285.      bias->oj = scale(1.0);
  286.      bias->layernumber = nlayers;
  287.      bias->unitnumber = 32767;           /* bias unit is unit 32767 */
  288.      w = (WTNODE *) malloc(sizeof(WTNODE)); /* connect to end of list */
  289.      wttotal = wttotal + 1;
  290.      wprev->next = w;
  291.      (UNIT *) w->backunit = bias;
  292.      setweight(w,n+2,i,1);
  293.      w->next = NULL;
  294.      q = q->next;
  295.      i = i + 1;
  296.    };
  297.  return(lptr);
  298. }
  299.  
  300. #ifndef SYMMETRIC
  301.  
  302. void connect(a,b,range)  /* add a connection from unit a to unit b */
  303. UNIT *a, *b;             /* connections go in increasing order */
  304. WTTYPE range;
  305.  
  306. {WTNODE *wnew, *w, *wprev;
  307.  UNIT *wunit;
  308.  int farenough;
  309.  
  310.  wnew = (WTNODE *) malloc(sizeof(WTNODE));
  311.  wttotal = wttotal + 1;
  312.  wnew->eta = dbdeta;
  313.  wnew->weight = range * rand() / 32768;
  314.  if (rand() > 16383) wnew->weight = -wnew->weight;
  315.  wnew->olddw = 0;
  316.  (UNIT *) wnew->backunit = a;
  317.  w = (WTNODE *) b->wtlist;
  318.  wprev = NULL;
  319.  wunit = (UNIT *) w->backunit;
  320.  farenough = 0;                  /* insert the weight in order */
  321.  while (w != NULL && !farenough)
  322.     if (wunit->layernumber > a->layernumber) farenough = 1;
  323.     else if (wunit->layernumber == a->layernumber)
  324.             {
  325.               while (w != NULL && !farenough)
  326.                  {
  327.                    if (wunit->unitnumber < a->unitnumber &&
  328.                        wunit->layernumber == a->layernumber)
  329.                       {
  330.                         wprev = w;
  331.                         w = w->next;
  332.                         wunit = (UNIT *) w->backunit;
  333.                       }
  334.                    else farenough = 1;
  335.                  };
  336.             }      
  337.     else
  338.        {
  339.          wprev = w;
  340.          w = w->next;
  341.          wunit = (UNIT *) w->backunit;
  342.        }
  343.  if (wprev == NULL)
  344.     {
  345.       wnew->next = w;
  346.       (WTNODE *) b->wtlist = wnew;
  347.     }
  348.  else
  349.     {
  350.       wnew->next = w;
  351.       wprev->next = wnew;
  352.     };
  353. }
  354.  
  355. void addhiddenunit(layerno,range)
  356. int layerno;  /* add hidden unit to end of the layer */
  357. WTTYPE range;
  358. {
  359.  LAYER *lptr, *prevlayer, *nextlayer;
  360.  UNIT *u, *prevu, *p, *bias;
  361.  WTNODE *wnode;
  362.  int i, unitno;
  363.  
  364.  lptr = start;
  365.  for (i=1;i <= (layerno - 1); i++) lptr = lptr->next;
  366.  unitno = lptr->unitcount;
  367.  lptr->unitcount = unitno + 1;
  368.  prevu = locateunit(layerno,unitno);
  369.  if (prevu == NULL) return;
  370.  u = (UNIT *) malloc(sizeof(UNIT));
  371.  prevu->next = u;
  372.  u->next = NULL;
  373.  u->unitnumber = unitno + 1;
  374.  u->layernumber = layerno;
  375.  bias = (UNIT *) malloc(sizeof(UNIT));
  376.  bias->oj = scale(1.0);
  377.  bias->layernumber = layerno;
  378.  bias->unitnumber = 32767;           /* bias unit is unit 32767 */
  379.  wnode = (WTNODE *) malloc(sizeof(WTNODE));
  380.  wttotal = wttotal + 1;
  381.  wnode->weight = range * rand() / 32768;
  382.  if (rand() > 16383) wnode->weight = -wnode->weight;
  383.  wnode->olddw = 0;
  384.  wnode->eta = dbdeta;
  385.  wnode->next = NULL;
  386.  (UNIT *) wnode->backunit = bias;
  387.  (WTNODE *) u->wtlist = wnode;
  388.  prevlayer = lptr->backlayer;
  389.  p = (UNIT *) prevlayer->units;
  390.  while (p != NULL)
  391.     {
  392.       connect(p,u,range);
  393.       p = p->next;
  394.     };
  395.  nextlayer = lptr->next;
  396.  p = (UNIT *) nextlayer->units;
  397.  while (p != NULL)
  398.     {
  399.       connect(u,p,range);
  400.       p = p->next;
  401.     };
  402. }      
  403.  
  404. #endif
  405.  
  406. void readpatson(layer,command) /* reads the patterns for layer */
  407. LAYER *layer;
  408. int command;
  409.  
  410. {PATNODE *p, *prevp;
  411.  PATLIST *pl;
  412.  int i;
  413.  
  414.  pl = (PATLIST *) malloc(sizeof(PATLIST));
  415.  pl->next = NULL;
  416.  pl->bypass = 0;      /* number of times to bypass this pattern */
  417.  pl->pats = NULL;     /* no patterns read yet */
  418.  if (layer->patstart == NULL) (PATLIST *) layer->patstart = pl;
  419.  else layer->currentpat->next = pl;
  420.  layer->currentpat = pl;
  421.  
  422.  prevp = NULL; /* read in each number */
  423.  for (i=1;i<=layer->unitcount;i++)
  424.     {
  425.       p = (PATNODE *) malloc(sizeof(PATNODE));
  426.       if (informat == 'r') p->val = rdr(GE,(double) HCODE,command);
  427.       else p->val = scale(readchar());
  428.       if (readerror == 1)
  429.          {
  430.            printf("pattern not read\n");
  431.            return;
  432.          };
  433.       p->next = NULL;
  434.       if (prevp == NULL) pl->pats = p; else prevp->next = p;
  435.       prevp = p;
  436.     };
  437. }
  438.  
  439. void readpats(new,command)  /* reads the input and output patterns */
  440. int new;
  441. int command;
  442. { int i;
  443.   PATLIST *pl;
  444.   
  445.   for (i=prevnpats + 1;i<=npats;i++)
  446.      {
  447.        readpatson(start,command);
  448.        if (readerror == 1) goto failure;
  449.        readpatson(last,command);
  450.        if (readerror == 1) goto failure;
  451.      };
  452.   return;
  453. failure:
  454.   if (data != stdin)
  455.     {
  456.       printf("error while reading pattern %d\n",i);
  457.       exit(5);
  458.     };
  459.   if (new == 0)
  460.      {
  461.        resetpats();
  462.        for (i=1;i<prevnpats;i++) setonepat();
  463.        pl = (PATLIST *) start->currentpat;
  464.        pl->next = NULL;
  465.        pl = (PATLIST *) last->currentpat;
  466.        pl->next = NULL;
  467.      };
  468.   printf("no patterns added\n");
  469.   printf("%d patterns in use\n",prevnpats);   
  470.   npats = prevnpats;
  471. }
  472.  
  473. void init()    /* initializes almost everything */
  474. {int i;
  475.  
  476.  activation = 'p';          /* piece-wise activation function */
  477.  alpha = scale(0.5);
  478.  backprop = 1;              /* always back-propagate errors */
  479.  bufferend = 0;
  480.  bufferptr = buffsize + 1;
  481.  ch = ' ';
  482.  D = scale(1.0);
  483.  dbdeta = scale(0.5);
  484.  decay = scale(0.5);
  485.  deriv = 'f';               /* use Fahlman's derivative */
  486.  eta = scale(0.5);
  487.  eta2 = scale(0.05);
  488.  etamax = scale(30.0);
  489.  extraconnect = 0;
  490.  format[0] = 0;  /* set default places for breaks in output patterns */
  491.  for(i=1;i<=maxformat-1;i++) format[i] = format[i-1] + 10;
  492.  informat = 'c';            /* input format is compressed */
  493.  initialkick = -1;          /* weights have not been kicked yet */
  494.  kappa = scale(0.5);
  495.  lastprint = 0;
  496.  lastsave = 0;
  497.  outformat = 'r';           /* output format is real */
  498.  skiprate = 0;
  499.  prevnpats = 0;
  500.  qmark = scale(0.5);
  501.  saverate = 100000;         /* effectively, never save weights */
  502.  seed = 0;
  503. #ifdef SYMMETRIC
  504.  stdthresh = -32768;        /* indicates no threshold set */
  505. #endif
  506.  summary = '-';             /* don't summarize learning */
  507.  theta1 = scale(0.5);
  508.  theta2 = scale(1.0) - theta1;
  509.  toler = scale(0.1);
  510.  toosmall = -1;             /* indicates no weights whittled away */
  511.  totaliter = 0;
  512.  update = 'o';              /* update formulas are the original */
  513.  wtformat = 'r';            /* save weights in real format */
  514.  wtlimit = scale(0.0);      /* no limit on weights */
  515.  wtlimithit = 0;            /* weight limit not yet hit */
  516.  wttotal = 0;
  517. }
  518.  
  519. void restartcmdloop() /* for a SIGINT, restart in cmdloop */
  520. {
  521.  if (data != stdin) ch = EOF;
  522.  signal(SIGINT,restartcmdloop);
  523.  longjmp(cmdloopstate,1);
  524. }
  525.  
  526. void cmdloop()    /* read commands and process them */
  527. {
  528.  int finished, layerno, unitno, layer1, layer2, node1, node2;
  529.  int i, itemp, itemp2;
  530.  WTTYPE temp, temp2;
  531.  UNIT *u, *n1, *n2, *hunit, *iunit, *junit, *kunit;
  532.  LAYER *p;
  533.  char string[81];
  534.  WTNODE *w;
  535.  
  536.  setjmp(cmdloopstate); /* position to recover from SIGINT */
  537.  finished = 0;         /* loop until finished == 1 */
  538.  do{
  539. restart:
  540. #ifdef SYMMETRIC
  541.     if (data == stdin) printf("[?!*AabCEefHhijklmnoPpQqRrSsTtWwx]? ");
  542. #else
  543.     if (data == stdin) printf("[?!*AabCcEefHhijklmnoPpQqRrSstWwx]? ");
  544. #endif
  545.      while(ch == ' ' || ch == '\n') ch = readch();
  546.      switch (ch) {
  547. case EOF: if (data == stdin) exit(6); else data = stdin;
  548.           printf("taking commands from stdin now\n");
  549.           bufferend = 0;             /* force a read from stdin */
  550.           bufferptr = buffsize + 1;  /* when readch is called */
  551.           ch = ' ';
  552.           goto restart;
  553. case '?': printf("\n%d iterations, s %1d  ",totaliter,seed);
  554.           printf("k 0 %5.3lf,  ",unscale(initialkick));
  555.           printf("file = %s\n",datafilename);
  556.           printf("Algorithm: a%c",activation);
  557.           if (backprop) printf(" b+"); else printf(" b-");
  558.           printf(" D%5.2lf d%c ",unscale(D),deriv);
  559.           printf("l%6.2lf s%1d u%c\n",unscale(wtlimit),skiprate,update);
  560.           printf("e %7.5lf %7.5lf",unscale(eta),unscale(eta2));
  561.           printf(" --- a %7.5lf\n",unscale(alpha));
  562.           printf("j d %8.5lf e %8.5lf",unscale(decay),unscale(dbdeta));
  563.           printf(" k %8.5lf m %8.5lf",unscale(kappa),unscale(etamax));
  564.           printf(" t %8.5lf\n",unscale(theta1));
  565.           printf("tolerance = %4.2lf\n",unscale(toler));
  566.           printf("f i%c o%c",informat,outformat);
  567.           printf(" s%c w%c\n",summary,wtformat);
  568.           printf("format breaks after: ");
  569.           for (i=1;i<=10;i++) printf("%4d",format[i]);
  570.           printf("\n                     ");
  571.           for (i=11;i<=maxformat-1;i++) printf("%4d",format[i]);
  572.           printf("\nlast time weights were saved: %d\n",lastsave);
  573.           printf("saving weights every %d iterations\n",saverate);
  574.           if (wtlimithit) printf(">>>>> WEIGHT LIMIT HIT <<<<<\n");
  575.           printf("network size: ");
  576.           p = start;
  577.           while (p != NULL)
  578.              {
  579.                printf(" %1d",p->unitcount);
  580.                p = p->next;
  581.              };
  582.           if (extraconnect == 1) printf(" with extra connections");
  583.           printf(" (total:  %1d weights)\n",wttotal);
  584.           if (toosmall != -1)
  585.              {
  586.                printf("removed non-bias weights with absolute ");
  587.                printf("value below  %4.2lf\n",unscale(toosmall));
  588.              };
  589. #ifdef SYMMETRIC
  590.           if (stdthresh != -32768)
  591.              printf("thresholds frozen at %lf\n", unscale(stdthresh));
  592. #endif
  593.           printf("%d patterns        ",npats);
  594.           printf("%d learned        ",npats-unlearnedpats);
  595.           printf("%d unlearned on last pass\n",unlearnedpats);
  596.           printf("? = %lf\n",unscale(qmark));
  597.           printf("for help, type h followed by");
  598.           printf(" the letter of the command\n\n");
  599.           break;
  600. case '!': i = 0;
  601.           ch = readch();
  602.           while (ch != '\n' && i <= 80)
  603.              {
  604.                string[i] = ch;
  605.                ch = readch();
  606.                i = i + 1;
  607.              };
  608.           bufferptr = bufferptr - 1; /* ungetc(ch,data); */
  609.           string[i] = '\0';
  610.           system(string);
  611.           break;
  612. case '*': break;  /* * on a line is a comment */
  613. case 'A': while (ch != '\n' && ch != '*')
  614.            {
  615.             ch = readch();
  616.             if (ch == 'a')
  617.              {
  618.               do ch = readch(); while (ch == ' ');
  619.               if (ch == 'p') activation = 'p';
  620. #ifndef INTEGER
  621.               else if (ch == 's') activation = 's';
  622. #endif
  623.               else texterror();
  624.              }
  625.             else if (ch == 'b')
  626.              {
  627.               do ch = readch(); while (ch == ' ');
  628.               if (ch == '+') backprop = 1;
  629.               else if (ch == '-') backprop = 0;
  630.               else texterror();
  631.              }
  632.             else if (ch == 'D')
  633.              {
  634.                temp = rdr(GT,0.0,'A');
  635.                if (readerror == 0) D = temp;
  636.              }
  637.             else if (ch == 'd')
  638.              {
  639.               do ch = readch(); while (ch == ' ');
  640.               if (ch == 'd' || ch == 'f' || ch == 'o') deriv = ch;
  641.               else texterror();
  642.              }
  643.             else if (ch == 'l')
  644.              {
  645.                temp = rdr(GE,0.0,'A');
  646.                if (readerror == 0)
  647.                   {
  648.                     wtlimit = temp;
  649.                     if (wtlimit == 0) wtlimithit = 0;
  650.                   };
  651.              }
  652.             else if (ch == 's')
  653.              {
  654.               itemp = readint(0,32767,'s');
  655.               if (readerror == 0) skiprate = itemp;
  656.              }
  657.             else if (ch == 'u')
  658.              {
  659.               do ch = readch(); while (ch == ' ');
  660.               if (ch == 'c' || ch == 'C' || ch == 'd' ||
  661.                   ch == 'j' || ch == 'o') update = ch;
  662.               else texterror();
  663.              }
  664.             else if (ch == '*' || ch == '\n' || ch == ' ');
  665.             else texterror();
  666.            }
  667.           bufferptr = bufferptr - 1;
  668.           break;
  669. case 'a': temp = rdr(GE,0.0,'a');
  670.           if (readerror == 0) alpha = temp;
  671.           break;
  672. case 'b': itemp = 0;
  673.           ch = readch();
  674.           while (ch != '\n' && ch != '*')
  675.              {
  676.                bufferptr = bufferptr - 1;
  677.                itemp2 = readint(format[itemp],MAXINT,'b');
  678.                if (readerror == 1) goto endb;
  679.                itemp = itemp + 1;
  680.                if (itemp < maxformat) format[itemp] = itemp2;
  681.                else printf("format too long\n");
  682.                ch = readch();
  683.                while (ch == ' ') ch = readch();
  684.                /* if its the start of a number, back up */
  685.                if (ch != '\n') bufferptr = bufferptr - 1;
  686.              };
  687.           if (itemp < maxformat-1)
  688.              for (i=itemp+1;i <= maxformat-1; i++)
  689.                 format[i] = format[i-1] + 10;
  690.           bufferptr = bufferptr - 1;
  691.     endb: break;
  692. case 'C': if (toosmall != -1)
  693.              {
  694.                printf("cannot restart with the weights removed\n");
  695.                break;
  696.              };
  697.           wtlimithit = 0;
  698.           totaliter = 0;
  699.           lastsave = 0;
  700.           initialkick = -1;
  701.           lastprint = 0;
  702.           seed = 0;
  703.           p = start->next;
  704.           while (p != NULL)
  705.              {
  706.                u = (UNIT *) p->units;
  707.                while (u != NULL)
  708.                   {
  709.                     w = (WTNODE *) u->wtlist;
  710.                     while (w != NULL)
  711.                        {
  712. #ifdef SYMMETRIC
  713.                          if (w->next != NULL)
  714.                             { /* skip threshold weight */
  715.                               *(w->weight) = 0;
  716.                               *(w->olddw) = 0;
  717.                               *(w->eta) = dbdeta;
  718.                             };
  719. #else
  720.                          w->weight = 0;
  721.                          w->olddw = 0;
  722.                          w->eta = dbdeta;
  723. #endif
  724.                          w = w->next;
  725.                        };
  726.                     u = u->next;
  727.                   };
  728.                p = p->next;
  729.              };
  730.           break;
  731. #ifndef SYMMETRIC
  732. case 'c': layer1 = readint(1,nlayers,'c');
  733.           if (readerror == 1) break;
  734.           node1 = readint(1,MAXINT,'c');
  735.           if (readerror == 1) break;
  736.           layer2 = readint(1,nlayers,'c');
  737.           if (readerror == 1) break;
  738.           node2 = readint(1,MAXINT,'c');
  739.           if (readerror == 1) break;
  740.           if (layer1 >= layer2)
  741.              {
  742.                printf("backward connections in c command not");
  743.                printf(" implemented\n");
  744.                break;
  745.              };
  746.           n1 = locateunit(layer1,node1);
  747.           n2 = locateunit(layer2,node2);
  748.           if (n1 != NULL && n2 != NULL)
  749.              {
  750.                connect(n1,n2,0);
  751.                extraconnect = 1;
  752.              }
  753.           else printf("connection not made: %d %d %d %d\n",
  754.                        layer1, node1, layer2, node2);
  755.           break;
  756. #endif
  757.  
  758. case 'E': itemp = readint(0,1,'E');
  759.           if (readerror == 1) break;
  760.           else echo = itemp;
  761.           break;
  762. case 'e': temp = rdr(GT,0.0,'e');
  763.           if (readerror == 0) eta = temp;
  764.           while (ch == ' ') ch = readch();
  765.           if (ch != '\n' && ch != '*')
  766.              {
  767.                bufferptr = bufferptr - 1;
  768.                temp = rdr(GT,0.0,'r');
  769.                if (readerror != 1) eta2 = temp;
  770.              }
  771.           else eta2 = eta / 10;
  772.           bufferptr = bufferptr - 1;
  773.           break;
  774. case 'f': while (ch != '\n' && ch != '*')
  775.            {
  776.             ch = readch();
  777.             if (ch == 'i')
  778.              {
  779.               do ch = readch(); while (ch == ' ');
  780.               if (ch == 'c' || ch == 'r') informat = ch;
  781.               else texterror();
  782.              }
  783.             else if (ch == 'o')
  784.              {
  785.               do ch = readch(); while (ch == ' ');
  786.               if (ch == 'a' || ch == 'c' || ch == 'r') outformat = ch;
  787.               else texterror();
  788.              }
  789.             else if (ch == 's')
  790.              {
  791.               do ch = readch(); while (ch == ' ');
  792.               if (ch == '+' || summary == '-') summary = ch;
  793.               else texterror();
  794.              }
  795.             else if (ch == 'w')
  796.              {
  797.               do ch = readch(); while (ch == ' ');
  798.               if (ch == 'r' || ch == 'R' || ch == 'b' || ch == 'B')
  799.                  wtformat = ch;
  800.               else texterror();
  801.              }
  802.             else if (ch == ' ' || ch == '*' || ch == '\n');
  803.             else texterror();
  804.            }
  805.           bufferptr = bufferptr - 1;
  806.           break;
  807. #ifndef SYMMETRIC
  808. case 'H': itemp = readint(2,nlayers,'H');
  809.           if (readerror == 1) break;
  810.           temp = rdr(GE,0.0,'H');
  811.           if (readerror == 0) addhiddenunit(itemp,temp);
  812.           break;
  813. #endif
  814. case 'h': help();
  815.           break;
  816. case 'i': ch = readch();
  817.           while(ch == ' ') ch = readch();
  818.           itemp = 0;
  819.           while(ch != ' ' && ch != '\n' && itemp < 49)
  820.              {
  821.                cmdfilename[itemp] = ch;
  822.                itemp = itemp + 1;
  823.                ch = readch();
  824.              };
  825.           cmdfilename[itemp] = '\0';
  826.           if ((data = fopen(cmdfilename,"r")) == (FILE *) NULL)
  827.              {
  828.                printf("cannot open: %s\n",cmdfilename);
  829.                data = stdin;
  830.                printf("taking commands from stdin now\n");
  831.              }
  832.           bufferend = 0;
  833.           bufferptr = buffsize + 1;
  834.           ch = ' ';
  835.           goto restart;
  836. case 'j': while (ch != '\n' && ch != '*')
  837.            {
  838.             ch = readch();
  839.             if (ch == 'd')
  840.              {
  841.               temp = rdr(GT,0.0,'j');
  842.               if (readerror == 0) decay = temp;
  843.              }
  844.             else if (ch == 'e')
  845.              {
  846.               temp = rdr(GT,0.0,'d');
  847.               if (readerror == 0)
  848.                {
  849.                 dbdeta = temp;
  850.                 p = start->next;
  851.                 while (p != NULL)
  852.                  {
  853.                   u = (UNIT *) p->units;
  854.                   while (u != NULL)
  855.                    {
  856.                     w = (WTNODE *) u->wtlist;
  857.                     while (w != NULL)
  858.                      {
  859. #ifdef SYMMETRIC
  860.                       *(w->eta) = dbdeta;
  861. #else
  862.                       w->eta = dbdeta;
  863. #endif
  864.                       w = w->next;
  865.                      }
  866.                     u = u->next;
  867.                    }
  868.                   p = p->next;
  869.                  }
  870.                }
  871.              }
  872.             else if (ch == 'k')
  873.              {
  874.               temp = rdr(GT,0.0,'j');
  875.               if (readerror == 0) kappa = temp;
  876.              }
  877.             else if (ch == 'm')
  878.              {
  879.               temp = rdr(GT,0.0,'j');
  880.               if (readerror == 0) etamax = temp;
  881.              }
  882.             else if (ch == 't')
  883.              {
  884.               temp = rdr(GE,0.0,'j');
  885.               if (readerror == 0)
  886.                  {
  887.                   theta1 = temp;
  888.                   theta2 = scale(1.0) - theta1;
  889.                  };
  890.              }
  891.             else if (ch == '*' || ch == '\n' || ch == ' ');
  892.             else texterror();
  893.            }
  894.           bufferptr = bufferptr - 1;
  895.           break;
  896. case 'k': temp = rdr(GE,0.0,'k');
  897.           if (readerror == 1) break;
  898.           temp2 = rdr(GT,0.0,'k');
  899.           if (readerror == 0)
  900.            {
  901.             if (initialkick == -1 && temp == 0) initialkick = temp2;
  902.             kick(temp,temp2);
  903.            }
  904.           break;
  905. case 'l': layerno = readint(1,nlayers,'l'); 
  906.           if (readerror == 1) break;
  907.           p = start;
  908.           for (i=2;i<=layerno;i++) p = p->next;
  909.           printoutunits(p,0);
  910.           break;
  911. case 'm': nlayers = 0;
  912.           ch = readch();
  913.           p = NULL;
  914.           while (ch != '\n' && ch != '*')
  915.              {
  916.                itemp = readint(1,MAXINT,'m');
  917.                if (readerror == 1) goto endm;
  918.                nlayers = nlayers + 1;
  919.                p = mklayer(p,itemp);
  920.                if (nlayers == 1) start = p;
  921.                ch = readch();
  922.                while (ch == ' ') ch = readch();
  923.                /* if its the start of a number, back up */
  924.                if (ch != '\n') bufferptr = bufferptr - 1;
  925.              };
  926.           last = p;
  927.           p = start;
  928.           p = p->next;
  929.           hlayer = (UNIT *) p->units;
  930.           p = p->next;
  931.           if (p != NULL)
  932.              {
  933.                ilayer = (UNIT *) p->units;
  934.                p = p->next;
  935.                if (p != NULL)
  936.                   {
  937.                     jlayer = (UNIT *) p->units;
  938.                     p = p->next;
  939.                     if (p != NULL) klayer = (UNIT *) p->units;
  940.                   }
  941.              };
  942.           bufferptr = bufferptr - 1;
  943.           nullpatterns();
  944.     endm: break;
  945. case 'n': if (start == NULL)
  946.              {
  947.                printf("the network must be defined first\n");
  948.                break;
  949.              };
  950.           itemp = readint(1,MAXINT,'n');
  951.           if (readerror == 1) break;
  952.           nullpatterns();
  953.           npats = itemp;
  954.           readingpattern = 1;
  955.           readpats(1,'n');
  956.           readingpattern = 0;
  957.           unlearnedpats = npats;
  958.           break;
  959. case 'o': do ch = readch(); while (ch == ' ' || ch == '\n');
  960.           if (ch == 'r' || ch == 'a' || ch == 'c') outformat = ch;
  961.           else printf("incorrect output format: %c\n",ch);
  962.           break;
  963. case 'P': do ch = readch(); while (ch == ' ');
  964.           bufferptr = bufferptr - 1;
  965.           if (ch == '\n' || ch == '*') itemp = 0;
  966.           else
  967.              {
  968.                itemp = readint(0,npats,'P');
  969.                if (readerror == 1) break;
  970.              };
  971.           if (itemp == 0) printpats(1,npats,0,1,0);
  972.           else printpats(itemp,itemp,0,1,0);
  973.           break;
  974. case 'p': u = (UNIT *) start->units;
  975.           readingpattern = 1;
  976.           hunit = hlayer;
  977.           iunit = ilayer;
  978.           junit = jlayer;
  979.           kunit = klayer;
  980.           while (u != NULL)
  981.            {
  982.             if (informat == 'r') u->oj = rdr(GE,(double) HCODE,'p');
  983.             else u->oj = scale(readchar());
  984.             if (readerror == 1) goto endp;
  985.             if (u->oj <= KCODE) /* do hidden unit codes */
  986.              {
  987.               if (u->oj == HCODE)
  988.                  {if (copyhidden(u,&hunit,2) == 1) goto endp;}
  989.               else if (u->oj == ICODE)
  990.                  {if (copyhidden(u,&iunit,3) == 1) goto endp;}
  991.               else if (u->oj == JCODE)
  992.                  {if (copyhidden(u,&junit,4) == 1) goto endp;}
  993.               else if (copyhidden(u,&kunit,5) == 1) goto endp;
  994.              };
  995.             u = u->next;
  996.            };
  997.           forward();
  998.           printoutunits(last,0);
  999.     endp: readingpattern = 0;
  1000.           break;
  1001. case 'Q': temp = rdr(GT,(double) KCODE,'Q');
  1002.           if (readerror == 0) qmark = temp;
  1003.           break;
  1004. case 'q': finished = 1;
  1005.           break;
  1006. case 'R': restoreweights();
  1007.           break;
  1008. case 'r': if (start == NULL)
  1009.              {
  1010.                printf("the network must be defined first\n");
  1011.                break;
  1012.              };
  1013.           iter = readint(1,MAXINT,'r'); 
  1014.           if (readerror == 1) break;
  1015.           while (ch == ' ') ch = readch();
  1016.           if (ch != '\n' && ch != '*')
  1017.              {
  1018.                bufferptr = bufferptr - 1;
  1019.                itemp = readint(1,MAXINT,'r');
  1020.                if (readerror != 1) run(iter,itemp);
  1021.              }
  1022.           else run(iter,-1);
  1023.           bufferptr = bufferptr - 1;
  1024.           break;
  1025. case 'S': do ch = readch(); while (ch == ' ');
  1026.           bufferptr = bufferptr - 1;
  1027.           if (ch == '\n' || ch == '*') itemp = 0;
  1028.           else
  1029.              {
  1030.                itemp = readint(0,MAXINT,'S');
  1031.                if (readerror == 1) break;
  1032.              };
  1033.           if (itemp == 0) saveweights();
  1034.           else saverate = itemp;
  1035.           break;
  1036. case 's': seed = readint(0,MAXINT,'s');
  1037.           srand(seed);
  1038.           break;
  1039. #ifdef SYMMETRIC
  1040. case 'T': stdthresh = rdr(GT,-unscale(32767),'T');
  1041.           if (readerror == 1) break;
  1042.           u = (UNIT *) last->units;
  1043.           while (u != NULL)
  1044.              {
  1045.                w = (WTNODE *) u->wtlist;
  1046.                while (w->next != NULL) w = w->next;
  1047.                *(w->weight) = stdthresh;
  1048.                u = u->next;
  1049.              };
  1050.           break;
  1051. #endif
  1052. case 't': temp = rdr(GT,0.0,'t');
  1053.           if (readerror == 1) break;
  1054.           else if (temp < scale(1.0)) toler = temp;
  1055.           else printf("tolerance value out of range\n");
  1056.           break;
  1057. #ifndef SYMMETRIC
  1058. case 'W': temp = rdr(GT,0.0,'W');
  1059.           if (readerror == 0)
  1060.              {
  1061.                toosmall = temp;
  1062.                whittle(temp);
  1063.                printf("total weights now: %1d\n",wttotal);
  1064.              };
  1065.           break;
  1066. #endif
  1067. case 'w': layerno = readint(2,nlayers,'w');
  1068.           if (readerror == 1) break;
  1069.           unitno = readint(1,MAXINT,'w');
  1070.           if (readerror == 1) break;
  1071.           u = locateunit(layerno,unitno);
  1072.           if (u != NULL) printweights(u);
  1073.           break;
  1074. case 'x': if (start == NULL)
  1075.              {
  1076.                printf("the network must be defined first\n");
  1077.                break;
  1078.              };
  1079.           itemp = readint(1,MAXINT,'x');
  1080.           if (readerror == 1) break;
  1081.           prevnpats = npats;
  1082.           npats = npats + itemp;
  1083.           findendofpats(start);
  1084.           findendofpats(last);
  1085.           readingpattern = 1;
  1086.           readpats(0,'x');
  1087.           readingpattern = 0;
  1088.           unlearnedpats = npats;
  1089.           break;
  1090. default : texterror();
  1091.           break;
  1092.       };
  1093.     ch = readch();
  1094.     while(ch != '\n') ch = readch();
  1095.   }while (finished == 0);
  1096. }
  1097.  
  1098. void main(argc,argv)
  1099. int argc;
  1100. char *argv[];
  1101. {
  1102.  char *fnamestr, *i;
  1103.  
  1104. printf("Fast Backpropagation Copyright (c) 1990 by Donald R. Tveter\n");
  1105.  
  1106.  setbuf(stdout,NULL);  /* set unbuffered output */
  1107.  if (argc == 1) /* check for file argument, if any */
  1108.     {
  1109.       printf("missing data file name, stdin assumed\n");
  1110.       data = stdin;
  1111.       *datafilename = '\0';
  1112.     }
  1113.  else
  1114.     if ((data = fopen(argv[1],"r")) == (FILE *) NULL)
  1115.        {
  1116.          printf("cannot open: %s\n",argv[1]);
  1117.          exit(1);
  1118.        }
  1119.     else /* make a copy of the file name in a global variable */
  1120.        {
  1121.          fnamestr = argv[1];
  1122.          i = datafilename;
  1123.          while(*fnamestr != '\0') *i++ = *fnamestr++;
  1124.        };
  1125.  init();
  1126.  signal(SIGINT,restartcmdloop); /* restart from interrrupt */
  1127.  cmdloop();
  1128. }
  1129.