aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-18 13:24:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-18 13:29:08 -0700
commit109ecf823d1ba2b9570102e09af87db4c6202b73 (patch)
tree929ea650b2dc240931258bb5e38c13f3abe34ab1 /tensorflow/core/kernels/matrix_solve_ls_op_impl.h
parent51441302d45c50750c42727af7ecc9bc709ae0f3 (diff)
Add support for complex in matrix_solve_ls_op.
Split into separate files for each data type to speed up build. PiperOrigin-RevId: 165744539
Diffstat (limited to 'tensorflow/core/kernels/matrix_solve_ls_op_impl.h')
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op_impl.h161
1 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
new file mode 100644
index 0000000000..0e09078365
--- /dev/null
+++ b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
@@ -0,0 +1,161 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/linalg_ops.cc.
+
+#include "third_party/eigen3/Eigen/Cholesky"
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/Eigen/QR"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+template <class Scalar>
+class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
+ public:
+ typedef LinearAlgebraOp<Scalar> Base;
+
+ explicit MatrixSolveLsOp(OpKernelConstruction* context) : Base(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_));
+ }
+
+ using TensorShapes = typename Base::TensorShapes;
+ using Matrix = typename Base::Matrix;
+ using MatrixMaps = typename Base::MatrixMaps;
+ using ConstMatrixMap = typename Base::ConstMatrixMap;
+ using ConstMatrixMaps = typename Base::ConstMatrixMaps;
+
+ // Tell the base class to ignore the regularization parameter
+ // in context->input(2).
+ int NumMatrixInputs(const OpKernelContext* context) const final { return 2; }
+
+ void ValidateInputMatrixShapes(
+ OpKernelContext* context,
+ const TensorShapes& input_matrix_shapes) const final {
+ Base::ValidateSolver(context, input_matrix_shapes);
+ }
+
+ TensorShapes GetOutputMatrixShapes(
+ const TensorShapes& input_matrix_shapes) const final {
+ return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
+ input_matrix_shapes[1].dim_size(1)})});
+ }
+
+ int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
+ double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
+ double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
+ double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
+ double cost = std::max(m, n) * std::min(m, n) * (std::min(m, n) + num_rhss);
+ return cost >= static_cast<double>(kint64max) ? kint64max
+ : static_cast<int64>(cost);
+ }
+
+ bool EnableInputForwarding() const final { return false; }
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
+ MatrixMaps* outputs) final {
+ const ConstMatrixMap& matrix = inputs[0];
+ const ConstMatrixMap& rhs = inputs[1];
+ const auto& l2_regularizer_in = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()),
+ errors::InvalidArgument("l2_regularizer must be scalar, got shape ",
+ l2_regularizer_in.shape().DebugString()));
+ const double l2_regularizer = l2_regularizer_in.scalar<double>()();
+ OP_REQUIRES(context, l2_regularizer >= 0,
+ errors::InvalidArgument("l2_regularizer must be >= 0."));
+
+ const int64 rows = matrix.rows();
+ const int64 cols = matrix.cols();
+ if (rows == 0 || cols == 0) {
+ // The result is the empty matrix.
+ return;
+ }
+ if (fast_) {
+ // The fast branch assumes that matrix is not rank deficient and
+ // not too ill-conditioned. Specifically, the reciprocal condition number
+ // should be greater than the square root of the machine precision, i.e.
+ // 1 / cond(matrix) > sqrt(std::numeric_limits<Scalar>::epsilon()).
+ // This branch solves over- or underdetermined least-squares problems
+ // via the normal equations and Cholesky decomposition.
+ if (matrix.rows() >= matrix.cols()) {
+ // Overdetermined case (rows >= cols): Solves the ordinary (possibly
+ // regularized) least-squares problem
+ // min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2
+ // by solving the normal equations
+ // (A^T * A + l2_regularizer * I) X = A^T RHS
+ // using Cholesky decomposition.
+ Matrix gramian(cols, cols);
+ gramian.template triangularView<Eigen::Lower>() =
+ matrix.adjoint() * matrix;
+ if (l2_regularizer > 0) {
+ gramian +=
+ (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal();
+ }
+ const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
+ OP_REQUIRES(
+ context, llt.info() == Eigen::Success,
+ errors::InvalidArgument("Input matrix was rank deficient or "
+ "ill-conditioned. Try setting fast=False "
+ "or provide a larger l2_regularizer > 0."));
+ outputs->at(0).noalias() = matrix.adjoint() * rhs;
+ llt.solveInPlace(outputs->at(0));
+ } else {
+ // Underdetermined case (rows < cols): Solves the minimum-norm problem
+ // min ||X||_F^2 s.t. A*X = RHS
+ // by solving the normal equations of the second kind
+ // (A * A^T + l2_regularizer * I) Z = RHS, X = A^T * Z
+ // using Cholesky decomposition.
+ Matrix gramian(rows, rows);
+ gramian.template triangularView<Eigen::Lower>() =
+ matrix * matrix.adjoint();
+ if (l2_regularizer > 0) {
+ gramian +=
+ (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal();
+ }
+ const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
+ OP_REQUIRES(
+ context, llt.info() == Eigen::Success,
+ errors::InvalidArgument("Input matrix was rank deficient or "
+ "ill-conditioned. Try setting fast=False "
+ "or provide an l2_regularizer > 0."));
+ outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs);
+ }
+ } else {
+ // Use complete orthogonal decomposition which is backwards stable and
+ // will compute the minimum-norm solution for rank-deficient matrices.
+ // This is 6-7 times slower than the fast path.
+ //
+ // TODO(rmlarsen): The implementation of
+ // Eigen::CompleteOrthogonalDecomposition is not blocked, so for
+ // matrices that do not fit in cache, it is significantly slower than
+ // the equivalent blocked LAPACK routine xGELSY (e.g. Eigen is ~3x
+ // slower for 4k x 4k matrices).
+ // See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
+ outputs->at(0) = matrix.completeOrthogonalDecomposition().solve(rhs);
+ }
+ }
+
+ private:
+ bool fast_;
+};
+
+} // namespace tensorflow