diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-01-26 09:30:53 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-26 10:59:37 -0800 |
commit | 38242dd94a87a0e72f39315447866e55923d642c (patch) | |
tree | 18df5f31800dd52067b965894e63eaad2acd6982 | |
parent | 7d4a063834cf9e5cdebe95d2115ce2e527296d79 (diff) |
Add new matrix_solve_ls op for solving linear least-squares problems.
Change: 113064195
m--------- | google/protobuf | 0 | ||||
-rw-r--r-- | tensorflow/core/kernels/matrix_solve_ls_op.cc | 183 | ||||
-rw-r--r-- | tensorflow/core/ops/linalg_ops.cc | 105 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 94 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/index.md | 2 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/math_ops.md | 109 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/state_ops.md | 3 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/train.md | 8 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py | 198 | ||||
-rw-r--r-- | tensorflow/python/ops/linalg_ops.py | 137 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 3 |
12 files changed, 813 insertions, 30 deletions
diff --git a/google/protobuf b/google/protobuf -Subproject 60f7fc51fb01ac98fd64d9372e7547f5ab267ce +Subproject b13874d59e976371a1a87e6dc2bf347ed0a0ce5 diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc new file mode 100644 index 0000000000..b752a7ed6e --- /dev/null +++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc @@ -0,0 +1,183 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. +#include <cmath> + +#include "third_party/eigen3/Eigen/Cholesky" +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/Eigen/QR" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/binary_linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +template <class Scalar, bool SupportsBatchOperationT> +class MatrixSolveLsOp + : public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> { + public: + explicit MatrixSolveLsOp(OpKernelConstruction* context) + : BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) { + OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_)); + } + + ~MatrixSolveLsOp() override {} + + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_shape, + const TensorShape& rhs_matrix_shape) override { + CHECK_EQ(input_matrix_shape.dims(), rhs_matrix_shape.dims()); + TensorShape output_matrix_shape = rhs_matrix_shape; + output_matrix_shape.set_dim( + output_matrix_shape.dims() - 2, + input_matrix_shape.dim_size(output_matrix_shape.dims() - 1)); + return output_matrix_shape; + } + + int64 GetCostPerUnit(const TensorShape& input_matrix_shape, + const TensorShape& rhs_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + const int64 rhss = rhs_matrix_shape.dim_size(1); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint32max; + } else { + return 2 * rows * rows * (rows + rhss); + } + } + + using typename BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix; + using typename BinaryLinearAlgebraOp<Scalar, + SupportsBatchOperationT>::MatrixMap; + using typename BinaryLinearAlgebraOp<Scalar, + SupportsBatchOperationT>::ConstMatrixMap; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& matrix, + const ConstMatrixMap& rhs, MatrixMap* output) override { + const int64 rows = matrix.rows(); + const int64 cols = matrix.cols(); + OP_REQUIRES( + context, rows == rhs.rows(), + errors::InvalidArgument("Input matrix and rhs are incompatible.")); + const auto& l2_regularizer_in = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()), + errors::InvalidArgument("l2_regularizer must be scalar, got shape ", + l2_regularizer_in.shape().DebugString())); + const double l2_regularizer = l2_regularizer_in.scalar<double>()(); + + OP_REQUIRES(context, l2_regularizer >= 0, + errors::InvalidArgument("l2_regularizer must be >= 0.")); + if (rows == 0 || cols == 0) { + // The result is the empty matrix. + return; + } + if (fast_) { + // The fast branch assumes that matrix is not rank deficient and + // not too ill-conditioned. Specifically, the reciprobal condition number + // should be greater than the square root of the machine precision, i.e. + // 1 / cond(matrix) > sqrt(std::numeric_limits<Scalar>::epsilon()). + // This branch solves over- or underdetermined least-squares problems + // via the normal equations and Cholesky decomposition. + if (matrix.rows() >= matrix.cols()) { + // Overdetermined case (rows >= cols): Solves the ordinary (possibly + // regularized) least-squares problem + // min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2 + // by solving the normal equations + // (A^T * A + l2_regularizer * I) X = A^T RHS + // using Cholesky decomposition. + Matrix gramian(cols, cols); + gramian.template triangularView<Eigen::Lower>() = + matrix.transpose() * matrix; + if (l2_regularizer > 0) { + gramian += + (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal(); + } + const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian); + OP_REQUIRES( + context, llt.info() == Eigen::Success, + errors::InvalidArgument("Input matrix was rank deficient or " + "ill-conditioned. Try setting fast=False " + "or provide a larger l2_regularizer > 0.")); + *output = llt.solve(matrix.transpose() * rhs); + } else { + // Underdetermined case (rows < cols): Solves the minimum-norm problem + // min ||X||_F^2 s.t. A*X = RHS + // by solving the normal equations of the second kind + // (A * A^T + l2_regularizer * I) Z = RHS, X = A^T * Z + // using Cholesky decomposition. + Matrix gramian(rows, rows); + gramian.template triangularView<Eigen::Lower>() = + matrix * matrix.transpose(); + if (l2_regularizer > 0) { + gramian += + (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal(); + } + const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian); + OP_REQUIRES( + context, llt.info() == Eigen::Success, + errors::InvalidArgument("Input matrix was rank deficient or " + "ill-conditioned. Try setting fast=False " + "or provide an l2_regularizer > 0.")); + *output = matrix.transpose() * llt.solve(rhs); + } + } else { + // Use a rank revealing factorization (QR with column pivoting). + // + // NOTICE: Currently, Eigen's implementation of column pivoted Householder + // QR has a few deficiencies: + // 1. It does not implement the post-processing step to compute a + // complete orthogonal factorization. This means that it does not + // return a minimum-norm solution for underdetermined and + // rank-deficient matrices. We could use the Eigen SVD instead, but + // the currently available JacobiSVD is so slow that is it is + // essentially useless (~100x slower than QR). + // 2. The implementation is not blocked, so for matrics that do not fit + // in cache, it is significantly slower than the equivalent blocked + // LAPACK routine xGEQP3 (e.g. Eigen is ~3x slower for 4k x 4k + // matrices). See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf + // 3. The implementation uses the numerically unstable norm downdating + // formula from the original 1965 Businger & Golub paper. This can + // lead to incorrect rank determination for graded matrices. I + // (rmlarsen@) have a patch to bring this up to date by implementing + // the robust formula from + // http://www.netlib.org/lapack/lawnspdf/lawn176.pdf + // + // TODO(rmlarsen): a) Contribute 1. and 2. to Eigen. + // b) Evaluate new divide-and-conquer SVD in Eigen when + // it becomes available & robust. + *output = matrix.colPivHouseholderQr().solve(rhs); + } + } + + private: + bool fast_; +}; + +REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float, false>), + float); +REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double, false>), + double); +REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float, true>), + float); +REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double, true>), + double); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 74a19135c2..952a592d51 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -26,7 +26,6 @@ Calculates the determinant of a square matrix. input: A tensor of shape `[M, M]`. output: A scalar, equal to the determinant of the input. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchMatrixDeterminant") @@ -42,7 +41,6 @@ for all input submatrices `[..., :, :]`. input: Shape is `[..., M, M]`. output: Shape is `[...]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("MatrixInverse") @@ -61,7 +59,6 @@ garbage result. input: Shape is `[M, M]`. output: Shape is `[M, M]` containing the matrix inverse of the input. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchMatrixInverse") @@ -84,7 +81,6 @@ garbage result. input: Shape is `[..., M, M]`. output: Shape is `[..., M, M]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("Cholesky") @@ -103,7 +99,6 @@ input. input: Shape is `[M, M]`. output: Shape is `[M, M]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchCholesky") @@ -120,7 +115,6 @@ containing the Cholesky decompositions for all input submatrices `[..., :, :]`. input: Shape is `[..., M, M]`. output: Shape is `[..., M, M]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("SelfAdjointEig") @@ -138,7 +132,6 @@ subsequent rows are eigenvectors. input: Shape is `[M, M]`. output: Shape is `[M+1, M]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchSelfAdjointEig") @@ -157,7 +150,6 @@ eigenvalues, and subsequent [...,1:, :] containing the eigenvectors. input: Shape is `[..., M, M]`. output: Shape is `[..., M+1, M]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("MatrixSolve") @@ -172,7 +164,6 @@ matrix: Shape is `[M, M]`. rhs: Shape is `[M, K]`. output: Shape is `[M, K]` containing the tensor that solves matrix * output = rhs. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchMatrixSolve") @@ -191,7 +182,6 @@ matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]. matrix: Shape is `[..., M, M]`. rhs: Shape is `[..., M, K]`. output: Shape is `[..., M, K]`. -T: The type of values in the input and output. )doc"); REGISTER_OP("MatrixTriangularSolve") @@ -218,7 +208,6 @@ matrix: Shape is `[M, M]`. rhs: Shape is `[M, K]`. output: Shape is `[M, K]`. lower: Boolean indicating whether matrix is lower or upper triangular. -T: The type of values in the input and output. )doc"); REGISTER_OP("BatchMatrixTriangularSolve") @@ -247,7 +236,99 @@ matrix: Shape is `[..., M, M]`. rhs: Shape is `[..., M, K]`. output: Shape is `[..., M, K]`. lower: Boolean indicating whether matrix is lower or upper triangular. -T: The type of values in the input and output. +)doc"); + +REGISTER_OP("MatrixSolveLs") + .Input("matrix: T") + .Input("rhs: T") + .Input("l2_regularizer: double") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("fast: bool = True") + .Doc(R"doc( +Solves a linear least-squares problem. + +Below we will use the following notation +`matrix`=\\(A \in \Re^{m \times n}\\), +`rhs`=\\(B \in \Re^{m \times k}\\), +`output`=\\(X \in \Re^{n \times k}\\), +`l2_regularizer`=\\(\lambda\\). + +If `fast` is `True`, then the solution is computed by solving the normal +equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 + +\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +\\(X = A^T (A A^T + \lambda I)^{-1} B\\), +which (for \\(\lambda = 0\\)) is the minimum-norm solution to the +under-determined linear system, i.e. +\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), +subject to \\(A Z = B\\). +Notice that the fast path is only numerically stable when \\(A\\) is +numerically full rank and has a condition number +\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) +or \\(\lambda\\) is sufficiently large. + +If `fast` is `False` then the solution is computed using the rank revealing QR +decomposition with column pivoting. This will always compute a least-squares +solution that minimizes the residual norm \\(||A X - B||_F^2 \\), even when +\\( A \\) is rank deficient or ill-conditioned. Notice: The current version +does not compute a minimum norm solution. If `fast` is `False` then +`l2_regularizer` is ignored. + +matrix: Shape is `[M, N]`. +rhs: Shape is `[M, K]`. +output: Shape is `[N, K]` containing the tensor that solves + `matrix * output = rhs` in the least-squares sense. +)doc"); + +REGISTER_OP("BatchMatrixSolveLs") + .Input("matrix: T") + .Input("rhs: T") + .Input("l2_regularizer: double") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("fast: bool = True") + .Doc(R"doc( +Solves multiple linear least-squares problems. + +`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +form square matrices. Rhs is a tensor of shape `[..., M, K]`. The output +is a tensor shape `[..., N, K]` where each output matrix solves each of +the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the +least squares sense. + +Below we will use the following notation for each pair of +matrix and right-hand sides in the batch: + +`matrix`=\\(A \in \Re^{m \times n}\\), +`rhs`=\\(B \in \Re^{m \times k}\\), +`output`=\\(X \in \Re^{n \times k}\\), +`l2_regularizer`=\\(\lambda\\). + +If `fast` is `True`, then the solution is computed by solving the normal +equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 + +\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +minimum-norm solution to the under-determined linear system, i.e. +\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to +\\(A Z = B\\). Notice that the fast path is only numerically stable when +\\(A\\) is numerically full rank and has a condition number +\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) is +sufficiently large. + +If `fast` is `False` then the solution is computed using the rank revealing QR +decomposition with column pivoting. This will always compute a least-squares +solution that minimizes the residual norm \\(||A X - B||_F^2\\), even when +\\(A\\) is rank deficient or ill-conditioned. Notice: The current version does +not compute a minimum norm solution. If `fast` is `False` then `l2_regularizer` +is ignored. + +matrix: Shape is `[..., M, N]`. +rhs: Shape is `[..., M, K]`. +output: Shape is `[..., N, K]`. )doc"); } // namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index d2f0885c6e..19f9afb94b 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -993,7 +993,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_DOUBLE @@ -1067,7 +1066,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -1093,7 +1091,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -1124,7 +1121,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -1136,6 +1132,47 @@ op { description: "Matrix is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape\n`[..., M, K]`. The output is a tensor shape `[..., M, K]` where each output\nmatrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]." } op { + name: "BatchMatrixSolveLs" + input_arg { + name: "matrix" + description: "Shape is `[..., M, N]`." + type_attr: "T" + } + input_arg { + name: "rhs" + description: "Shape is `[..., M, K]`." + type_attr: "T" + } + input_arg { + name: "l2_regularizer" + type: DT_DOUBLE + } + output_arg { + name: "output" + description: "Shape is `[..., N, K]`." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "fast" + type: "bool" + default_value { + b: true + } + } + summary: "Solves multiple linear least-squares problems." + description: "`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape `[..., M, K]`. The output\nis a tensor shape `[..., N, K]` where each output matrix solves each of\nthe equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the\nleast squares sense.\n\nBelow we will use the following notation for each pair of\nmatrix and right-hand sides in the batch:\n\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\), which (for \\\\(\\lambda = 0\\\\)) is the\nminimum-norm solution to the under-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\), subject to\n\\\\(A Z = B\\\\). Notice that the fast path is only numerically stable when\n\\\\(A\\\\) is numerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\) or\\\\(\\lambda\\\\) is\nsufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2\\\\), even when\n\\\\(A\\\\) is rank deficient or ill-conditioned. Notice: The current version does\nnot compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`\nis ignored." +} +op { name: "BatchMatrixTriangularSolve" input_arg { name: "matrix" @@ -1163,7 +1200,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -1336,7 +1372,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_DOUBLE @@ -1493,7 +1528,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_DOUBLE @@ -4123,7 +4157,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -4148,7 +4181,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -4179,7 +4211,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -4190,6 +4221,47 @@ op { summary: "Solves a system of linear equations. Checks for invertibility." } op { + name: "MatrixSolveLs" + input_arg { + name: "matrix" + description: "Shape is `[M, N]`." + type_attr: "T" + } + input_arg { + name: "rhs" + description: "Shape is `[M, K]`." + type_attr: "T" + } + input_arg { + name: "l2_regularizer" + type: DT_DOUBLE + } + output_arg { + name: "output" + description: "Shape is `[N, K]` containing the tensor that solves\n`matrix * output = rhs` in the least-squares sense." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "fast" + type: "bool" + default_value { + b: true + } + } + summary: "Solves a linear least-squares problem." + description: "Below we will use the following notation\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\),\nwhich (for \\\\(\\lambda = 0\\\\)) is the minimum-norm solution to the\nunder-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\),\nsubject to \\\\(A Z = B\\\\).\nNotice that the fast path is only numerically stable when \\\\(A\\\\) is\nnumerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\)\nor \\\\(\\lambda\\\\) is sufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2 \\\\), even when\n\\\\( A \\\\) is rank deficient or ill-conditioned. Notice: The current version\ndoes not compute a minimum norm solution. If `fast` is `False` then\n`l2_regularizer` is ignored." +} +op { name: "MatrixTriangularSolve" input_arg { name: "matrix" @@ -4217,7 +4289,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_FLOAT @@ -7396,7 +7467,6 @@ op { attr { name: "T" type: "type" - description: "The type of values in the input and output." allowed_values { list { type: DT_DOUBLE diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index e70f04a5e1..253f13f38f 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -120,6 +120,7 @@ * [`batch_matrix_determinant`](../../api_docs/python/math_ops.md#batch_matrix_determinant) * [`batch_matrix_inverse`](../../api_docs/python/math_ops.md#batch_matrix_inverse) * [`batch_matrix_solve`](../../api_docs/python/math_ops.md#batch_matrix_solve) + * [`batch_matrix_solve_ls`](../../api_docs/python/math_ops.md#batch_matrix_solve_ls) * [`batch_matrix_triangular_solve`](../../api_docs/python/math_ops.md#batch_matrix_triangular_solve) * [`batch_self_adjoint_eig`](../../api_docs/python/math_ops.md#batch_self_adjoint_eig) * [`ceil`](../../api_docs/python/math_ops.md#ceil) @@ -148,6 +149,7 @@ * [`matrix_determinant`](../../api_docs/python/math_ops.md#matrix_determinant) * [`matrix_inverse`](../../api_docs/python/math_ops.md#matrix_inverse) * [`matrix_solve`](../../api_docs/python/math_ops.md#matrix_solve) + * [`matrix_solve_ls`](../../api_docs/python/math_ops.md#matrix_solve_ls) * [`matrix_triangular_solve`](../../api_docs/python/math_ops.md#matrix_triangular_solve) * [`maximum`](../../api_docs/python/math_ops.md#maximum) * [`minimum`](../../api_docs/python/math_ops.md#minimum) diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index 4fba7e08b2..8303c14a5b 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -1087,6 +1087,115 @@ If `lower` is false then the strictly then the output satisfies +- - - + +### `tf.matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None)` {#matrix_solve_ls} + +Solves a linear least-squares problem. + +Below we will use the following notation +`matrix`=\\(A \in \Re^{m \times n}\\), +`rhs`=\\(B \in \Re^{m \times k}\\), +`output`=\\(X \in \Re^{n \times k}\\), +`l2_regularizer`=\\(\lambda\\). + +If `fast` is `True`, then the solution is computed by solving the normal +equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the regularized +least-squares problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} +||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is +computed as \\(X = A^T (A A^T + \lambda I)^{-1} B\\), +which (for \\(\lambda = 0\\)) is the minimum-norm solution to the +under-determined linear system, i.e. +\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), +subject to \\(A Z = B\\). +Notice that the fast path is only numerically stable when \\(A\\) is +numerically full rank and has a condition number +\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) +or \\(\lambda\\) is sufficiently large. + +If `fast` is `False` then the solution is computed using the rank revealing +QR decomposition with column pivoting. This will always compute a +least-squares solution that minimizes the residual norm +\\(||A X - B||_F^2 \\), even when \\(A\\) is rank deficient or +ill-conditioned. Notice: The current version does not compute a minimum norm +solution. If `fast` is `False` then `l2_regularizer` is ignored. + +##### Args: + + +* <b>`matrix`</b>: 2-D `Tensor` of shape `[M, N]`. +* <b>`rhs`</b>: 2-D `Tensor` of shape is `[M, K]`. +* <b>`l2_regularizer`</b>: 0-D `double` `Tensor`. Ignored if `fast=False`. +* <b>`fast`</b>: bool. Defaults to `True`. + +##### Returns: + + +* <b>`output`</b>: Matrix of shape `[N, K]` containing the matrix that solves + `matrix * output = rhs` in the least-squares sense. + + +- - - + +### `tf.batch_matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None)` {#batch_matrix_solve_ls} + +Solves multiple linear least-squares problems. + +`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose +inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a +`Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K` +matrices that solve the equations +`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares +sense. + +Below we will use the following notation for each pair of +matrix and right-hand sides in the batch: + +`matrix`=\\(A \in \Re^{m \times n}\\), +`rhs`=\\(B \in \Re^{m \times k}\\), +`output`=\\(X \in \Re^{n \times k}\\), +`l2_regularizer`=\\(\lambda\\). + +If `fast` is `True`, then the solution is computed by solving the normal +equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares +problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 + +\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as +\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is +the minimum-norm solution to the under-determined linear system, i.e. +\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to +\\(A Z = B\\). Notice that the fast path is only numerically stable when +\\(A\\) is numerically full rank and has a condition number +\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) +is sufficiently large. + +If `fast` is `False` then the solution is computed using the rank revealing +QR decomposition with column pivoting. This will always compute a +least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\), +even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current +version does not compute a minimum norm solution. If `fast` is `False` then +`l2_regularizer` is ignored. + +##### Args: + + +* <b>`matrix`</b>: `Tensor` of shape `[..., M, N]`. +* <b>`rhs`</b>: `Tensor` of shape `[..., M, K]`. +* <b>`l2_regularizer`</b>: 0-D `double` `Tensor`. Ignored if `fast=False`. +* <b>`fast`</b>: bool. Defaults to `True`. + +##### Returns: + + +* <b>`output`</b>: `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form + `M`-by-`K` matrices that solve the equations + `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least + squares sense. + + + ## Complex Number Functions TensorFlow provides several operations that you can use to add complex number diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 1a275f7ea2..feac191afa 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -897,7 +897,8 @@ with tf.variable_scope("foo", reuse=True) If initializer is `None` (the default), the default initializer passed in the constructor is used. If that one is `None` too, a -`UniformUnitScalingInitializer` will be used. +`UniformUnitScalingInitializer` will be used. The initializer can also be +a Tensor, in which case the variable is initialized to this value and shape. ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index f9a9144f59..7650d3fb5b 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -94,7 +94,7 @@ This must be called by the constructors of subclasses. - - - -#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` {#Optimizer.minimize} +#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, name=None)` {#Optimizer.minimize} Add operations to minimize `loss` by updating `var_list`. @@ -116,6 +116,8 @@ of using this function. `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. * <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. +* <b>`colocate_gradients_with_ops`</b>: If True, try colocating gradients with + the corresponding op. * <b>`name`</b>: Optional name for the returned operation. ##### Returns: @@ -131,7 +133,7 @@ of using this function. - - - -#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` {#Optimizer.compute_gradients} +#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False)` {#Optimizer.compute_gradients} Compute gradients of `loss` for the variables in `var_list`. @@ -152,6 +154,8 @@ given variable. `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. * <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. +* <b>`colocate_gradients_with_ops`</b>: If True, try colocating gradients with + the corresponding op. ##### Returns: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f87fee1211..0862e39f6f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -986,6 +986,7 @@ cpu_only_kernel_test_list = glob([ "kernel_tests/matrix_inverse_op_test.py", "kernel_tests/matrix_solve_op_test.py", "kernel_tests/matrix_triangular_solve_op_test.py", + "kernel_tests/matrix_solve_ls_op_test.py", "kernel_tests/mutex_ops_test.py", "kernel_tests/parsing_ops_test.py", "kernel_tests/queue_ops_test.py", diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py new file mode 100644 index 0000000000..2ee334a089 --- /dev/null +++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py @@ -0,0 +1,198 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tensorflow.ops.math_ops.matrix_solve.""" +# pylint: disable=unused-import,g-bad-import-order +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +# pylint: enable=unused-import,g-bad-import-order + + +def BatchMatMul(a, b): + # A numpy implementation of tf.batch_matmul(). + if a.ndim < 3: + return np.dot(a, b) + # Get the number of matrices. + n = np.prod(a.shape[:-2]) + assert n == np.prod(b.shape[:-2]) + a_flat = np.reshape(a, tuple([n]) + a.shape[-2:]) + b_flat = np.reshape(b, tuple([n]) + b.shape[-2:]) + c_flat_shape = [n, a.shape[-2], b.shape[-1]] + c_flat = np.empty(c_flat_shape) + for i in range(n): + c_flat[i, :, :] = np.dot(a_flat[i, :, :], b_flat[i, :, :]) + return np.reshape(c_flat, a.shape[:-1] + b_flat.shape[-1:]) + + +def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0): + # A numpy implementation of regularized least squares solver using + # the normal equations. + matrix_dims = matrices.shape + matrices_transposed = np.swapaxes(matrices, -2, -1) + rows = matrix_dims[-2] + cols = matrix_dims[-1] + if rows >= cols: + preconditioner = l2_regularization * np.identity(cols) + gramian = BatchMatMul(matrices_transposed, matrices) + preconditioner + inverse = np.linalg.inv(gramian) + left_pseudo_inverse = BatchMatMul(inverse, matrices_transposed) + return BatchMatMul(left_pseudo_inverse, rhss) + else: + preconditioner = l2_regularization * np.identity(rows) + gramian = BatchMatMul(matrices, matrices_transposed) + preconditioner + inverse = np.linalg.inv(gramian) + right_pseudo_inverse = BatchMatMul(matrices_transposed, inverse) + return BatchMatMul(right_pseudo_inverse, rhss) + + +class MatrixSolveLsOpTest(tf.test.TestCase): + + def _verifySolve(self, x, y): + for np_type in [np.float32, np.float64]: + a = x.astype(np_type) + b = y.astype(np_type) + np_ans, _, _, _ = np.linalg.lstsq(a, b) + for fast in [True, False]: + with self.test_session(): + tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval() + self.assertEqual(np_ans.shape, tf_ans.shape) + + # Check residual norm. + tf_r = b - BatchMatMul(a, tf_ans) + tf_r_norm = np.sum(tf_r * tf_r) + np_r = b - BatchMatMul(a, np_ans) + np_r_norm = np.sum(np_r * np_r) + self.assertAllClose(np_r_norm, tf_r_norm) + + # Check solution. + if fast or a.shape[0] >= a.shape[1]: + # We skip this test for the underdetermined case when using the + # slow path, because Eigen does not return a minimum norm solution. + # TODO(rmlarsen): Enable this check for all paths if/when we fix + # Eigen's solver. + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + + def _verifySolveBatch(self, x, y): + # Since numpy.linalg.lsqr does not support batch solves, as opposed + # to numpy.linalg.solve, we just perform this test for a fixed batch size + # of 2x3. + for np_type in [np.float32, np.float64]: + a = np.tile(x.astype(np_type), [2, 3, 1, 1]) + b = np.tile(y.astype(np_type), [2, 3, 1, 1]) + np_ans = np.empty([2, 3, a.shape[-1], b.shape[-1]]) + for dim1 in range(2): + for dim2 in range(3): + np_ans[dim1, dim2, :, :], _, _, _ = np.linalg.lstsq( + a[dim1, dim2, :, :], b[dim1, dim2, :, :]) + for fast in [True, False]: + with self.test_session(): + tf_ans = tf.batch_matrix_solve_ls(a, b, fast=fast).eval() + self.assertEqual(np_ans.shape, tf_ans.shape) + # Check residual norm. + tf_r = b - BatchMatMul(a, tf_ans) + tf_r_norm = np.sum(tf_r * tf_r) + np_r = b - BatchMatMul(a, np_ans) + np_r_norm = np.sum(np_r * np_r) + self.assertAllClose(np_r_norm, tf_r_norm) + # Check solution. + if fast or a.shape[-2] >= a.shape[-1]: + # We skip this test for the underdetermined case when using the + # slow path, because Eigen does not return a minimum norm solution. + # TODO(rmlarsen): Enable this check for all paths if/when we fix + # Eigen's solver. + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + + def _verifyRegularized(self, x, y, l2_regularizer): + for np_type in [np.float32, np.float64]: + # Test with a single matrix. + a = x.astype(np_type) + b = y.astype(np_type) + np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer) + with self.test_session(): + tf_ans = tf.matrix_solve_ls(a, + b, + l2_regularizer=l2_regularizer, + fast=True).eval() + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + # Test with a 2x3 batch of matrices. + a = np.tile(x.astype(np_type), [2, 3, 1, 1]) + b = np.tile(y.astype(np_type), [2, 3, 1, 1]) + np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer) + with self.test_session(): + tf_ans = tf.batch_matrix_solve_ls(a, + b, + l2_regularizer=l2_regularizer, + fast=True).eval() + self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5) + + def testSquare(self): + # 2x2 matrices, 2x3 right-hand sides. + matrix = np.array([[1., 2.], [3., 4.]]) + rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) + self._verifySolve(matrix, rhs) + self._verifySolveBatch(matrix, rhs) + self._verifyRegularized(matrix, rhs, l2_regularizer=0.1) + + def testOverdetermined(self): + # 2x2 matrices, 2x3 right-hand sides. + matrix = np.array([[1., 2.], [3., 4.], [5., 6.]]) + rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]]) + self._verifySolve(matrix, rhs) + self._verifySolveBatch(matrix, rhs) + self._verifyRegularized(matrix, rhs, l2_regularizer=0.1) + + def testUnderdetermined(self): + # 2x2 matrices, 2x3 right-hand sides. + matrix = np.array([[1., 2., 3], [4., 5., 6.]]) + rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) + self._verifySolve(matrix, rhs) + self._verifySolveBatch(matrix, rhs) + self._verifyRegularized(matrix, rhs, l2_regularizer=0.1) + + def testWrongDimensions(self): + # The matrix and right-hand sides should have the same number of rows. + with self.test_session(): + matrix = tf.constant([[1., 0.], [0., 1.]]) + rhs = tf.constant([[1., 0.]]) + with self.assertRaises(ValueError): + tf.matrix_solve_ls(matrix, rhs) + with self.assertRaises(ValueError): + tf.batch_matrix_solve_ls(matrix, rhs) + + def testEmpty(self): + full = np.array([[1., 2.], [3., 4.], [5., 6.]]) + empty0 = np.empty([3, 0]) + empty1 = np.empty([0, 2]) + for fast in [True, False]: + with self.test_session(): + tf_ans = tf.matrix_solve_ls(empty0, empty0, fast=fast).eval() + self.assertEqual(tf_ans.shape, (0, 0)) + tf_ans = tf.matrix_solve_ls(empty0, full, fast=fast).eval() + self.assertEqual(tf_ans.shape, (0, 2)) + tf_ans = tf.matrix_solve_ls(full, empty0, fast=fast).eval() + self.assertEqual(tf_ans.shape, (2, 0)) + tf_ans = tf.matrix_solve_ls(empty1, empty1, fast=fast).eval() + self.assertEqual(tf_ans.shape, (2, 2)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 72e5098398..091016e950 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -108,9 +108,9 @@ def _MatrixSolveShape(op): rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) # The matrix must be square. lhs_shape[0].assert_is_compatible_with(lhs_shape[1]) - # The matrix and righ-hand-side must have the same number of rows. + # The matrix and right-hand side must have the same number of rows. lhs_shape[0].assert_is_compatible_with(rhs_shape[0]) - return [[lhs_shape[0], rhs_shape[1]]] + return [[lhs_shape[1], rhs_shape[1]]] @ops.RegisterShape("BatchMatrixSolve") @@ -119,7 +119,7 @@ def _BatchMatrixSolveShape(op): rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3) # The matrices must be square. lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2]) - # The matrices and righ-hand-sides in the batch must have the same number of + # The matrices and right-hand sides in the batch must have the same number of # rows. lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2]) return [lhs_shape[:-2].concatenate(rhs_shape[-1])] @@ -146,3 +146,134 @@ def _BatchMatrixTriangularSolveShape(op): # rows. lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2]) return [rhs_shape] + + +@ops.RegisterShape("MatrixSolveLs") +def _MatrixSolveLsShape(op): + lhs_shape = op.inputs[0].get_shape().with_rank(2) + rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) + # The matrix and right-hand side must have the same number of rows. + lhs_shape[0].assert_is_compatible_with(rhs_shape[0]) + return [[lhs_shape[1], rhs_shape[1]]] + + +@ops.RegisterShape("BatchMatrixSolveLs") +def _BatchMatrixSolveLsShape(op): + lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3) + rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3) + # The matrices and right-hand sides in the batch must have the same number of + # rows. + lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2]) + return [lhs_shape[:-3].concatenate([lhs_shape[-1], rhs_shape[-1]])] + + +def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): + r"""Solves a linear least-squares problem. + + Below we will use the following notation + `matrix`=\\(A \in \Re^{m \times n}\\), + `rhs`=\\(B \in \Re^{m \times k}\\), + `output`=\\(X \in \Re^{n \times k}\\), + `l2_regularizer`=\\(\lambda\\). + + If `fast` is `True`, then the solution is computed by solving the normal + equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then + \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the regularized + least-squares problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} + ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is + computed as \\(X = A^T (A A^T + \lambda I)^{-1} B\\), + which (for \\(\lambda = 0\\)) is the minimum-norm solution to the + under-determined linear system, i.e. + \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), + subject to \\(A Z = B\\). + Notice that the fast path is only numerically stable when \\(A\\) is + numerically full rank and has a condition number + \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) + or \\(\lambda\\) is sufficiently large. + + If `fast` is `False` then the solution is computed using the rank revealing + QR decomposition with column pivoting. This will always compute a + least-squares solution that minimizes the residual norm + \\(||A X - B||_F^2 \\), even when \\(A\\) is rank deficient or + ill-conditioned. Notice: The current version does not compute a minimum norm + solution. If `fast` is `False` then `l2_regularizer` is ignored. + + Args: + matrix: 2-D `Tensor` of shape `[M, N]`. + rhs: 2-D `Tensor` of shape is `[M, K]`. + l2_regularizer: 0-D `double` `Tensor`. Ignored if `fast=False`. + fast: bool. Defaults to `True`. + name: string, optional name of the operation. + + Returns: + output: Matrix of shape `[N, K]` containing the matrix that solves + `matrix * output = rhs` in the least-squares sense. + """ + return gen_linalg_ops.matrix_solve_ls(matrix, + rhs, + l2_regularizer, + fast=fast, + name=name) + + +def batch_matrix_solve_ls(matrix, + rhs, + l2_regularizer=0.0, + fast=True, + name=None): + r"""Solves multiple linear least-squares problems. + + `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions + form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose + inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a + `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K` + matrices that solve the equations + `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares + sense. + + Below we will use the following notation for each pair of + matrix and right-hand sides in the batch: + + `matrix`=\\(A \in \Re^{m \times n}\\), + `rhs`=\\(B \in \Re^{m \times k}\\), + `output`=\\(X \in \Re^{n \times k}\\), + `l2_regularizer`=\\(\lambda\\). + + If `fast` is `True`, then the solution is computed by solving the normal + equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then + \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares + problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 + + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as + \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is + the minimum-norm solution to the under-determined linear system, i.e. + \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to + \\(A Z = B\\). Notice that the fast path is only numerically stable when + \\(A\\) is numerically full rank and has a condition number + \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) + is sufficiently large. + + If `fast` is `False` then the solution is computed using the rank revealing + QR decomposition with column pivoting. This will always compute a + least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\), + even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current + version does not compute a minimum norm solution. If `fast` is `False` then + `l2_regularizer` is ignored. + + Args: + matrix: `Tensor` of shape `[..., M, N]`. + rhs: `Tensor` of shape `[..., M, K]`. + l2_regularizer: 0-D `double` `Tensor`. Ignored if `fast=False`. + fast: bool. Defaults to `True`. + name: string, optional name of the operation. + + Returns: + output: `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form + `M`-by-`K` matrices that solve the equations + `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least + squares sense. + """ + return gen_linalg_ops.batch_matrix_solve_ls(matrix, + rhs, + l2_regularizer, + fast=fast, + name=name) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index b6f8b01c53..fc94b38ce8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -82,6 +82,9 @@ mathematical functions for matrices to your graph. @@matrix_triangular_solve @@batch_matrix_triangular_solve +@@matrix_solve_ls +@@batch_matrix_solve_ls + ## Complex Number Functions TensorFlow provides several operations that you can use to add complex number |