Main Page   Namespace List   Compound List   File List   Compound Members

# jama_svd.h

Go to the documentation of this file.
00001 #ifndef JAMA_SVD_H
00002 #define JAMA_SVD_H
00003
00004
00005 #include "tnt_array1d.h"
00006 #include "tnt_array1d_utils.h"
00007 #include "tnt_array2d.h"
00008 #include "tnt_array2d_utils.h"
00009 #include "tnt_math_utils.h"
00010
00011
00012 using namespace TNT;
00013
00014 namespace JAMA
00015 {
00033 template <class Real>
00034 class SVD
00035 {
00036
00037
00038         Array2D<Real> U, V;
00039         Array1D<Real> s;
00040         int m, n;
00041
00042   public:
00043
00044
00045    SVD (const Array2D<Real> &Arg) {
00046
00047
00048       m = Arg.dim1();
00049       n = Arg.dim2();
00050       int nu = min(m,n);
00051       s = Array1D<Real>(min(m+1,n));
00052       U = Array2D<Real>(m, nu, Real(0));
00053       V = Array2D<Real>(n,n);
00054       Array1D<Real> e(n);
00055       Array1D<Real> work(m);
00056           Array2D<Real> A(Arg.copy());
00057       int wantu = 1;                                    /* boolean */
00058       int wantv = 1;                                    /* boolean */
00059           int i=0, j=0, k=0;
00060
00061       // Reduce A to bidiagonal form, storing the diagonal elements
00062       // in s and the super-diagonal elements in e.
00063
00064       int nct = min(m-1,n);
00065       int nrt = max(0,min(n-2,m));
00066       for (k = 0; k < max(nct,nrt); k++) {
00067          if (k < nct) {
00068
00069             // Compute the transformation for the k-th column and
00070             // place the k-th diagonal in s[k].
00071             // Compute 2-norm of k-th column without under/overflow.
00072             s[k] = 0;
00073             for (i = k; i < m; i++) {
00074                s[k] = hypot(s[k],A[i][k]);
00075             }
00076             if (s[k] != 0.0) {
00077                if (A[k][k] < 0.0) {
00078                   s[k] = -s[k];
00079                }
00080                for (i = k; i < m; i++) {
00081                   A[i][k] /= s[k];
00082                }
00083                A[k][k] += 1.0;
00084             }
00085             s[k] = -s[k];
00086          }
00087          for (j = k+1; j < n; j++) {
00088             if ((k < nct) && (s[k] != 0.0))  {
00089
00090             // Apply the transformation.
00091
00092                double t = 0;
00093                for (i = k; i < m; i++) {
00094                   t += A[i][k]*A[i][j];
00095                }
00096                t = -t/A[k][k];
00097                for (i = k; i < m; i++) {
00098                   A[i][j] += t*A[i][k];
00099                }
00100             }
00101
00102             // Place the k-th row of A into e for the
00103             // subsequent calculation of the row transformation.
00104
00105             e[j] = A[k][j];
00106          }
00107          if (wantu & (k < nct)) {
00108
00109             // Place the transformation in U for subsequent back
00110             // multiplication.
00111
00112             for (i = k; i < m; i++) {
00113                U[i][k] = A[i][k];
00114             }
00115          }
00116          if (k < nrt) {
00117
00118             // Compute the k-th row transformation and place the
00119             // k-th super-diagonal in e[k].
00120             // Compute 2-norm without under/overflow.
00121             e[k] = 0;
00122             for (i = k+1; i < n; i++) {
00123                e[k] = hypot(e[k],e[i]);
00124             }
00125             if (e[k] != 0.0) {
00126                if (e[k+1] < 0.0) {
00127                   e[k] = -e[k];
00128                }
00129                for (i = k+1; i < n; i++) {
00130                   e[i] /= e[k];
00131                }
00132                e[k+1] += 1.0;
00133             }
00134             e[k] = -e[k];
00135             if ((k+1 < m) & (e[k] != 0.0)) {
00136
00137             // Apply the transformation.
00138
00139                for (i = k+1; i < m; i++) {
00140                   work[i] = 0.0;
00141                }
00142                for (j = k+1; j < n; j++) {
00143                   for (i = k+1; i < m; i++) {
00144                      work[i] += e[j]*A[i][j];
00145                   }
00146                }
00147                for (j = k+1; j < n; j++) {
00148                   double t = -e[j]/e[k+1];
00149                   for (i = k+1; i < m; i++) {
00150                      A[i][j] += t*work[i];
00151                   }
00152                }
00153             }
00154             if (wantv) {
00155
00156             // Place the transformation in V for subsequent
00157             // back multiplication.
00158
00159                for (i = k+1; i < n; i++) {
00160                   V[i][k] = e[i];
00161                }
00162             }
00163          }
00164       }
00165
00166       // Set up the final bidiagonal matrix or order p.
00167
00168       int p = min(n,m+1);
00169       if (nct < n) {
00170          s[nct] = A[nct][nct];
00171       }
00172       if (m < p) {
00173          s[p-1] = 0.0;
00174       }
00175       if (nrt+1 < p) {
00176          e[nrt] = A[nrt][p-1];
00177       }
00178       e[p-1] = 0.0;
00179
00180       // If required, generate U.
00181
00182       if (wantu) {
00183          for (j = nct; j < nu; j++) {
00184             for (i = 0; i < m; i++) {
00185                U[i][j] = 0.0;
00186             }
00187             U[j][j] = 1.0;
00188          }
00189          for (k = nct-1; k >= 0; k--) {
00190             if (s[k] != 0.0) {
00191                for (j = k+1; j < nu; j++) {
00192                   double t = 0;
00193                   for (i = k; i < m; i++) {
00194                      t += U[i][k]*U[i][j];
00195                   }
00196                   t = -t/U[k][k];
00197                   for (i = k; i < m; i++) {
00198                      U[i][j] += t*U[i][k];
00199                   }
00200                }
00201                for (i = k; i < m; i++ ) {
00202                   U[i][k] = -U[i][k];
00203                }
00204                U[k][k] = 1.0 + U[k][k];
00205                for (i = 0; i < k-1; i++) {
00206                   U[i][k] = 0.0;
00207                }
00208             } else {
00209                for (i = 0; i < m; i++) {
00210                   U[i][k] = 0.0;
00211                }
00212                U[k][k] = 1.0;
00213             }
00214          }
00215       }
00216
00217       // If required, generate V.
00218
00219       if (wantv) {
00220          for (k = n-1; k >= 0; k--) {
00221             if ((k < nrt) & (e[k] != 0.0)) {
00222                for (j = k+1; j < nu; j++) {
00223                   double t = 0;
00224                   for (i = k+1; i < n; i++) {
00225                      t += V[i][k]*V[i][j];
00226                   }
00227                   t = -t/V[k+1][k];
00228                   for (i = k+1; i < n; i++) {
00229                      V[i][j] += t*V[i][k];
00230                   }
00231                }
00232             }
00233             for (i = 0; i < n; i++) {
00234                V[i][k] = 0.0;
00235             }
00236             V[k][k] = 1.0;
00237          }
00238       }
00239
00240       // Main iteration loop for the singular values.
00241
00242       int pp = p-1;
00243       int iter = 0;
00244       double eps = pow(2.0,-52.0);
00245       while (p > 0) {
00246          int k=0;
00247                  int kase=0;
00248
00249          // Here is where a test for too many iterations would go.
00250
00251          // This section of the program inspects for
00252          // negligible elements in the s and e arrays.  On
00253          // completion the variables kase and k are set as follows.
00254
00255          // kase = 1     if s(p) and e[k-1] are negligible and k<p
00256          // kase = 2     if s(k) is negligible and k<p
00257          // kase = 3     if e[k-1] is negligible, k<p, and
00258          //              s(k), ..., s(p) are not negligible (qr step).
00259          // kase = 4     if e(p-1) is negligible (convergence).
00260
00261          for (k = p-2; k >= -1; k--) {
00262             if (k == -1) {
00263                break;
00264             }
00265             if (abs(e[k]) <= eps*(abs(s[k]) + abs(s[k+1]))) {
00266                e[k] = 0.0;
00267                break;
00268             }
00269          }
00270          if (k == p-2) {
00271             kase = 4;
00272          } else {
00273             int ks;
00274             for (ks = p-1; ks >= k; ks--) {
00275                if (ks == k) {
00276                   break;
00277                }
00278                double t = (ks != p ? abs(e[ks]) : 0.) +
00279                           (ks != k+1 ? abs(e[ks-1]) : 0.);
00280                if (abs(s[ks]) <= eps*t)  {
00281                   s[ks] = 0.0;
00282                   break;
00283                }
00284             }
00285             if (ks == k) {
00286                kase = 3;
00287             } else if (ks == p-1) {
00288                kase = 1;
00289             } else {
00290                kase = 2;
00291                k = ks;
00292             }
00293          }
00294          k++;
00295
00296          // Perform the task indicated by kase.
00297
00298          switch (kase) {
00299
00300             // Deflate negligible s(p).
00301
00302             case 1: {
00303                double f = e[p-2];
00304                e[p-2] = 0.0;
00305                for (j = p-2; j >= k; j--) {
00306                   double t = hypot(s[j],f);
00307                   double cs = s[j]/t;
00308                   double sn = f/t;
00309                   s[j] = t;
00310                   if (j != k) {
00311                      f = -sn*e[j-1];
00312                      e[j-1] = cs*e[j-1];
00313                   }
00314                   if (wantv) {
00315                      for (i = 0; i < n; i++) {
00316                         t = cs*V[i][j] + sn*V[i][p-1];
00317                         V[i][p-1] = -sn*V[i][j] + cs*V[i][p-1];
00318                         V[i][j] = t;
00319                      }
00320                   }
00321                }
00322             }
00323             break;
00324
00325             // Split at negligible s(k).
00326
00327             case 2: {
00328                double f = e[k-1];
00329                e[k-1] = 0.0;
00330                for (j = k; j < p; j++) {
00331                   double t = hypot(s[j],f);
00332                   double cs = s[j]/t;
00333                   double sn = f/t;
00334                   s[j] = t;
00335                   f = -sn*e[j];
00336                   e[j] = cs*e[j];
00337                   if (wantu) {
00338                      for (i = 0; i < m; i++) {
00339                         t = cs*U[i][j] + sn*U[i][k-1];
00340                         U[i][k-1] = -sn*U[i][j] + cs*U[i][k-1];
00341                         U[i][j] = t;
00342                      }
00343                   }
00344                }
00345             }
00346             break;
00347
00348             // Perform one qr step.
00349
00350             case 3: {
00351
00352                // Calculate the shift.
00353
00354                double scale = max(max(max(max(
00355                        abs(s[p-1]),abs(s[p-2])),abs(e[p-2])),
00356                        abs(s[k])),abs(e[k]));
00357                double sp = s[p-1]/scale;
00358                double spm1 = s[p-2]/scale;
00359                double epm1 = e[p-2]/scale;
00360                double sk = s[k]/scale;
00361                double ek = e[k]/scale;
00362                double b = ((spm1 + sp)*(spm1 - sp) + epm1*epm1)/2.0;
00363                double c = (sp*epm1)*(sp*epm1);
00364                double shift = 0.0;
00365                if ((b != 0.0) | (c != 0.0)) {
00366                   shift = sqrt(b*b + c);
00367                   if (b < 0.0) {
00368                      shift = -shift;
00369                   }
00370                   shift = c/(b + shift);
00371                }
00372                double f = (sk + sp)*(sk - sp) + shift;
00373                double g = sk*ek;
00374
00375                // Chase zeros.
00376
00377                for (j = k; j < p-1; j++) {
00378                   double t = hypot(f,g);
00379                   double cs = f/t;
00380                   double sn = g/t;
00381                   if (j != k) {
00382                      e[j-1] = t;
00383                   }
00384                   f = cs*s[j] + sn*e[j];
00385                   e[j] = cs*e[j] - sn*s[j];
00386                   g = sn*s[j+1];
00387                   s[j+1] = cs*s[j+1];
00388                   if (wantv) {
00389                      for (i = 0; i < n; i++) {
00390                         t = cs*V[i][j] + sn*V[i][j+1];
00391                         V[i][j+1] = -sn*V[i][j] + cs*V[i][j+1];
00392                         V[i][j] = t;
00393                      }
00394                   }
00395                   t = hypot(f,g);
00396                   cs = f/t;
00397                   sn = g/t;
00398                   s[j] = t;
00399                   f = cs*e[j] + sn*s[j+1];
00400                   s[j+1] = -sn*e[j] + cs*s[j+1];
00401                   g = sn*e[j+1];
00402                   e[j+1] = cs*e[j+1];
00403                   if (wantu && (j < m-1)) {
00404                      for (i = 0; i < m; i++) {
00405                         t = cs*U[i][j] + sn*U[i][j+1];
00406                         U[i][j+1] = -sn*U[i][j] + cs*U[i][j+1];
00407                         U[i][j] = t;
00408                      }
00409                   }
00410                }
00411                e[p-2] = f;
00412                iter = iter + 1;
00413             }
00414             break;
00415
00416             // Convergence.
00417
00418             case 4: {
00419
00420                // Make the singular values positive.
00421
00422                if (s[k] <= 0.0) {
00423                   s[k] = (s[k] < 0.0 ? -s[k] : 0.0);
00424                   if (wantv) {
00425                      for (i = 0; i <= pp; i++) {
00426                         V[i][k] = -V[i][k];
00427                      }
00428                   }
00429                }
00430
00431                // Order the singular values.
00432
00433                while (k < pp) {
00434                   if (s[k] >= s[k+1]) {
00435                      break;
00436                   }
00437                   double t = s[k];
00438                   s[k] = s[k+1];
00439                   s[k+1] = t;
00440                   if (wantv && (k < n-1)) {
00441                      for (i = 0; i < n; i++) {
00442                         t = V[i][k+1]; V[i][k+1] = V[i][k]; V[i][k] = t;
00443                      }
00444                   }
00445                   if (wantu && (k < m-1)) {
00446                      for (i = 0; i < m; i++) {
00447                         t = U[i][k+1]; U[i][k+1] = U[i][k]; U[i][k] = t;
00448                      }
00449                   }
00450                   k++;
00451                }
00452                iter = 0;
00453                p--;
00454             }
00455             break;
00456          }
00457       }
00458    }
00459
00460
00461    void getU (Array2D<Real> &A)
00462    {
00463           int minm = min(m+1,n);
00464
00465           A = Array2D<Real>(m, minm);
00466
00467           for (int i=0; i<m; i++)
00468                 for (int j=0; j<minm; j++)
00469                         A[i][j] = U[i][j];
00470
00471    }
00472
00473    /* Return the right singular vectors */
00474
00475    void getV (Array2D<Real> &A)
00476    {
00477           A = V;
00478    }
00479
00481
00482    void getSingularValues (Array1D<Real> &x)
00483    {
00484       x = s;
00485    }
00486
00490
00491    void getS (Array2D<Real> &A) {
00492           A = Array2D<Real>(n,n);
00493       for (int i = 0; i < n; i++) {
00494          for (int j = 0; j < n; j++) {
00495             A[i][j] = 0.0;
00496          }
00497          A[i][i] = s[i];
00498       }
00499    }
00500
00502
00503    double norm2 () {
00504       return s[0];
00505    }
00506
00508
00509    double cond () {
00510       return s[0]/s[min(m,n)-1];
00511    }
00512
00516
00517    int rank ()
00518    {
00519       double eps = pow(2.0,-52.0);
00520       double tol = max(m,n)*s[0]*eps;
00521       int r = 0;
00522       for (int i = 0; i < s.dim(); i++) {
00523          if (s[i] > tol) {
00524             r++;
00525          }
00526       }
00527       return r;
00528    }
00529 };
00530
00531 }
00532 #endif
00533 // JAMA_SVD_H

Generated at Mon Jan 20 07:47:17 2003 for JAMA/C++ by 1.2.5 written by Dimitri van Heesch, © 1997-2001