home *** CD-ROM | disk | FTP | other *** search
/ Geek Gadgets 1 / ADE-1.bin / ade-dist / octave-1.1.1p1-src.tgz / tar.out / fsf / octave / liboctave / dDiagMatrix.cc < prev    next >
C/C++ Source or Header  |  1996-09-28  |  9KB  |  473 lines

  1. // DiagMatrix manipulations.                             -*- C++ -*-
  2. /*
  3.  
  4. Copyright (C) 1992, 1993, 1994, 1995 John W. Eaton
  5.  
  6. This file is part of Octave.
  7.  
  8. Octave is free software; you can redistribute it and/or modify it
  9. under the terms of the GNU General Public License as published by the
  10. Free Software Foundation; either version 2, or (at your option) any
  11. later version.
  12.  
  13. Octave is distributed in the hope that it will be useful, but WITHOUT
  14. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  15. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  16. for more details.
  17.  
  18. You should have received a copy of the GNU General Public License
  19. along with Octave; see the file COPYING.  If not, write to the Free
  20. Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
  21.  
  22. */
  23.  
  24. #ifdef HAVE_CONFIG_H
  25. #include "config.h"
  26. #endif
  27.  
  28. #include <iostream.h>
  29.  
  30. #include <Complex.h>
  31.  
  32. #include "mx-base.h"
  33. #include "mx-inlines.cc"
  34. #include "lo-error.h"
  35.  
  36. /*
  37.  * Diagonal Matrix class.
  38.  */
  39.  
  40. #define KLUDGE_DIAG_MATRICES
  41. #define TYPE double
  42. #define KL_DMAT_TYPE DiagMatrix
  43. #include "mx-kludge.cc"
  44. #undef KLUDGE_DIAG_MATRICES
  45. #undef TYPE
  46. #undef KL_DMAT_TYPE
  47.  
  48. int
  49. DiagMatrix::operator == (const DiagMatrix& a) const
  50. {
  51.   if (rows () != a.rows () || cols () != a.cols ())
  52.     return 0;
  53.  
  54.   return equal (data (), a.data (), length ());
  55. }
  56.  
  57. int
  58. DiagMatrix::operator != (const DiagMatrix& a) const
  59. {
  60.   return !(*this == a);
  61. }
  62.  
  63. DiagMatrix&
  64. DiagMatrix::fill (double val)
  65. {
  66.   for (int i = 0; i < length (); i++)
  67.     elem (i, i) = val;
  68.   return *this;
  69. }
  70.  
  71. DiagMatrix&
  72. DiagMatrix::fill (double val, int beg, int end)
  73. {
  74.   if (beg < 0 || end >= length () || end < beg)
  75.     {
  76.       (*current_liboctave_error_handler) ("range error for fill");
  77.       return *this;
  78.     }
  79.  
  80.   for (int i = beg; i < end; i++)
  81.     elem (i, i) = val;
  82.  
  83.   return *this;
  84. }
  85.  
  86. DiagMatrix&
  87. DiagMatrix::fill (const ColumnVector& a)
  88. {
  89.   int len = length ();
  90.   if (a.length () != len)
  91.     {
  92.       (*current_liboctave_error_handler) ("range error for fill");
  93.       return *this;
  94.     }
  95.  
  96.   for (int i = 0; i < len; i++)
  97.     elem (i, i) = a.elem (i);
  98.  
  99.   return *this;
  100. }
  101.  
  102. DiagMatrix&
  103. DiagMatrix::fill (const RowVector& a)
  104. {
  105.   int len = length ();
  106.   if (a.length () != len)
  107.     {
  108.       (*current_liboctave_error_handler) ("range error for fill");
  109.       return *this;
  110.     }
  111.  
  112.   for (int i = 0; i < len; i++)
  113.     elem (i, i) = a.elem (i);
  114.  
  115.   return *this;
  116. }
  117.  
  118. DiagMatrix&
  119. DiagMatrix::fill (const ColumnVector& a, int beg)
  120. {
  121.   int a_len = a.length ();
  122.   if (beg < 0 || beg + a_len >= length ())
  123.     {
  124.       (*current_liboctave_error_handler) ("range error for fill");
  125.       return *this;
  126.     }
  127.  
  128.   for (int i = 0; i < a_len; i++)
  129.     elem (i+beg, i+beg) = a.elem (i);
  130.  
  131.   return *this;
  132. }
  133.  
  134. DiagMatrix&
  135. DiagMatrix::fill (const RowVector& a, int beg)
  136. {
  137.   int a_len = a.length ();
  138.   if (beg < 0 || beg + a_len >= length ())
  139.     {
  140.       (*current_liboctave_error_handler) ("range error for fill");
  141.       return *this;
  142.     }
  143.  
  144.   for (int i = 0; i < a_len; i++)
  145.     elem (i+beg, i+beg) = a.elem (i);
  146.  
  147.   return *this;
  148. }
  149.  
  150. DiagMatrix
  151. DiagMatrix::transpose (void) const
  152. {
  153.   return DiagMatrix (dup (data (), length ()), cols (), rows ());
  154. }
  155.  
  156. DiagMatrix
  157. real (const ComplexDiagMatrix& a)
  158. {
  159.   DiagMatrix retval;
  160.   int a_len = a.length ();
  161.   if (a_len > 0)
  162.     retval = DiagMatrix (real_dup (a.data (), a_len), a.rows (),
  163.              a.cols ());
  164.   return retval;
  165. }
  166.  
  167. DiagMatrix
  168. imag (const ComplexDiagMatrix& a)
  169. {
  170.   DiagMatrix retval;
  171.   int a_len = a.length ();
  172.   if (a_len > 0)
  173.     retval = DiagMatrix (imag_dup (a.data (), a_len), a.rows (),
  174.              a.cols ());
  175.   return retval;
  176. }
  177.  
  178. Matrix
  179. DiagMatrix::extract (int r1, int c1, int r2, int c2) const
  180. {
  181.   if (r1 > r2) { int tmp = r1; r1 = r2; r2 = tmp; }
  182.   if (c1 > c2) { int tmp = c1; c1 = c2; c2 = tmp; }
  183.  
  184.   int new_r = r2 - r1 + 1;
  185.   int new_c = c2 - c1 + 1;
  186.  
  187.   Matrix result (new_r, new_c);
  188.  
  189.   for (int j = 0; j < new_c; j++)
  190.     for (int i = 0; i < new_r; i++)
  191.       result.elem (i, j) = elem (r1+i, c1+j);
  192.  
  193.   return result;
  194. }
  195.  
  196. // extract row or column i.
  197.  
  198. RowVector
  199. DiagMatrix::row (int i) const
  200. {
  201.   int nr = rows ();
  202.   int nc = cols ();
  203.   if (i < 0 || i >= nr)
  204.     {
  205.       (*current_liboctave_error_handler) ("invalid row selection");
  206.       return RowVector (); 
  207.     }
  208.  
  209.   RowVector retval (nc, 0.0);
  210.   if (nr <= nc || (nr > nc && i < nc))
  211.     retval.elem (i) = elem (i, i);
  212.  
  213.   return retval;
  214. }
  215.  
  216. RowVector
  217. DiagMatrix::row (char *s) const
  218. {
  219.   if (! s)
  220.     {
  221.       (*current_liboctave_error_handler) ("invalid row selection");
  222.       return RowVector (); 
  223.     }
  224.  
  225.   char c = *s;
  226.   if (c == 'f' || c == 'F')
  227.     return row (0);
  228.   else if (c == 'l' || c == 'L')
  229.     return row (rows () - 1);
  230.   else
  231.     {
  232.       (*current_liboctave_error_handler) ("invalid row selection");
  233.       return RowVector (); 
  234.     }
  235. }
  236.  
  237. ColumnVector
  238. DiagMatrix::column (int i) const
  239. {
  240.   int nr = rows ();
  241.   int nc = cols ();
  242.   if (i < 0 || i >= nc)
  243.     {
  244.       (*current_liboctave_error_handler) ("invalid column selection");
  245.       return ColumnVector (); 
  246.     }
  247.  
  248.   ColumnVector retval (nr, 0.0);
  249.   if (nr >= nc || (nr < nc && i < nr))
  250.     retval.elem (i) = elem (i, i);
  251.  
  252.   return retval;
  253. }
  254.  
  255. ColumnVector
  256. DiagMatrix::column (char *s) const
  257. {
  258.   if (! s)
  259.     {
  260.       (*current_liboctave_error_handler) ("invalid column selection");
  261.       return ColumnVector (); 
  262.     }
  263.  
  264.   char c = *s;
  265.   if (c == 'f' || c == 'F')
  266.     return column (0);
  267.   else if (c == 'l' || c == 'L')
  268.     return column (cols () - 1);
  269.   else
  270.     {
  271.       (*current_liboctave_error_handler) ("invalid column selection");
  272.       return ColumnVector (); 
  273.     }
  274. }
  275.  
  276. DiagMatrix
  277. DiagMatrix::inverse (void) const
  278. {
  279.   int info;
  280.   return inverse (info);
  281. }
  282.  
  283. DiagMatrix
  284. DiagMatrix::inverse (int &info) const
  285. {
  286.   int nr = rows ();
  287.   int nc = cols ();
  288.   int len = length ();
  289.   if (nr != nc)
  290.     {
  291.       (*current_liboctave_error_handler) ("inverse requires square matrix");
  292.       return DiagMatrix ();
  293.     }
  294.  
  295.   info = 0;
  296.   double *tmp_data = dup (data (), len);
  297.   for (int i = 0; i < len; i++)
  298.     {
  299.       if (elem (i, i) == 0.0)
  300.     {
  301.       info = -1;
  302.       copy (tmp_data, data (), len); // Restore contents.
  303.       break;
  304.     }
  305.       else
  306.     {
  307.       tmp_data[i] = 1.0 / elem (i, i);
  308.     }
  309.     }
  310.  
  311.   return DiagMatrix (tmp_data, nr, nc);
  312. }
  313.  
  314. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  315.  
  316. DiagMatrix&
  317. DiagMatrix::operator += (const DiagMatrix& a)
  318. {
  319.   int nr = rows ();
  320.   int nc = cols ();
  321.   if (nr != a.rows () || nc != a.cols ())
  322.     {
  323.       (*current_liboctave_error_handler)
  324.     ("nonconformant matrix += operation attempted");
  325.       return *this;
  326.     }
  327.  
  328.   if (nc == 0 || nr == 0)
  329.     return *this;
  330.  
  331.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  332.  
  333.   add2 (d, a.data (), length ());
  334.   return *this;
  335. }
  336.  
  337. DiagMatrix&
  338. DiagMatrix::operator -= (const DiagMatrix& a)
  339. {
  340.   int nr = rows ();
  341.   int nc = cols ();
  342.   if (nr != a.rows () || nc != a.cols ())
  343.     {
  344.       (*current_liboctave_error_handler)
  345.     ("nonconformant matrix -= operation attempted");
  346.       return *this;
  347.     }
  348.  
  349.   if (nr == 0 || nc == 0)
  350.     return *this;
  351.  
  352.   double *d = fortran_vec (); // Ensures only one reference to my privates!
  353.  
  354.   subtract2 (d, a.data (), length ());
  355.   return *this;
  356. }
  357.  
  358. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  359.  
  360. DiagMatrix
  361. operator * (const DiagMatrix& a, const DiagMatrix& b)
  362. {
  363.   int nr_a = a.rows ();
  364.   int nc_a = a.cols ();
  365.   int nr_b = b.rows ();
  366.   int nc_b = b.cols ();
  367.   if (nc_a != nr_b)
  368.     {
  369.       (*current_liboctave_error_handler)
  370.         ("nonconformant matrix multiplication attempted");
  371.       return DiagMatrix ();
  372.     }
  373.  
  374.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  375.     return DiagMatrix (nr_a, nc_a, 0.0);
  376.  
  377.   DiagMatrix c (nr_a, nc_b);
  378.  
  379.   int len = nr_a < nc_b ? nr_a : nc_b;
  380.  
  381.   for (int i = 0; i < len; i++)
  382.     {
  383.       double a_element = a.elem (i, i);
  384.       double b_element = b.elem (i, i);
  385.  
  386.       if (a_element == 0.0 || b_element == 0.0)
  387.         c.elem (i, i) = 0.0;
  388.       else if (a_element == 1.0)
  389.         c.elem (i, i) = b_element;
  390.       else if (b_element == 1.0)
  391.         c.elem (i, i) = a_element;
  392.       else
  393.         c.elem (i, i) = a_element * b_element;
  394.     }
  395.  
  396.   return c;
  397. }
  398.  
  399. // other operations
  400.  
  401. ColumnVector
  402. DiagMatrix::diag (void) const
  403. {
  404.   return diag (0);
  405. }
  406.  
  407. // Could be optimized...
  408.  
  409. ColumnVector
  410. DiagMatrix::diag (int k) const
  411. {
  412.   int nnr = rows ();
  413.   int nnc = cols ();
  414.   if (k > 0)
  415.     nnc -= k;
  416.   else if (k < 0)
  417.     nnr += k;
  418.  
  419.   ColumnVector d;
  420.  
  421.   if (nnr > 0 && nnc > 0)
  422.     {
  423.       int ndiag = (nnr < nnc) ? nnr : nnc;
  424.  
  425.       d.resize (ndiag);
  426.  
  427.       if (k > 0)
  428.     {
  429.       for (int i = 0; i < ndiag; i++)
  430.         d.elem (i) = elem (i, i+k);
  431.     }
  432.       else if ( k < 0)
  433.     {
  434.       for (int i = 0; i < ndiag; i++)
  435.         d.elem (i) = elem (i-k, i);
  436.     }
  437.       else
  438.     {
  439.       for (int i = 0; i < ndiag; i++)
  440.         d.elem (i) = elem (i, i);
  441.     }
  442.     }
  443.   else
  444.     cerr << "diag: requested diagonal out of range\n";
  445.  
  446.   return d;
  447. }
  448.  
  449. ostream&
  450. operator << (ostream& os, const DiagMatrix& a)
  451. {
  452. //  int field_width = os.precision () + 7;
  453.   for (int i = 0; i < a.rows (); i++)
  454.     {
  455.       for (int j = 0; j < a.cols (); j++)
  456.     {
  457.       if (i == j)
  458.         os << " " /* setw (field_width) */ << a.elem (i, i);
  459.       else
  460.         os << " " /* setw (field_width) */ << 0.0;
  461.     }
  462.       os << "\n";
  463.     }
  464.   return os;
  465. }
  466.  
  467. /*
  468. ;;; Local Variables: ***
  469. ;;; mode: C++ ***
  470. ;;; page-delimiter: "^/\\*" ***
  471. ;;; End: ***
  472. */
  473.