aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-26 09:30:53 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-26 10:59:37 -0800
commit38242dd94a87a0e72f39315447866e55923d642c (patch)
tree18df5f31800dd52067b965894e63eaad2acd6982
parent7d4a063834cf9e5cdebe95d2115ce2e527296d79 (diff)
Add new matrix_solve_ls op for solving linear least-squares problems.
Change: 113064195
m---------google/protobuf0
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op.cc183
-rw-r--r--tensorflow/core/ops/linalg_ops.cc105
-rw-r--r--tensorflow/core/ops/ops.pbtxt94
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md109
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md8
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py198
-rw-r--r--tensorflow/python/ops/linalg_ops.py137
-rw-r--r--tensorflow/python/ops/math_ops.py3
12 files changed, 813 insertions, 30 deletions
diff --git a/google/protobuf b/google/protobuf
-Subproject 60f7fc51fb01ac98fd64d9372e7547f5ab267ce
+Subproject b13874d59e976371a1a87e6dc2bf347ed0a0ce5
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op.cc b/tensorflow/core/kernels/matrix_solve_ls_op.cc
new file mode 100644
index 0000000000..b752a7ed6e
--- /dev/null
+++ b/tensorflow/core/kernels/matrix_solve_ls_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/linalg_ops.cc.
+#include <cmath>
+
+#include "third_party/eigen3/Eigen/Cholesky"
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/Eigen/QR"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+template <class Scalar, bool SupportsBatchOperationT>
+class MatrixSolveLsOp
+ : public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
+ public:
+ explicit MatrixSolveLsOp(OpKernelConstruction* context)
+ : BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_));
+ }
+
+ ~MatrixSolveLsOp() override {}
+
+ TensorShape GetOutputMatrixShape(
+ const TensorShape& input_matrix_shape,
+ const TensorShape& rhs_matrix_shape) override {
+ CHECK_EQ(input_matrix_shape.dims(), rhs_matrix_shape.dims());
+ TensorShape output_matrix_shape = rhs_matrix_shape;
+ output_matrix_shape.set_dim(
+ output_matrix_shape.dims() - 2,
+ input_matrix_shape.dim_size(output_matrix_shape.dims() - 1));
+ return output_matrix_shape;
+ }
+
+ int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
+ const TensorShape& rhs_matrix_shape) override {
+ const int64 rows = input_matrix_shape.dim_size(0);
+ const int64 rhss = rhs_matrix_shape.dim_size(1);
+ if (rows > (1LL << 20)) {
+ // A big number to cap the cost in case overflow.
+ return kint32max;
+ } else {
+ return 2 * rows * rows * (rows + rhss);
+ }
+ }
+
+ using typename BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
+ using typename BinaryLinearAlgebraOp<Scalar,
+ SupportsBatchOperationT>::MatrixMap;
+ using typename BinaryLinearAlgebraOp<Scalar,
+ SupportsBatchOperationT>::ConstMatrixMap;
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& matrix,
+ const ConstMatrixMap& rhs, MatrixMap* output) override {
+ const int64 rows = matrix.rows();
+ const int64 cols = matrix.cols();
+ OP_REQUIRES(
+ context, rows == rhs.rows(),
+ errors::InvalidArgument("Input matrix and rhs are incompatible."));
+ const auto& l2_regularizer_in = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()),
+ errors::InvalidArgument("l2_regularizer must be scalar, got shape ",
+ l2_regularizer_in.shape().DebugString()));
+ const double l2_regularizer = l2_regularizer_in.scalar<double>()();
+
+ OP_REQUIRES(context, l2_regularizer >= 0,
+ errors::InvalidArgument("l2_regularizer must be >= 0."));
+ if (rows == 0 || cols == 0) {
+ // The result is the empty matrix.
+ return;
+ }
+ if (fast_) {
+ // The fast branch assumes that matrix is not rank deficient and
+ // not too ill-conditioned. Specifically, the reciprobal condition number
+ // should be greater than the square root of the machine precision, i.e.
+ // 1 / cond(matrix) > sqrt(std::numeric_limits<Scalar>::epsilon()).
+ // This branch solves over- or underdetermined least-squares problems
+ // via the normal equations and Cholesky decomposition.
+ if (matrix.rows() >= matrix.cols()) {
+ // Overdetermined case (rows >= cols): Solves the ordinary (possibly
+ // regularized) least-squares problem
+ // min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2
+ // by solving the normal equations
+ // (A^T * A + l2_regularizer * I) X = A^T RHS
+ // using Cholesky decomposition.
+ Matrix gramian(cols, cols);
+ gramian.template triangularView<Eigen::Lower>() =
+ matrix.transpose() * matrix;
+ if (l2_regularizer > 0) {
+ gramian +=
+ (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal();
+ }
+ const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
+ OP_REQUIRES(
+ context, llt.info() == Eigen::Success,
+ errors::InvalidArgument("Input matrix was rank deficient or "
+ "ill-conditioned. Try setting fast=False "
+ "or provide a larger l2_regularizer > 0."));
+ *output = llt.solve(matrix.transpose() * rhs);
+ } else {
+ // Underdetermined case (rows < cols): Solves the minimum-norm problem
+ // min ||X||_F^2 s.t. A*X = RHS
+ // by solving the normal equations of the second kind
+ // (A * A^T + l2_regularizer * I) Z = RHS, X = A^T * Z
+ // using Cholesky decomposition.
+ Matrix gramian(rows, rows);
+ gramian.template triangularView<Eigen::Lower>() =
+ matrix * matrix.transpose();
+ if (l2_regularizer > 0) {
+ gramian +=
+ (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal();
+ }
+ const Eigen::LLT<Matrix, Eigen::Lower> llt(gramian);
+ OP_REQUIRES(
+ context, llt.info() == Eigen::Success,
+ errors::InvalidArgument("Input matrix was rank deficient or "
+ "ill-conditioned. Try setting fast=False "
+ "or provide an l2_regularizer > 0."));
+ *output = matrix.transpose() * llt.solve(rhs);
+ }
+ } else {
+ // Use a rank revealing factorization (QR with column pivoting).
+ //
+ // NOTICE: Currently, Eigen's implementation of column pivoted Householder
+ // QR has a few deficiencies:
+ // 1. It does not implement the post-processing step to compute a
+ // complete orthogonal factorization. This means that it does not
+ // return a minimum-norm solution for underdetermined and
+ // rank-deficient matrices. We could use the Eigen SVD instead, but
+ // the currently available JacobiSVD is so slow that is it is
+ // essentially useless (~100x slower than QR).
+ // 2. The implementation is not blocked, so for matrics that do not fit
+ // in cache, it is significantly slower than the equivalent blocked
+ // LAPACK routine xGEQP3 (e.g. Eigen is ~3x slower for 4k x 4k
+ // matrices). See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
+ // 3. The implementation uses the numerically unstable norm downdating
+ // formula from the original 1965 Businger & Golub paper. This can
+ // lead to incorrect rank determination for graded matrices. I
+ // (rmlarsen@) have a patch to bring this up to date by implementing
+ // the robust formula from
+ // http://www.netlib.org/lapack/lawnspdf/lawn176.pdf
+ //
+ // TODO(rmlarsen): a) Contribute 1. and 2. to Eigen.
+ // b) Evaluate new divide-and-conquer SVD in Eigen when
+ // it becomes available & robust.
+ *output = matrix.colPivHouseholderQr().solve(rhs);
+ }
+ }
+
+ private:
+ bool fast_;
+};
+
+REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<float, false>),
+ float);
+REGISTER_BINARY_LINALG_OP("MatrixSolveLs", (MatrixSolveLsOp<double, false>),
+ double);
+REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<float, true>),
+ float);
+REGISTER_BINARY_LINALG_OP("BatchMatrixSolveLs", (MatrixSolveLsOp<double, true>),
+ double);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 74a19135c2..952a592d51 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -26,7 +26,6 @@ Calculates the determinant of a square matrix.
input: A tensor of shape `[M, M]`.
output: A scalar, equal to the determinant of the input.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchMatrixDeterminant")
@@ -42,7 +41,6 @@ for all input submatrices `[..., :, :]`.
input: Shape is `[..., M, M]`.
output: Shape is `[...]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("MatrixInverse")
@@ -61,7 +59,6 @@ garbage result.
input: Shape is `[M, M]`.
output: Shape is `[M, M]` containing the matrix inverse of the input.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchMatrixInverse")
@@ -84,7 +81,6 @@ garbage result.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("Cholesky")
@@ -103,7 +99,6 @@ input.
input: Shape is `[M, M]`.
output: Shape is `[M, M]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchCholesky")
@@ -120,7 +115,6 @@ containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("SelfAdjointEig")
@@ -138,7 +132,6 @@ subsequent rows are eigenvectors.
input: Shape is `[M, M]`.
output: Shape is `[M+1, M]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchSelfAdjointEig")
@@ -157,7 +150,6 @@ eigenvalues, and subsequent [...,1:, :] containing the eigenvectors.
input: Shape is `[..., M, M]`.
output: Shape is `[..., M+1, M]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("MatrixSolve")
@@ -172,7 +164,6 @@ matrix: Shape is `[M, M]`.
rhs: Shape is `[M, K]`.
output: Shape is `[M, K]` containing the tensor that solves
matrix * output = rhs.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchMatrixSolve")
@@ -191,7 +182,6 @@ matrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :].
matrix: Shape is `[..., M, M]`.
rhs: Shape is `[..., M, K]`.
output: Shape is `[..., M, K]`.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("MatrixTriangularSolve")
@@ -218,7 +208,6 @@ matrix: Shape is `[M, M]`.
rhs: Shape is `[M, K]`.
output: Shape is `[M, K]`.
lower: Boolean indicating whether matrix is lower or upper triangular.
-T: The type of values in the input and output.
)doc");
REGISTER_OP("BatchMatrixTriangularSolve")
@@ -247,7 +236,99 @@ matrix: Shape is `[..., M, M]`.
rhs: Shape is `[..., M, K]`.
output: Shape is `[..., M, K]`.
lower: Boolean indicating whether matrix is lower or upper triangular.
-T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("MatrixSolveLs")
+ .Input("matrix: T")
+ .Input("rhs: T")
+ .Input("l2_regularizer: double")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("fast: bool = True")
+ .Doc(R"doc(
+Solves a linear least-squares problem.
+
+Below we will use the following notation
+`matrix`=\\(A \in \Re^{m \times n}\\),
+`rhs`=\\(B \in \Re^{m \times k}\\),
+`output`=\\(X \in \Re^{n \times k}\\),
+`l2_regularizer`=\\(\lambda\\).
+
+If `fast` is `True`, then the solution is computed by solving the normal
+equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
+problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
+\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+\\(X = A^T (A A^T + \lambda I)^{-1} B\\),
+which (for \\(\lambda = 0\\)) is the minimum-norm solution to the
+under-determined linear system, i.e.
+\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\),
+subject to \\(A Z = B\\).
+Notice that the fast path is only numerically stable when \\(A\\) is
+numerically full rank and has a condition number
+\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\)
+or \\(\lambda\\) is sufficiently large.
+
+If `fast` is `False` then the solution is computed using the rank revealing QR
+decomposition with column pivoting. This will always compute a least-squares
+solution that minimizes the residual norm \\(||A X - B||_F^2 \\), even when
+\\( A \\) is rank deficient or ill-conditioned. Notice: The current version
+does not compute a minimum norm solution. If `fast` is `False` then
+`l2_regularizer` is ignored.
+
+matrix: Shape is `[M, N]`.
+rhs: Shape is `[M, K]`.
+output: Shape is `[N, K]` containing the tensor that solves
+ `matrix * output = rhs` in the least-squares sense.
+)doc");
+
+REGISTER_OP("BatchMatrixSolveLs")
+ .Input("matrix: T")
+ .Input("rhs: T")
+ .Input("l2_regularizer: double")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("fast: bool = True")
+ .Doc(R"doc(
+Solves multiple linear least-squares problems.
+
+`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+form square matrices. Rhs is a tensor of shape `[..., M, K]`. The output
+is a tensor shape `[..., N, K]` where each output matrix solves each of
+the equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the
+least squares sense.
+
+Below we will use the following notation for each pair of
+matrix and right-hand sides in the batch:
+
+`matrix`=\\(A \in \Re^{m \times n}\\),
+`rhs`=\\(B \in \Re^{m \times k}\\),
+`output`=\\(X \in \Re^{n \times k}\\),
+`l2_regularizer`=\\(\lambda\\).
+
+If `fast` is `True`, then the solution is computed by solving the normal
+equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
+problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
+\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
+minimum-norm solution to the under-determined linear system, i.e.
+\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to
+\\(A Z = B\\). Notice that the fast path is only numerically stable when
+\\(A\\) is numerically full rank and has a condition number
+\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\) is
+sufficiently large.
+
+If `fast` is `False` then the solution is computed using the rank revealing QR
+decomposition with column pivoting. This will always compute a least-squares
+solution that minimizes the residual norm \\(||A X - B||_F^2\\), even when
+\\(A\\) is rank deficient or ill-conditioned. Notice: The current version does
+not compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`
+is ignored.
+
+matrix: Shape is `[..., M, N]`.
+rhs: Shape is `[..., M, K]`.
+output: Shape is `[..., N, K]`.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d2f0885c6e..19f9afb94b 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -993,7 +993,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
@@ -1067,7 +1066,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -1093,7 +1091,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -1124,7 +1121,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -1136,6 +1132,47 @@ op {
description: "Matrix is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape\n`[..., M, K]`. The output is a tensor shape `[..., M, K]` where each output\nmatrix satisfies matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]."
}
op {
+ name: "BatchMatrixSolveLs"
+ input_arg {
+ name: "matrix"
+ description: "Shape is `[..., M, N]`."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "rhs"
+ description: "Shape is `[..., M, K]`."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "l2_regularizer"
+ type: DT_DOUBLE
+ }
+ output_arg {
+ name: "output"
+ description: "Shape is `[..., N, K]`."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ summary: "Solves multiple linear least-squares problems."
+ description: "`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions\nform square matrices. Rhs is a tensor of shape `[..., M, K]`. The output\nis a tensor shape `[..., N, K]` where each output matrix solves each of\nthe equations matrix[..., :, :] * output[..., :, :] = rhs[..., :, :] in the\nleast squares sense.\n\nBelow we will use the following notation for each pair of\nmatrix and right-hand sides in the batch:\n\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\), which (for \\\\(\\lambda = 0\\\\)) is the\nminimum-norm solution to the under-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\), subject to\n\\\\(A Z = B\\\\). Notice that the fast path is only numerically stable when\n\\\\(A\\\\) is numerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\) or\\\\(\\lambda\\\\) is\nsufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2\\\\), even when\n\\\\(A\\\\) is rank deficient or ill-conditioned. Notice: The current version does\nnot compute a minimum norm solution. If `fast` is `False` then `l2_regularizer`\nis ignored."
+}
+op {
name: "BatchMatrixTriangularSolve"
input_arg {
name: "matrix"
@@ -1163,7 +1200,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -1336,7 +1372,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
@@ -1493,7 +1528,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
@@ -4123,7 +4157,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -4148,7 +4181,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -4179,7 +4211,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -4190,6 +4221,47 @@ op {
summary: "Solves a system of linear equations. Checks for invertibility."
}
op {
+ name: "MatrixSolveLs"
+ input_arg {
+ name: "matrix"
+ description: "Shape is `[M, N]`."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "rhs"
+ description: "Shape is `[M, K]`."
+ type_attr: "T"
+ }
+ input_arg {
+ name: "l2_regularizer"
+ type: DT_DOUBLE
+ }
+ output_arg {
+ name: "output"
+ description: "Shape is `[N, K]` containing the tensor that solves\n`matrix * output = rhs` in the least-squares sense."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ summary: "Solves a linear least-squares problem."
+ description: "Below we will use the following notation\n`matrix`=\\\\(A \\in \\Re^{m \\times n}\\\\),\n`rhs`=\\\\(B \\in \\Re^{m \\times k}\\\\),\n`output`=\\\\(X \\in \\Re^{n \\times k}\\\\),\n`l2_regularizer`=\\\\(\\lambda\\\\).\n\nIf `fast` is `True`, then the solution is computed by solving the normal\nequations using Cholesky decomposition. Specifically, if \\\\(m \\ge n\\\\) then\n\\\\(X = (A^T A + \\lambda I)^{-1} A^T B\\\\), which solves the least-squares\nproblem \\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||A Z - B||_F^2 +\n\\lambda ||Z||_F^2\\\\). If \\\\(m \\lt n\\\\) then `output` is computed as\n\\\\(X = A^T (A A^T + \\lambda I)^{-1} B\\\\),\nwhich (for \\\\(\\lambda = 0\\\\)) is the minimum-norm solution to the\nunder-determined linear system, i.e.\n\\\\(X = \\mathrm{argmin}_{Z \\in \\Re^{n \\times k}} ||Z||_F^2 \\\\),\nsubject to \\\\(A Z = B\\\\).\nNotice that the fast path is only numerically stable when \\\\(A\\\\) is\nnumerically full rank and has a condition number\n\\\\(\\mathrm{cond}(A) \\lt \\frac{1}{\\sqrt{\\epsilon_{mach}}}\\\\)\nor \\\\(\\lambda\\\\) is sufficiently large.\n\nIf `fast` is `False` then the solution is computed using the rank revealing QR\ndecomposition with column pivoting. This will always compute a least-squares\nsolution that minimizes the residual norm \\\\(||A X - B||_F^2 \\\\), even when\n\\\\( A \\\\) is rank deficient or ill-conditioned. Notice: The current version\ndoes not compute a minimum norm solution. If `fast` is `False` then\n`l2_regularizer` is ignored."
+}
+op {
name: "MatrixTriangularSolve"
input_arg {
name: "matrix"
@@ -4217,7 +4289,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_FLOAT
@@ -7396,7 +7467,6 @@ op {
attr {
name: "T"
type: "type"
- description: "The type of values in the input and output."
allowed_values {
list {
type: DT_DOUBLE
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index e70f04a5e1..253f13f38f 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -120,6 +120,7 @@
* [`batch_matrix_determinant`](../../api_docs/python/math_ops.md#batch_matrix_determinant)
* [`batch_matrix_inverse`](../../api_docs/python/math_ops.md#batch_matrix_inverse)
* [`batch_matrix_solve`](../../api_docs/python/math_ops.md#batch_matrix_solve)
+ * [`batch_matrix_solve_ls`](../../api_docs/python/math_ops.md#batch_matrix_solve_ls)
* [`batch_matrix_triangular_solve`](../../api_docs/python/math_ops.md#batch_matrix_triangular_solve)
* [`batch_self_adjoint_eig`](../../api_docs/python/math_ops.md#batch_self_adjoint_eig)
* [`ceil`](../../api_docs/python/math_ops.md#ceil)
@@ -148,6 +149,7 @@
* [`matrix_determinant`](../../api_docs/python/math_ops.md#matrix_determinant)
* [`matrix_inverse`](../../api_docs/python/math_ops.md#matrix_inverse)
* [`matrix_solve`](../../api_docs/python/math_ops.md#matrix_solve)
+ * [`matrix_solve_ls`](../../api_docs/python/math_ops.md#matrix_solve_ls)
* [`matrix_triangular_solve`](../../api_docs/python/math_ops.md#matrix_triangular_solve)
* [`maximum`](../../api_docs/python/math_ops.md#maximum)
* [`minimum`](../../api_docs/python/math_ops.md#minimum)
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 4fba7e08b2..8303c14a5b 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -1087,6 +1087,115 @@ If `lower` is false then the strictly then the output satisfies
+- - -
+
+### `tf.matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None)` {#matrix_solve_ls}
+
+Solves a linear least-squares problem.
+
+Below we will use the following notation
+`matrix`=\\(A \in \Re^{m \times n}\\),
+`rhs`=\\(B \in \Re^{m \times k}\\),
+`output`=\\(X \in \Re^{n \times k}\\),
+`l2_regularizer`=\\(\lambda\\).
+
+If `fast` is `True`, then the solution is computed by solving the normal
+equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the regularized
+least-squares problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}}
+||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is
+computed as \\(X = A^T (A A^T + \lambda I)^{-1} B\\),
+which (for \\(\lambda = 0\\)) is the minimum-norm solution to the
+under-determined linear system, i.e.
+\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\),
+subject to \\(A Z = B\\).
+Notice that the fast path is only numerically stable when \\(A\\) is
+numerically full rank and has a condition number
+\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\)
+or \\(\lambda\\) is sufficiently large.
+
+If `fast` is `False` then the solution is computed using the rank revealing
+QR decomposition with column pivoting. This will always compute a
+least-squares solution that minimizes the residual norm
+\\(||A X - B||_F^2 \\), even when \\(A\\) is rank deficient or
+ill-conditioned. Notice: The current version does not compute a minimum norm
+solution. If `fast` is `False` then `l2_regularizer` is ignored.
+
+##### Args:
+
+
+* <b>`matrix`</b>: 2-D `Tensor` of shape `[M, N]`.
+* <b>`rhs`</b>: 2-D `Tensor` of shape is `[M, K]`.
+* <b>`l2_regularizer`</b>: 0-D `double` `Tensor`. Ignored if `fast=False`.
+* <b>`fast`</b>: bool. Defaults to `True`.
+
+##### Returns:
+
+
+* <b>`output`</b>: Matrix of shape `[N, K]` containing the matrix that solves
+ `matrix * output = rhs` in the least-squares sense.
+
+
+- - -
+
+### `tf.batch_matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None)` {#batch_matrix_solve_ls}
+
+Solves multiple linear least-squares problems.
+
+`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose
+inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
+`Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K`
+matrices that solve the equations
+`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares
+sense.
+
+Below we will use the following notation for each pair of
+matrix and right-hand sides in the batch:
+
+`matrix`=\\(A \in \Re^{m \times n}\\),
+`rhs`=\\(B \in \Re^{m \times k}\\),
+`output`=\\(X \in \Re^{n \times k}\\),
+`l2_regularizer`=\\(\lambda\\).
+
+If `fast` is `True`, then the solution is computed by solving the normal
+equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+\\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
+problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
+\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+\\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is
+the minimum-norm solution to the under-determined linear system, i.e.
+\\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to
+\\(A Z = B\\). Notice that the fast path is only numerically stable when
+\\(A\\) is numerically full rank and has a condition number
+\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\)
+is sufficiently large.
+
+If `fast` is `False` then the solution is computed using the rank revealing
+QR decomposition with column pivoting. This will always compute a
+least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\),
+even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current
+version does not compute a minimum norm solution. If `fast` is `False` then
+`l2_regularizer` is ignored.
+
+##### Args:
+
+
+* <b>`matrix`</b>: `Tensor` of shape `[..., M, N]`.
+* <b>`rhs`</b>: `Tensor` of shape `[..., M, K]`.
+* <b>`l2_regularizer`</b>: 0-D `double` `Tensor`. Ignored if `fast=False`.
+* <b>`fast`</b>: bool. Defaults to `True`.
+
+##### Returns:
+
+
+* <b>`output`</b>: `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form
+ `M`-by-`K` matrices that solve the equations
+ `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least
+ squares sense.
+
+
+
## Complex Number Functions
TensorFlow provides several operations that you can use to add complex number
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 1a275f7ea2..feac191afa 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -897,7 +897,8 @@ with tf.variable_scope("foo", reuse=True)
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, a
-`UniformUnitScalingInitializer` will be used.
+`UniformUnitScalingInitializer` will be used. The initializer can also be
+a Tensor, in which case the variable is initialized to this value and shape.
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index f9a9144f59..7650d3fb5b 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -94,7 +94,7 @@ This must be called by the constructors of subclasses.
- - -
-#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` {#Optimizer.minimize}
+#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, name=None)` {#Optimizer.minimize}
Add operations to minimize `loss` by updating `var_list`.
@@ -116,6 +116,8 @@ of using this function.
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
+* <b>`colocate_gradients_with_ops`</b>: If True, try colocating gradients with
+ the corresponding op.
* <b>`name`</b>: Optional name for the returned operation.
##### Returns:
@@ -131,7 +133,7 @@ of using this function.
- - -
-#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` {#Optimizer.compute_gradients}
+#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False)` {#Optimizer.compute_gradients}
Compute gradients of `loss` for the variables in `var_list`.
@@ -152,6 +154,8 @@ given variable.
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
+* <b>`colocate_gradients_with_ops`</b>: If True, try colocating gradients with
+ the corresponding op.
##### Returns:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index f87fee1211..0862e39f6f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -986,6 +986,7 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/matrix_inverse_op_test.py",
"kernel_tests/matrix_solve_op_test.py",
"kernel_tests/matrix_triangular_solve_op_test.py",
+ "kernel_tests/matrix_solve_ls_op_test.py",
"kernel_tests/mutex_ops_test.py",
"kernel_tests/parsing_ops_test.py",
"kernel_tests/queue_ops_test.py",
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
new file mode 100644
index 0000000000..2ee334a089
--- /dev/null
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -0,0 +1,198 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow.ops.math_ops.matrix_solve."""
+# pylint: disable=unused-import,g-bad-import-order
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+# pylint: enable=unused-import,g-bad-import-order
+
+
+def BatchMatMul(a, b):
+ # A numpy implementation of tf.batch_matmul().
+ if a.ndim < 3:
+ return np.dot(a, b)
+ # Get the number of matrices.
+ n = np.prod(a.shape[:-2])
+ assert n == np.prod(b.shape[:-2])
+ a_flat = np.reshape(a, tuple([n]) + a.shape[-2:])
+ b_flat = np.reshape(b, tuple([n]) + b.shape[-2:])
+ c_flat_shape = [n, a.shape[-2], b.shape[-1]]
+ c_flat = np.empty(c_flat_shape)
+ for i in range(n):
+ c_flat[i, :, :] = np.dot(a_flat[i, :, :], b_flat[i, :, :])
+ return np.reshape(c_flat, a.shape[:-1] + b_flat.shape[-1:])
+
+
+def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0):
+ # A numpy implementation of regularized least squares solver using
+ # the normal equations.
+ matrix_dims = matrices.shape
+ matrices_transposed = np.swapaxes(matrices, -2, -1)
+ rows = matrix_dims[-2]
+ cols = matrix_dims[-1]
+ if rows >= cols:
+ preconditioner = l2_regularization * np.identity(cols)
+ gramian = BatchMatMul(matrices_transposed, matrices) + preconditioner
+ inverse = np.linalg.inv(gramian)
+ left_pseudo_inverse = BatchMatMul(inverse, matrices_transposed)
+ return BatchMatMul(left_pseudo_inverse, rhss)
+ else:
+ preconditioner = l2_regularization * np.identity(rows)
+ gramian = BatchMatMul(matrices, matrices_transposed) + preconditioner
+ inverse = np.linalg.inv(gramian)
+ right_pseudo_inverse = BatchMatMul(matrices_transposed, inverse)
+ return BatchMatMul(right_pseudo_inverse, rhss)
+
+
+class MatrixSolveLsOpTest(tf.test.TestCase):
+
+ def _verifySolve(self, x, y):
+ for np_type in [np.float32, np.float64]:
+ a = x.astype(np_type)
+ b = y.astype(np_type)
+ np_ans, _, _, _ = np.linalg.lstsq(a, b)
+ for fast in [True, False]:
+ with self.test_session():
+ tf_ans = tf.matrix_solve_ls(a, b, fast=fast).eval()
+ self.assertEqual(np_ans.shape, tf_ans.shape)
+
+ # Check residual norm.
+ tf_r = b - BatchMatMul(a, tf_ans)
+ tf_r_norm = np.sum(tf_r * tf_r)
+ np_r = b - BatchMatMul(a, np_ans)
+ np_r_norm = np.sum(np_r * np_r)
+ self.assertAllClose(np_r_norm, tf_r_norm)
+
+ # Check solution.
+ if fast or a.shape[0] >= a.shape[1]:
+ # We skip this test for the underdetermined case when using the
+ # slow path, because Eigen does not return a minimum norm solution.
+ # TODO(rmlarsen): Enable this check for all paths if/when we fix
+ # Eigen's solver.
+ self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+
+ def _verifySolveBatch(self, x, y):
+ # Since numpy.linalg.lsqr does not support batch solves, as opposed
+ # to numpy.linalg.solve, we just perform this test for a fixed batch size
+ # of 2x3.
+ for np_type in [np.float32, np.float64]:
+ a = np.tile(x.astype(np_type), [2, 3, 1, 1])
+ b = np.tile(y.astype(np_type), [2, 3, 1, 1])
+ np_ans = np.empty([2, 3, a.shape[-1], b.shape[-1]])
+ for dim1 in range(2):
+ for dim2 in range(3):
+ np_ans[dim1, dim2, :, :], _, _, _ = np.linalg.lstsq(
+ a[dim1, dim2, :, :], b[dim1, dim2, :, :])
+ for fast in [True, False]:
+ with self.test_session():
+ tf_ans = tf.batch_matrix_solve_ls(a, b, fast=fast).eval()
+ self.assertEqual(np_ans.shape, tf_ans.shape)
+ # Check residual norm.
+ tf_r = b - BatchMatMul(a, tf_ans)
+ tf_r_norm = np.sum(tf_r * tf_r)
+ np_r = b - BatchMatMul(a, np_ans)
+ np_r_norm = np.sum(np_r * np_r)
+ self.assertAllClose(np_r_norm, tf_r_norm)
+ # Check solution.
+ if fast or a.shape[-2] >= a.shape[-1]:
+ # We skip this test for the underdetermined case when using the
+ # slow path, because Eigen does not return a minimum norm solution.
+ # TODO(rmlarsen): Enable this check for all paths if/when we fix
+ # Eigen's solver.
+ self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+
+ def _verifyRegularized(self, x, y, l2_regularizer):
+ for np_type in [np.float32, np.float64]:
+ # Test with a single matrix.
+ a = x.astype(np_type)
+ b = y.astype(np_type)
+ np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
+ with self.test_session():
+ tf_ans = tf.matrix_solve_ls(a,
+ b,
+ l2_regularizer=l2_regularizer,
+ fast=True).eval()
+ self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+ # Test with a 2x3 batch of matrices.
+ a = np.tile(x.astype(np_type), [2, 3, 1, 1])
+ b = np.tile(y.astype(np_type), [2, 3, 1, 1])
+ np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
+ with self.test_session():
+ tf_ans = tf.batch_matrix_solve_ls(a,
+ b,
+ l2_regularizer=l2_regularizer,
+ fast=True).eval()
+ self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
+
+ def testSquare(self):
+ # 2x2 matrices, 2x3 right-hand sides.
+ matrix = np.array([[1., 2.], [3., 4.]])
+ rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
+ self._verifySolve(matrix, rhs)
+ self._verifySolveBatch(matrix, rhs)
+ self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
+
+ def testOverdetermined(self):
+ # 2x2 matrices, 2x3 right-hand sides.
+ matrix = np.array([[1., 2.], [3., 4.], [5., 6.]])
+ rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]])
+ self._verifySolve(matrix, rhs)
+ self._verifySolveBatch(matrix, rhs)
+ self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
+
+ def testUnderdetermined(self):
+ # 2x2 matrices, 2x3 right-hand sides.
+ matrix = np.array([[1., 2., 3], [4., 5., 6.]])
+ rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
+ self._verifySolve(matrix, rhs)
+ self._verifySolveBatch(matrix, rhs)
+ self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
+
+ def testWrongDimensions(self):
+ # The matrix and right-hand sides should have the same number of rows.
+ with self.test_session():
+ matrix = tf.constant([[1., 0.], [0., 1.]])
+ rhs = tf.constant([[1., 0.]])
+ with self.assertRaises(ValueError):
+ tf.matrix_solve_ls(matrix, rhs)
+ with self.assertRaises(ValueError):
+ tf.batch_matrix_solve_ls(matrix, rhs)
+
+ def testEmpty(self):
+ full = np.array([[1., 2.], [3., 4.], [5., 6.]])
+ empty0 = np.empty([3, 0])
+ empty1 = np.empty([0, 2])
+ for fast in [True, False]:
+ with self.test_session():
+ tf_ans = tf.matrix_solve_ls(empty0, empty0, fast=fast).eval()
+ self.assertEqual(tf_ans.shape, (0, 0))
+ tf_ans = tf.matrix_solve_ls(empty0, full, fast=fast).eval()
+ self.assertEqual(tf_ans.shape, (0, 2))
+ tf_ans = tf.matrix_solve_ls(full, empty0, fast=fast).eval()
+ self.assertEqual(tf_ans.shape, (2, 0))
+ tf_ans = tf.matrix_solve_ls(empty1, empty1, fast=fast).eval()
+ self.assertEqual(tf_ans.shape, (2, 2))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 72e5098398..091016e950 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -108,9 +108,9 @@ def _MatrixSolveShape(op):
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
# The matrix must be square.
lhs_shape[0].assert_is_compatible_with(lhs_shape[1])
- # The matrix and righ-hand-side must have the same number of rows.
+ # The matrix and right-hand side must have the same number of rows.
lhs_shape[0].assert_is_compatible_with(rhs_shape[0])
- return [[lhs_shape[0], rhs_shape[1]]]
+ return [[lhs_shape[1], rhs_shape[1]]]
@ops.RegisterShape("BatchMatrixSolve")
@@ -119,7 +119,7 @@ def _BatchMatrixSolveShape(op):
rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
# The matrices must be square.
lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
- # The matrices and righ-hand-sides in the batch must have the same number of
+ # The matrices and right-hand sides in the batch must have the same number of
# rows.
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
return [lhs_shape[:-2].concatenate(rhs_shape[-1])]
@@ -146,3 +146,134 @@ def _BatchMatrixTriangularSolveShape(op):
# rows.
lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
return [rhs_shape]
+
+
+@ops.RegisterShape("MatrixSolveLs")
+def _MatrixSolveLsShape(op):
+ lhs_shape = op.inputs[0].get_shape().with_rank(2)
+ rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
+ # The matrix and right-hand side must have the same number of rows.
+ lhs_shape[0].assert_is_compatible_with(rhs_shape[0])
+ return [[lhs_shape[1], rhs_shape[1]]]
+
+
+@ops.RegisterShape("BatchMatrixSolveLs")
+def _BatchMatrixSolveLsShape(op):
+ lhs_shape = op.inputs[0].get_shape().with_rank_at_least(3)
+ rhs_shape = op.inputs[1].get_shape().with_rank_at_least(3)
+ # The matrices and right-hand sides in the batch must have the same number of
+ # rows.
+ lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
+ return [lhs_shape[:-3].concatenate([lhs_shape[-1], rhs_shape[-1]])]
+
+
+def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
+ r"""Solves a linear least-squares problem.
+
+ Below we will use the following notation
+ `matrix`=\\(A \in \Re^{m \times n}\\),
+ `rhs`=\\(B \in \Re^{m \times k}\\),
+ `output`=\\(X \in \Re^{n \times k}\\),
+ `l2_regularizer`=\\(\lambda\\).
+
+ If `fast` is `True`, then the solution is computed by solving the normal
+ equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+ \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the regularized
+ least-squares problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}}
+ ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is
+ computed as \\(X = A^T (A A^T + \lambda I)^{-1} B\\),
+ which (for \\(\lambda = 0\\)) is the minimum-norm solution to the
+ under-determined linear system, i.e.
+ \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\),
+ subject to \\(A Z = B\\).
+ Notice that the fast path is only numerically stable when \\(A\\) is
+ numerically full rank and has a condition number
+ \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\)
+ or \\(\lambda\\) is sufficiently large.
+
+ If `fast` is `False` then the solution is computed using the rank revealing
+ QR decomposition with column pivoting. This will always compute a
+ least-squares solution that minimizes the residual norm
+ \\(||A X - B||_F^2 \\), even when \\(A\\) is rank deficient or
+ ill-conditioned. Notice: The current version does not compute a minimum norm
+ solution. If `fast` is `False` then `l2_regularizer` is ignored.
+
+ Args:
+ matrix: 2-D `Tensor` of shape `[M, N]`.
+ rhs: 2-D `Tensor` of shape is `[M, K]`.
+ l2_regularizer: 0-D `double` `Tensor`. Ignored if `fast=False`.
+ fast: bool. Defaults to `True`.
+ name: string, optional name of the operation.
+
+ Returns:
+ output: Matrix of shape `[N, K]` containing the matrix that solves
+ `matrix * output = rhs` in the least-squares sense.
+ """
+ return gen_linalg_ops.matrix_solve_ls(matrix,
+ rhs,
+ l2_regularizer,
+ fast=fast,
+ name=name)
+
+
+def batch_matrix_solve_ls(matrix,
+ rhs,
+ l2_regularizer=0.0,
+ fast=True,
+ name=None):
+ r"""Solves multiple linear least-squares problems.
+
+ `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+ form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose
+ inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
+ `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K`
+ matrices that solve the equations
+ `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares
+ sense.
+
+ Below we will use the following notation for each pair of
+ matrix and right-hand sides in the batch:
+
+ `matrix`=\\(A \in \Re^{m \times n}\\),
+ `rhs`=\\(B \in \Re^{m \times k}\\),
+ `output`=\\(X \in \Re^{n \times k}\\),
+ `l2_regularizer`=\\(\lambda\\).
+
+ If `fast` is `True`, then the solution is computed by solving the normal
+ equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+ \\(X = (A^T A + \lambda I)^{-1} A^T B\\), which solves the least-squares
+ problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||A Z - B||_F^2 +
+ \lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+ \\(X = A^T (A A^T + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is
+ the minimum-norm solution to the under-determined linear system, i.e.
+ \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k}} ||Z||_F^2 \\), subject to
+ \\(A Z = B\\). Notice that the fast path is only numerically stable when
+ \\(A\\) is numerically full rank and has a condition number
+ \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach}}}\\) or\\(\lambda\\)
+ is sufficiently large.
+
+ If `fast` is `False` then the solution is computed using the rank revealing
+ QR decomposition with column pivoting. This will always compute a
+ least-squares solution that minimizes the residual norm \\(||A X - B||_F^2\\),
+ even when \\(A\\) is rank deficient or ill-conditioned. Notice: The current
+ version does not compute a minimum norm solution. If `fast` is `False` then
+ `l2_regularizer` is ignored.
+
+ Args:
+ matrix: `Tensor` of shape `[..., M, N]`.
+ rhs: `Tensor` of shape `[..., M, K]`.
+ l2_regularizer: 0-D `double` `Tensor`. Ignored if `fast=False`.
+ fast: bool. Defaults to `True`.
+ name: string, optional name of the operation.
+
+ Returns:
+ output: `Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form
+ `M`-by-`K` matrices that solve the equations
+ `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least
+ squares sense.
+ """
+ return gen_linalg_ops.batch_matrix_solve_ls(matrix,
+ rhs,
+ l2_regularizer,
+ fast=fast,
+ name=name)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b6f8b01c53..fc94b38ce8 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -82,6 +82,9 @@ mathematical functions for matrices to your graph.
@@matrix_triangular_solve
@@batch_matrix_triangular_solve
+@@matrix_solve_ls
+@@batch_matrix_solve_ls
+
## Complex Number Functions
TensorFlow provides several operations that you can use to add complex number