aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-18 16:55:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-18 18:01:59 -0700
commit6696de2b02d9f2c6d16fd39ef388be87309525bb (patch)
tree6c5e7585992ba21fcad4686550fc898faa443f41
parent13a2e787aa8f5fac6b7269fbd86e377bdc9cf8af (diff)
Add C++ shape inference functions for linalg_ops.cc.
Also changed python inference function for MatrixSolve and MatrixSolveLs to use with_rank instead of with_rank_at_least for rhs. Change: 127781866
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/ops/linalg_ops.cc177
-rw-r--r--tensorflow/core/ops/linalg_ops_test.cc189
-rw-r--r--tensorflow/python/ops/linalg_ops.py4
4 files changed, 368 insertions, 3 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d60f167229..8eb38d2104 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1856,6 +1856,7 @@ tf_cc_tests(
"ops/array_ops_test.cc",
"ops/data_flow_ops_test.cc",
"ops/io_ops_test.cc",
+ "ops/linalg_ops_test.cc",
"ops/math_ops_test.cc",
"ops/nn_ops_test.cc",
"ops/training_ops_test.cc",
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 2c8bde919d..3b3681ed82 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -14,13 +14,121 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+typedef shape_inference::Dimension Dimension;
+typedef shape_inference::InferenceContext InferenceContext;
+typedef shape_inference::Shape Shape;
+static constexpr auto kUnknownDim = InferenceContext::kUnknownDim;
+
+namespace {
+
+// Return in <out> the result of making <s> a square matrix.
+Status MakeSquareMatrix(InferenceContext* c, const Shape* s,
+ const Shape** out) {
+ TF_RETURN_IF_ERROR(c->WithRank(s, 2, &s));
+ const Dimension* d;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, 0), c->Dim(s, 1), &d));
+ *out = c->Matrix(d, d);
+ return Status::OK();
+}
+
+Status UnchangedSquareShapeFn(InferenceContext* c) {
+ const Shape* out;
+ TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &out));
+ c->set_output(0, out);
+ return Status::OK();
+}
+
+// Return in <out> the result of making the end of <s> a square matrix.
+Status MakeBatchSquareMatrix(InferenceContext* c, const Shape* input,
+ const Shape** out) {
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
+
+ const Dimension* d;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
+
+ const Shape* batch_shape;
+ TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
+ TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
+ return Status::OK();
+}
+
+Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
+ const Shape* out;
+ TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
+ c->set_output(0, out);
+ return Status::OK();
+}
+
+Status SquareMatrixSolveShapeFn(InferenceContext* c) {
+ const Shape* lhs;
+ const Shape* rhs;
+ TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &lhs));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs));
+
+ // lhs and rhs have the same number of rows. Make a new output
+ // shape that has the merged-rows and the rest of the rhs.
+ const Dimension* rows;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &rows));
+ const Shape* rhs_remaining;
+ TF_RETURN_IF_ERROR(c->Subshape(rhs, 1, &rhs_remaining));
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(rows), rhs_remaining, &rhs));
+ c->set_output(0, rhs);
+ return Status::OK();
+}
+
+// Inputs are [...,M,N] and [...,M,K]. Output is [...,N,K].
+// If <square>, then input is [...,M,M].
+Status BatchMatrixSolveShapeFn(InferenceContext* c, bool square) {
+ const Shape* lhs;
+ const Shape* rhs;
+ if (square) {
+ TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
+ } else {
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
+ }
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
+
+ // Make the common batch subshape between the two dimensions.
+ const Shape* lhs_batch_shape;
+ const Shape* batch_shape;
+ TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
+ TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &batch_shape));
+ TF_RETURN_IF_ERROR(c->Merge(lhs_batch_shape, batch_shape, &batch_shape));
+
+ // lhs and rhs have the same value for m.
+ const Dimension* m;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
+
+ const Dimension* n = c->Dim(lhs, -1);
+ if (square) {
+ TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
+ }
+
+ // Build final shape (batch_shape + n + k) in <out>.
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &out));
+ TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
+ c->set_output(0, out);
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("MatrixDeterminant")
.Input("input: T")
.Output("output: T")
.Attr("T: {float, double}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ }))
.Doc(R"doc(
Calculates the determinant of a square matrix.
@@ -32,11 +140,24 @@ REGISTER_OP("BatchMatrixDeterminant")
.Input("input: T")
.Output("output: T")
.Attr("T: {float, double}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
+
+ const Dimension* unused;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
+
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ }))
.Doc(R"doc(
Calculates the determinants for a batch of square matrices.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-form square matrices. The output is a 1-D tensor containing the determinants
+form square matrices. The output is a tensor containing the determinants
for all input submatrices `[..., :, :]`.
input: Shape is `[..., M, M]`.
@@ -48,6 +169,7 @@ REGISTER_OP("MatrixInverse")
.Output("output: T")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the inverse of a square invertible matrix or its adjoint (conjugate
transpose).
@@ -69,6 +191,7 @@ REGISTER_OP("BatchMatrixInverse")
.Output("output: T")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the inverse of square invertible matrices or their adjoints
(conjugate transposes).
@@ -91,6 +214,7 @@ REGISTER_OP("Cholesky")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the Cholesky decomposition of a square matrix.
@@ -109,6 +233,7 @@ REGISTER_OP("BatchCholesky")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the Cholesky decomposition of a batch of square matrices.
@@ -126,6 +251,7 @@ REGISTER_OP("CholeskyGrad")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, double}")
+ .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
@@ -144,6 +270,7 @@ REGISTER_OP("BatchCholeskyGrad")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, double}")
+ .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn))
.Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
@@ -163,6 +290,16 @@ REGISTER_OP("SelfAdjointEig")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input));
+
+ const Dimension* d = c->Dim(input, 0);
+ const Dimension* d_plus_1;
+ TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
+ c->set_output(0, c->Matrix(d_plus_1, d));
+ return Status::OK();
+ }))
.Doc(R"doc(
Calculates the Eigen Decomposition of a square Self-Adjoint matrix.
@@ -180,6 +317,20 @@ REGISTER_OP("BatchSelfAdjointEig")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* input;
+ TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
+
+ const Dimension* d = c->Dim(input, -1);
+ const Dimension* d_plus_1;
+ TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
+
+ const Shape* s;
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
+ TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
+ c->set_output(0, s);
+ return Status::OK();
+ }))
.Doc(R"doc(
Calculates the Eigen Decomposition of a batch of square self-adjoint matrices.
@@ -200,6 +351,7 @@ REGISTER_OP("MatrixSolve")
.Output("output: T")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(SquareMatrixSolveShapeFn))
.Doc(R"doc(
Solves a system of linear equations. Checks for invertibility.
@@ -217,6 +369,9 @@ REGISTER_OP("BatchMatrixSolve")
.Output("output: T")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return BatchMatrixSolveShapeFn(c, true /* square (*/);
+ }))
.Doc(R"doc(
Solves systems of linear equations. Checks for invertibility.
@@ -241,6 +396,7 @@ REGISTER_OP("MatrixTriangularSolve")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn(SquareMatrixSolveShapeFn))
.Doc(R"doc(
Solves a system of linear equations with an upper or lower triangular matrix by
backsubstitution.
@@ -272,6 +428,9 @@ REGISTER_OP("BatchMatrixTriangularSolve")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
.Attr("T: {double, float}")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return BatchMatrixSolveShapeFn(c, true /* square (*/);
+ }))
.Doc(R"doc(
Solves systems of linear equations with upper or lower triangular matrices by
backsubstitution.
@@ -306,6 +465,19 @@ REGISTER_OP("MatrixSolveLs")
.Output("output: T")
.Attr("T: {double, float}")
.Attr("fast: bool = True")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* lhs;
+ const Shape* rhs;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &lhs));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs));
+
+ // The matrix and right-hand side must have the same number of rows.
+ const Dimension* unused;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &unused));
+
+ c->set_output(0, c->Matrix(c->Dim(lhs, 1), c->Dim(rhs, 1)));
+ return Status::OK();
+ }))
.Doc(R"doc(
Solves a linear least-squares problem.
@@ -349,6 +521,9 @@ REGISTER_OP("BatchMatrixSolveLs")
.Output("output: T")
.Attr("T: {double, float}")
.Attr("fast: bool = True")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ return BatchMatrixSolveShapeFn(c, false /* square */);
+ }))
.Doc(R"doc(
Solves multiple linear least-squares problems.
diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc
new file mode 100644
index 0000000000..7c88fc97b0
--- /dev/null
+++ b/tensorflow/core/ops/linalg_ops_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(LinalgOpsTest, MatrixDeterminant_ShapeFn) {
+ ShapeInferenceTestOp op("MatrixDeterminant");
+ INFER_OK(op, "?", "[]");
+ INFER_OK(op, "[?,?]", "[]");
+ INFER_OK(op, "[1,?]", "[]");
+ INFER_OK(op, "[?,1]", "[]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]");
+}
+
+TEST(LinalgOpsTest, BatchMatrixDeterminant_ShapeFn) {
+ ShapeInferenceTestOp op("BatchMatrixDeterminant");
+ INFER_OK(op, "?", "?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 2 and 1", op, "[1,?,3,4,1,2]");
+
+ INFER_OK(op, "[?,?]", "[]");
+ INFER_OK(op, "[1,?]", "[]");
+ INFER_OK(op, "[?,1]", "[]");
+
+ // Repeat previous block of tests with input rank > 2.
+ INFER_OK(op, "[1,?,3,4,?,?]", "[d0_0,d0_1,d0_2,d0_3]");
+ INFER_OK(op, "[1,?,3,4,1,?]", "[d0_0,d0_1,d0_2,d0_3]");
+ INFER_OK(op, "[1,?,3,4,?,1]", "[d0_0,d0_1,d0_2,d0_3]");
+}
+
+TEST(LinalgOpsTest, UnchangedSquare_ShapeFn) {
+ for (const char* op_name : {"Cholesky", "CholeskyGrad", "MatrixInverse"}) {
+ ShapeInferenceTestOp op(op_name);
+ INFER_OK(op, "?", "[?,?]");
+ INFER_OK(op, "[?,?]", "[d0_0|d0_1,d0_0|d0_1]");
+ INFER_OK(op, "[1,?]", "[d0_0,d0_0]");
+ INFER_OK(op, "[?,1]", "[d0_1,d0_1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]");
+ }
+}
+
+TEST(LinalgOpsTest, BatchUnchangedSquare_ShapeFn) {
+ for (const char* op_name :
+ {"BatchCholesky", "BatchCholeskyGrad", "BatchMatrixInverse"}) {
+ ShapeInferenceTestOp op(op_name);
+ INFER_OK(op, "?", "?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]");
+
+ INFER_OK(op, "[?,?]", "[d0_0|d0_1,d0_0|d0_1]");
+ INFER_OK(op, "[1,?]", "[d0_0,d0_0]");
+ INFER_OK(op, "[?,1]", "[d0_1,d0_1]");
+
+ // Repeat previous block of tests with input rank > 2.
+ INFER_OK(op, "[5,?,7,?,?]", "[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]");
+ INFER_OK(op, "[5,?,7,1,?]", "[d0_0,d0_1,d0_2,d0_3,d0_3]");
+ INFER_OK(op, "[5,?,7,?,1]", "[d0_0,d0_1,d0_2,d0_4,d0_4]");
+ }
+}
+
+TEST(LinalgOpsTest, SelfAdjointEig_ShapeFn) {
+ ShapeInferenceTestOp op("SelfAdjointEig");
+ INFER_OK(op, "?", "[?,?]");
+ INFER_OK(op, "[?,?]", "[?,d0_0|d0_1]");
+ INFER_OK(op, "[1,?]", "[2,d0_0]");
+ INFER_OK(op, "[?,1]", "[2,d0_1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]");
+}
+
+TEST(LinalgOpsTest, BatchSelfAdjointEig_ShapeFn) {
+ ShapeInferenceTestOp op("BatchSelfAdjointEig");
+ INFER_OK(op, "?", "?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]");
+
+ INFER_OK(op, "[?,?]", "[?,d0_0|d0_1]");
+ INFER_OK(op, "[1,?]", "[2,d0_0]");
+ INFER_OK(op, "[?,1]", "[2,d0_1]");
+
+ // Repeat previous block of tests with input rank > 2.
+ INFER_OK(op, "[5,?,7,?,?]", "[d0_0,d0_1,d0_2,?,d0_3|d0_4]");
+ INFER_OK(op, "[5,?,7,1,?]", "[d0_0,d0_1,d0_2,2,d0_3]");
+ INFER_OK(op, "[5,?,7,?,1]", "[d0_0,d0_1,d0_2,2,d0_4]");
+}
+
+TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) {
+ for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) {
+ ShapeInferenceTestOp op(op_name);
+ INFER_OK(op, "?;?", "[?,?]");
+ INFER_OK(op, "[?,?];?", "[d0_0,?]");
+
+ // Inputs are [M,M] and [M,K]. Output is [M,K].
+ INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
+ INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
+ INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
+ INFER_OK(op, "[1,1];[?,?]", "[d0_0|d0_1,d1_1]");
+ INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1];?");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "?;[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
+ }
+}
+
+TEST(LinalgOpsTest, BatchSquareMatrixSolve_ShapeFn) {
+ for (const char* op_name :
+ {"BatchMatrixSolve", "BatchMatrixTriangularSolve"}) {
+ ShapeInferenceTestOp op(op_name);
+ INFER_OK(op, "?;?", "?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[1]");
+ INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?");
+
+ INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]");
+
+ // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K].
+ // First test where ... is empty.
+ INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]");
+ INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]");
+ INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]");
+ INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]");
+ INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]");
+
+ // Test with ... being 2-d.
+ INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]");
+ INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]");
+ INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]");
+ INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]");
+ INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]");
+ }
+}
+
+TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) {
+ ShapeInferenceTestOp op("MatrixSolveLs");
+ INFER_OK(op, "?;?", "[?,?]");
+
+ // Inputs are [M,N] and [M,K]. Output is [N,K]
+ INFER_OK(op, "[1,?];[1,?]", "[d0_1,d1_1]");
+ INFER_OK(op, "[1,2];[1,3]", "[d0_1,d1_1]");
+
+ // First dims must be compatible.
+ INFER_ERROR("Dimensions must be equal, but are 5 and 6", op, "[5,?];[6,?]");
+
+ // Rank checks.
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1];?");
+ INFER_ERROR("Shape must be rank 2 but is rank 1", op, "?;[1]");
+}
+
+TEST(LinalgOpsTest, BatchMatrixSolveLs_ShapeFn) {
+ ShapeInferenceTestOp op("BatchMatrixSolveLs");
+ INFER_OK(op, "?;?", "?");
+
+ // Inputs are [...,M,N] and [...,M,K]. Output is [...,N,K]
+
+ // Test with no batch dims.
+ INFER_OK(op, "[1,?];[1,?]", "[d0_1,d1_1]");
+ INFER_OK(op, "[1,2];[1,3]", "[d0_1,d1_1]");
+ INFER_ERROR("Dimensions must be equal, but are 5 and 6", op, "[5,?];[6,?]");
+
+ // Test with batch dims.
+ INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_3,d1_3]");
+ INFER_OK(op, "[10,20,1,2];[10,20,1,3]", "[d0_0|d1_0,d0_1|d1_1,d0_3,d1_3]");
+ INFER_ERROR("Dimensions must be equal, but are 5 and 6", op,
+ "[10,?,5,?];[?,20,6,?]");
+
+ // Rank checks.
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
+ INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[1]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 861c04dd63..0e76f772ca 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -93,7 +93,7 @@ def _BatchSelfAdjointEigShape(op):
@ops.RegisterShape("MatrixTriangularSolve")
def _SquareMatrixSolveShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank(2)
- rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
+ rhs_shape = op.inputs[1].get_shape().with_rank(2)
# The matrix must be square.
lhs_shape[0].assert_is_compatible_with(lhs_shape[1])
# The matrix and right-hand side must have the same number of rows.
@@ -117,7 +117,7 @@ def _BatchSquareMatrixSolveShape(op):
@ops.RegisterShape("MatrixSolveLs")
def _MatrixSolveLsShape(op):
lhs_shape = op.inputs[0].get_shape().with_rank(2)
- rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2)
+ rhs_shape = op.inputs[1].get_shape().with_rank(2)
# The matrix and right-hand side must have the same number of rows.
lhs_shape[0].assert_is_compatible_with(rhs_shape[0])
return [[lhs_shape[1], rhs_shape[1]]]