aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-11 15:14:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-11 17:14:08 -0800
commit80e682562a202d5ed57c871aa76f0faec4061590 (patch)
tree1cfdc58299fe61b7a9e6648150d6e525acd7b5cc
parent3c933761e40101952fdcc0896bb3de1c5654192b (diff)
Switch the slow path in matrix_solve_ls to using Eigen::CompleteOrthogonalDecomposition (COD), which I recently contributed to Eigen in https://bitbucket.org/eigen/eigen/pull-requests/163/implement-complete-orthogonal/diff
The advantage of COD over column pivoted QR is that it is able to compute the minimum-norm solution when the matrix is rank-deficient, which is usually the desired behavior and makes it consistent with the fast path. Change: 114483303
-rw-r--r--WORKSPACE4
-rw-r--r--eigen.BUILD3
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op.cc34
-rw-r--r--tensorflow/core/ops/linalg_ops.cc20
-rw-r--r--tensorflow/core/ops/ops.pbtxt4
-rw-r--r--tensorflow/python/ops/linalg_ops.py9
-rw-r--r--tensorflow/python/platform/__init__.py20
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues3
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
13 files changed, 34 insertions, 73 deletions
diff --git a/WORKSPACE b/WORKSPACE
index a52fdf8345..f0fb83ce97 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -21,8 +21,8 @@ new_http_archive(
new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/726c779.tar.gz",
- sha256 = "30e0c5d84cfefc6a0bf7ae1e682b22788b5b2e408e7db7d9ea2d2aa9f70a72a9",
+ url = "https://bitbucket.org/eigen/eigen/get/0b9ab889fac2.tar.gz",
+ sha256 = "b9cff4ca8eb4889b1f52316b9f7362eec177898323c14d60d9fdb5ad2649c301",
build_file = "eigen.BUILD",
)
diff --git a/eigen.BUILD b/eigen.BUILD
index 084689c6f4..c8ce8191a9 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-726c779797e8"
+archive_dir = "eigen-eigen-0b9ab889fac2"
cc_library(
name = "eigen",
@@ -8,4 +8,3 @@ cc_library(
includes = [ archive_dir ],
visibility = ["//visibility:public"],
)
-
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc
index c69c93fcb1..9fba14a138 100644
--- a/tensorflow/core/kernels/matrix_solve_ls_op.cc
+++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc
@@ -139,31 +139,17 @@ class MatrixSolveLsOp
*output = matrix.transpose() * llt.solve(rhs);
}
} else {
- // Use a rank revealing factorization (QR with column pivoting).
+ // Use complete orthogonal decomposition which is backwards stable and
+ // will compute the minimum-norm solution for rank-deficient matrices.
+ // This is 6-7 times slower than the fast path.
//
- // NOTICE: Currently, Eigen's implementation of column pivoted Householder
- // QR has a few deficiencies:
- // 1. It does not implement the post-processing step to compute a
- // complete orthogonal factorization. This means that it does not
- // return a minimum-norm solution for underdetermined and
- // rank-deficient matrices. We could use the Eigen SVD instead, but
- // the currently available JacobiSVD is so slow that is it is
- // essentially useless (~100x slower than QR).
- // 2. The implementation is not blocked, so for matrics that do not fit
- // in cache, it is significantly slower than the equivalent blocked
- // LAPACK routine xGEQP3 (e.g. Eigen is ~3x slower for 4k x 4k
- // matrices). See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
- // 3. The implementation uses the numerically unstable norm downdating
- // formula from the original 1965 Businger & Golub paper. This can
- // lead to incorrect rank determination for graded matrices. I
- // (rmlarsen@) have a patch to bring this up to date by implementing
- // the robust formula from
- // http://www.netlib.org/lapack/lawnspdf/lawn176.pdf
- //
- // TODO(rmlarsen): a) Contribute 1. and 2. to Eigen.
- // b) Evaluate new divide-and-conquer SVD in Eigen when
- // it becomes available & robust.
- *output = matrix.colPivHouseholderQr().solve(rhs);
+ // TODO(rmlarsen): The implementation of
+ // Eigen::CompleteOrthogonalDecomposition is not blocked, so for
+ // matrices that do not fit in cache, it is significantly slower than
+ // the equivalent blocked LAPACK routine xGELSY (e.g. Eigen is ~3x
+ // slower for 4k x 4k matrices).
+ // See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
+ *output = matrix.completeOrthogonalDecomposition().solve(rhs);
}
}
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 952a592d51..834a898374 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -269,11 +269,10 @@ numerically full rank and has a condition number
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\)
or \\(\lambda\\) is sufficiently large.
-If `fast` is `False` then the solution is computed using the rank revealing QR
-decomposition with column pivoting. This will always compute a least-squares
-solution that minimizes the residual norm \\(||A X - B||_F^2 \\), even when
-\\( A \\) is rank deficient or ill-conditioned. Notice: The current version
-does not compute a minimum norm solution. If `fast` is `False` then
+If `fast` is `False` an algorithm based on the numerically robust complete
+orthogonal decomposition is used. This computes the minimum-norm
+least-squares solution, even when \\(A\\) is rank deficient. This path is
+typically 6-7 times slower than the fast path. If `fast` is `False` then
`l2_regularizer` is ignored.
matrix: Shape is `[M, N]`.
@@ -319,12 +318,11 @@ minimum-norm solution to the under-determined linear system, i.e.
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) is
sufficiently large.
-If `fast` is `False` then the solution is computed using the rank revealing QR
-decomposition with column pivoting. This will always compute a least-squares
-solution that minimizes the residual norm \\(||A X - B||_F^2\\), even when
-\\(A\\) is rank deficient or ill-conditioned. Notice: The current version does
-not compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`
-is ignored.
+If `fast` is `False` an algorithm based on the numerically robust complete
+orthogonal decomposition is used. This computes the minimum-norm
+least-squares solution, even when \\(A\\) is rank deficient. This path is
+typically 6-7 times slower than the fast path. If `fast` is `False` then
+`l2_regularizer` is ignored.
matrix: Shape is `[..., M, N]`.
rhs: Shape is `[..., M, K]`.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0fc5a225e0..874e0dc2c3 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1170,7 +1170,7 @@ op {
}
}
summary: "Solves multiple linear least-squares problems."
- description: "`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape `[..., M, K]`. The output\nis a tensor shape `[..., N, K]` where each output matrix solves each of\nthe equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the\nleast squares sense.\n\nBelow we will use the following notation for each pair of\nmatrix and right-hand sides in the batch:\n\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\), which (for \\\\(\\lambda = 0\\\\)) is the\nminimum-norm solution to the under-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\), subject to\n\\\\(A Z = B\\\\). Notice that the fast path is only numerically stable when\n\\\\(A\\\\) is numerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\) or\\\\(\\lambda\\\\) is\nsufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2\\\\), even when\n\\\\(A\\\\) is rank deficient or ill-conditioned. Notice: The current version does\nnot compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`\nis ignored."
+ description: "`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape `[..., M, K]`. The output\nis a tensor shape `[..., N, K]` where each output matrix solves each of\nthe equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the\nleast squares sense.\n\nBelow we will use the following notation for each pair of\nmatrix and right-hand sides in the batch:\n\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\), which (for \\\\(\\lambda = 0\\\\)) is the\nminimum-norm solution to the under-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\), subject to\n\\\\(A Z = B\\\\). Notice that the fast path is only numerically stable when\n\\\\(A\\\\) is numerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\) or\\\\(\\lambda\\\\) is\nsufficiently large.\n\nIf `fast` is `False` an algorithm based on the numerically robust complete\northogonal decomposition is used. This computes the minimum-norm\nleast-squares solution, even when \\\\(A\\\\) is rank deficient. This path is\ntypically 6-7 times slower than the fast path. If `fast` is `False` then\n`l2_regularizer` is ignored."
}
op {
name: "BatchMatrixTriangularSolve"
@@ -4344,7 +4344,7 @@ op {
}
}
summary: "Solves a linear least-squares problem."
- description: "Below we will use the following notation\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\),\nwhich (for \\\\(\\lambda = 0\\\\)) is the minimum-norm solution to the\nunder-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\),\nsubject to \\\\(A Z = B\\\\).\nNotice that the fast path is only numerically stable when \\\\(A\\\\) is\nnumerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\)\nor \\\\(\\lambda\\\\) is sufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2 \\\\), even when\n\\\\( A \\\\) is rank deficient or ill-conditioned. Notice: The current version\ndoes not compute a minimum norm solution. If `fast` is `False` then\n`l2_regularizer` is ignored."
+ description: "Below we will use the following notation\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\),\nwhich (for \\\\(\\lambda = 0\\\\)) is the minimum-norm solution to the\nunder-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\),\nsubject to \\\\(A Z = B\\\\).\nNotice that the fast path is only numerically stable when \\\\(A\\\\) is\nnumerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\)\nor \\\\(\\lambda\\\\) is sufficiently large.\n\nIf `fast` is `False` an algorithm based on the numerically robust complete\northogonal decomposition is used. This computes the minimum-norm\nleast-squares solution, even when \\\\(A\\\\) is rank deficient. This path is\ntypically 6-7 times slower than the fast path. If `fast` is `False` then\n`l2_regularizer` is ignored."
}
op {
name: "MatrixTriangularSolve"
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 8070cc1238..8d963fef68 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -253,11 +253,10 @@ def batch_matrix_solve_ls(matrix,
\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\)
is sufficiently large.
- If `fast` is `False` then the solution is computed using the rank revealing
- QR decomposition with column pivoting. This will always compute a
- least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\),
- even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current
- version does not compute a minimum norm solution. If `fast` is `False` then
+ If `fast` is `False` an algorithm based on the numerically robust complete
+ orthogonal decomposition is used. This computes the minimum-norm
+ least-squares solution, even when \\(A\\) is rank deficient. This path is
+ typically 6-7 times slower than the fast path. If `fast` is `False` then
`l2_regularizer` is ignored.
Args:
diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py
deleted file mode 100644
index aee1acdd46..0000000000
--- a/tensorflow/python/platform/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""DEPRECATED: Setup system-specific platform environment for TensorFlow."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index 671ec3d4a6..27fcc7e50a 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-726c779797e8/Eigen/Cholesky"
+#include "eigen-eigen-0b9ab889fac2/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index 38f45037e6..4a21b2fc3b 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-726c779797e8/Eigen/Core"
+#include "eigen-eigen-0b9ab889fac2/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 64f4200304..92185d0565 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1,2 +1 @@
-#include "eigen-eigen-726c779797e8/Eigen/Eigenvalues"
-
+#include "eigen-eigen-0b9ab889fac2/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index ab9e6cb4c5..e2be07ee89 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-726c779797e8/Eigen/LU"
+#include "eigen-eigen-0b9ab889fac2/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index 9ecf7be16d..3bffd1df6d 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-726c779797e8/Eigen/QR"
+#include "eigen-eigen-0b9ab889fac2/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index a80816717b..81a51b4593 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-726c779797e8/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-0b9ab889fac2/unsupported/Eigen/CXX11/Tensor"