From 6696de2b02d9f2c6d16fd39ef388be87309525bb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Jul 2016 16:55:42 -0800 Subject: Add C++ shape inference functions for linalg_ops.cc. Also changed python inference function for MatrixSolve and MatrixSolveLs to use with_rank instead of with_rank_at_least for rhs. Change: 127781866 --- tensorflow/core/BUILD | 1 + tensorflow/core/ops/linalg_ops.cc | 177 +++++++++++++++++++++++++++++- tensorflow/core/ops/linalg_ops_test.cc | 189 +++++++++++++++++++++++++++++++++ tensorflow/python/ops/linalg_ops.py | 4 +- 4 files changed, 368 insertions(+), 3 deletions(-) create mode 100644 tensorflow/core/ops/linalg_ops_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d60f167229..8eb38d2104 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1856,6 +1856,7 @@ tf_cc_tests( "ops/array_ops_test.cc", "ops/data_flow_ops_test.cc", "ops/io_ops_test.cc", + "ops/linalg_ops_test.cc", "ops/math_ops_test.cc", "ops/nn_ops_test.cc", "ops/training_ops_test.cc", diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 2c8bde919d..3b3681ed82 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -14,13 +14,121 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +typedef shape_inference::Dimension Dimension; +typedef shape_inference::InferenceContext InferenceContext; +typedef shape_inference::Shape Shape; +static constexpr auto kUnknownDim = InferenceContext::kUnknownDim; + +namespace { + +// Return in the result of making a square matrix. +Status MakeSquareMatrix(InferenceContext* c, const Shape* s, + const Shape** out) { + TF_RETURN_IF_ERROR(c->WithRank(s, 2, &s)); + const Dimension* d; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, 0), c->Dim(s, 1), &d)); + *out = c->Matrix(d, d); + return Status::OK(); +} + +Status UnchangedSquareShapeFn(InferenceContext* c) { + const Shape* out; + TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &out)); + c->set_output(0, out); + return Status::OK(); +} + +// Return in the result of making the end of a square matrix. +Status MakeBatchSquareMatrix(InferenceContext* c, const Shape* input, + const Shape** out) { + const Shape* s; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s)); + + const Dimension* d; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d)); + + const Shape* batch_shape; + TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape)); + TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out)); + return Status::OK(); +} + +Status BatchUnchangedSquareShapeFn(InferenceContext* c) { + const Shape* out; + TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out)); + c->set_output(0, out); + return Status::OK(); +} + +Status SquareMatrixSolveShapeFn(InferenceContext* c) { + const Shape* lhs; + const Shape* rhs; + TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &lhs)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs)); + + // lhs and rhs have the same number of rows. Make a new output + // shape that has the merged-rows and the rest of the rhs. + const Dimension* rows; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &rows)); + const Shape* rhs_remaining; + TF_RETURN_IF_ERROR(c->Subshape(rhs, 1, &rhs_remaining)); + TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(rows), rhs_remaining, &rhs)); + c->set_output(0, rhs); + return Status::OK(); +} + +// Inputs are [...,M,N] and [...,M,K]. Output is [...,N,K]. +// If , then input is [...,M,M]. +Status BatchMatrixSolveShapeFn(InferenceContext* c, bool square) { + const Shape* lhs; + const Shape* rhs; + if (square) { + TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); + } else { + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs)); + } + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs)); + + // Make the common batch subshape between the two dimensions. + const Shape* lhs_batch_shape; + const Shape* batch_shape; + TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape)); + TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &batch_shape)); + TF_RETURN_IF_ERROR(c->Merge(lhs_batch_shape, batch_shape, &batch_shape)); + + // lhs and rhs have the same value for m. + const Dimension* m; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m)); + + const Dimension* n = c->Dim(lhs, -1); + if (square) { + TF_RETURN_IF_ERROR(c->Merge(m, n, &n)); + } + + // Build final shape (batch_shape + n + k) in . + const Shape* out; + TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &out)); + TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out)); + c->set_output(0, out); + return Status::OK(); +} + +} // namespace + REGISTER_OP("MatrixDeterminant") .Input("input: T") .Output("output: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input)); + c->set_output(0, c->Scalar()); + return Status::OK(); + })) .Doc(R"doc( Calculates the determinant of a square matrix. @@ -32,11 +140,24 @@ REGISTER_OP("BatchMatrixDeterminant") .Input("input: T") .Output("output: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); + + const Dimension* unused; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused)); + + const Shape* out; + TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out)); + c->set_output(0, out); + return Status::OK(); + })) .Doc(R"doc( Calculates the determinants for a batch of square matrices. The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -form square matrices. The output is a 1-D tensor containing the determinants +form square matrices. The output is a tensor containing the determinants for all input submatrices `[..., :, :]`. input: Shape is `[..., M, M]`. @@ -48,6 +169,7 @@ REGISTER_OP("MatrixInverse") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn)) .Doc(R"doc( Calculates the inverse of a square invertible matrix or its adjoint (conjugate transpose). @@ -69,6 +191,7 @@ REGISTER_OP("BatchMatrixInverse") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn)) .Doc(R"doc( Calculates the inverse of square invertible matrices or their adjoints (conjugate transposes). @@ -91,6 +214,7 @@ REGISTER_OP("Cholesky") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn)) .Doc(R"doc( Calculates the Cholesky decomposition of a square matrix. @@ -109,6 +233,7 @@ REGISTER_OP("BatchCholesky") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn)) .Doc(R"doc( Calculates the Cholesky decomposition of a batch of square matrices. @@ -126,6 +251,7 @@ REGISTER_OP("CholeskyGrad") .Input("grad: T") .Output("output: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(UnchangedSquareShapeFn)) .Doc(R"doc( Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. @@ -144,6 +270,7 @@ REGISTER_OP("BatchCholeskyGrad") .Input("grad: T") .Output("output: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(BatchUnchangedSquareShapeFn)) .Doc(R"doc( Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. @@ -163,6 +290,16 @@ REGISTER_OP("SelfAdjointEig") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(MakeSquareMatrix(c, c->input(0), &input)); + + const Dimension* d = c->Dim(input, 0); + const Dimension* d_plus_1; + TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1)); + c->set_output(0, c->Matrix(d_plus_1, d)); + return Status::OK(); + })) .Doc(R"doc( Calculates the Eigen Decomposition of a square Self-Adjoint matrix. @@ -180,6 +317,20 @@ REGISTER_OP("BatchSelfAdjointEig") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* input; + TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); + + const Dimension* d = c->Dim(input, -1); + const Dimension* d_plus_1; + TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1)); + + const Shape* s; + TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s)); + TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s)); + c->set_output(0, s); + return Status::OK(); + })) .Doc(R"doc( Calculates the Eigen Decomposition of a batch of square self-adjoint matrices. @@ -200,6 +351,7 @@ REGISTER_OP("MatrixSolve") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(SquareMatrixSolveShapeFn)) .Doc(R"doc( Solves a system of linear equations. Checks for invertibility. @@ -217,6 +369,9 @@ REGISTER_OP("BatchMatrixSolve") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return BatchMatrixSolveShapeFn(c, true /* square (*/); + })) .Doc(R"doc( Solves systems of linear equations. Checks for invertibility. @@ -241,6 +396,7 @@ REGISTER_OP("MatrixTriangularSolve") .Attr("lower: bool = True") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn(SquareMatrixSolveShapeFn)) .Doc(R"doc( Solves a system of linear equations with an upper or lower triangular matrix by backsubstitution. @@ -272,6 +428,9 @@ REGISTER_OP("BatchMatrixTriangularSolve") .Attr("lower: bool = True") .Attr("adjoint: bool = False") .Attr("T: {double, float}") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return BatchMatrixSolveShapeFn(c, true /* square (*/); + })) .Doc(R"doc( Solves systems of linear equations with upper or lower triangular matrices by backsubstitution. @@ -306,6 +465,19 @@ REGISTER_OP("MatrixSolveLs") .Output("output: T") .Attr("T: {double, float}") .Attr("fast: bool = True") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + const Shape* lhs; + const Shape* rhs; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &lhs)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &rhs)); + + // The matrix and right-hand side must have the same number of rows. + const Dimension* unused; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, 0), c->Dim(rhs, 0), &unused)); + + c->set_output(0, c->Matrix(c->Dim(lhs, 1), c->Dim(rhs, 1))); + return Status::OK(); + })) .Doc(R"doc( Solves a linear least-squares problem. @@ -349,6 +521,9 @@ REGISTER_OP("BatchMatrixSolveLs") .Output("output: T") .Attr("T: {double, float}") .Attr("fast: bool = True") + .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) { + return BatchMatrixSolveShapeFn(c, false /* square */); + })) .Doc(R"doc( Solves multiple linear least-squares problems. diff --git a/tensorflow/core/ops/linalg_ops_test.cc b/tensorflow/core/ops/linalg_ops_test.cc new file mode 100644 index 0000000000..7c88fc97b0 --- /dev/null +++ b/tensorflow/core/ops/linalg_ops_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(LinalgOpsTest, MatrixDeterminant_ShapeFn) { + ShapeInferenceTestOp op("MatrixDeterminant"); + INFER_OK(op, "?", "[]"); + INFER_OK(op, "[?,?]", "[]"); + INFER_OK(op, "[1,?]", "[]"); + INFER_OK(op, "[?,1]", "[]"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]"); +} + +TEST(LinalgOpsTest, BatchMatrixDeterminant_ShapeFn) { + ShapeInferenceTestOp op("BatchMatrixDeterminant"); + INFER_OK(op, "?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 2 and 1", op, "[1,?,3,4,1,2]"); + + INFER_OK(op, "[?,?]", "[]"); + INFER_OK(op, "[1,?]", "[]"); + INFER_OK(op, "[?,1]", "[]"); + + // Repeat previous block of tests with input rank > 2. + INFER_OK(op, "[1,?,3,4,?,?]", "[d0_0,d0_1,d0_2,d0_3]"); + INFER_OK(op, "[1,?,3,4,1,?]", "[d0_0,d0_1,d0_2,d0_3]"); + INFER_OK(op, "[1,?,3,4,?,1]", "[d0_0,d0_1,d0_2,d0_3]"); +} + +TEST(LinalgOpsTest, UnchangedSquare_ShapeFn) { + for (const char* op_name : {"Cholesky", "CholeskyGrad", "MatrixInverse"}) { + ShapeInferenceTestOp op(op_name); + INFER_OK(op, "?", "[?,?]"); + INFER_OK(op, "[?,?]", "[d0_0|d0_1,d0_0|d0_1]"); + INFER_OK(op, "[1,?]", "[d0_0,d0_0]"); + INFER_OK(op, "[?,1]", "[d0_1,d0_1]"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]"); + } +} + +TEST(LinalgOpsTest, BatchUnchangedSquare_ShapeFn) { + for (const char* op_name : + {"BatchCholesky", "BatchCholeskyGrad", "BatchMatrixInverse"}) { + ShapeInferenceTestOp op(op_name); + INFER_OK(op, "?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]"); + + INFER_OK(op, "[?,?]", "[d0_0|d0_1,d0_0|d0_1]"); + INFER_OK(op, "[1,?]", "[d0_0,d0_0]"); + INFER_OK(op, "[?,1]", "[d0_1,d0_1]"); + + // Repeat previous block of tests with input rank > 2. + INFER_OK(op, "[5,?,7,?,?]", "[d0_0,d0_1,d0_2,d0_3|d0_4,d0_3|d0_4]"); + INFER_OK(op, "[5,?,7,1,?]", "[d0_0,d0_1,d0_2,d0_3,d0_3]"); + INFER_OK(op, "[5,?,7,?,1]", "[d0_0,d0_1,d0_2,d0_4,d0_4]"); + } +} + +TEST(LinalgOpsTest, SelfAdjointEig_ShapeFn) { + ShapeInferenceTestOp op("SelfAdjointEig"); + INFER_OK(op, "?", "[?,?]"); + INFER_OK(op, "[?,?]", "[?,d0_0|d0_1]"); + INFER_OK(op, "[1,?]", "[2,d0_0]"); + INFER_OK(op, "[?,1]", "[2,d0_1]"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]"); +} + +TEST(LinalgOpsTest, BatchSelfAdjointEig_ShapeFn) { + ShapeInferenceTestOp op("BatchSelfAdjointEig"); + INFER_OK(op, "?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2]"); + + INFER_OK(op, "[?,?]", "[?,d0_0|d0_1]"); + INFER_OK(op, "[1,?]", "[2,d0_0]"); + INFER_OK(op, "[?,1]", "[2,d0_1]"); + + // Repeat previous block of tests with input rank > 2. + INFER_OK(op, "[5,?,7,?,?]", "[d0_0,d0_1,d0_2,?,d0_3|d0_4]"); + INFER_OK(op, "[5,?,7,1,?]", "[d0_0,d0_1,d0_2,2,d0_3]"); + INFER_OK(op, "[5,?,7,?,1]", "[d0_0,d0_1,d0_2,2,d0_4]"); +} + +TEST(LinalgOpsTest, SquareMatrixSolve_ShapeFn) { + for (const char* op_name : {"MatrixSolve", "MatrixTriangularSolve"}) { + ShapeInferenceTestOp op(op_name); + INFER_OK(op, "?;?", "[?,?]"); + INFER_OK(op, "[?,?];?", "[d0_0,?]"); + + // Inputs are [M,M] and [M,K]. Output is [M,K]. + INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); + INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); + INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); + INFER_OK(op, "[1,1];[?,?]", "[d0_0|d0_1,d1_1]"); + INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "?;[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); + } +} + +TEST(LinalgOpsTest, BatchSquareMatrixSolve_ShapeFn) { + for (const char* op_name : + {"BatchMatrixSolve", "BatchMatrixTriangularSolve"}) { + ShapeInferenceTestOp op(op_name); + INFER_OK(op, "?;?", "?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[1]"); + INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,2];?"); + + INFER_OK(op, "[?,?];?", "[d0_0|d0_1,?]"); + + // Inputs are [...,M,M] and [...,M,K]. Output is [...,M,K]. + // First test where ... is empty. + INFER_OK(op, "[?,?];[1,?]", "[d1_0,d1_1]"); + INFER_OK(op, "[1,?];[1,?]", "[d0_0|d1_0,d1_1]"); + INFER_OK(op, "[?,1];[1,?]", "[d0_1|d1_0,d1_1]"); + INFER_OK(op, "[1,1];[?,?]", "[d0_0,d1_1]"); + INFER_OK(op, "[1,1];[1,?]", "[d0_0|d0_1|d1_0,d1_1]"); + + // Test with ... being 2-d. + INFER_OK(op, "[10,?,?,?];[?,20,1,?]", "[d0_0,d1_1,d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_2|d1_2,d1_3]"); + INFER_OK(op, "[10,?,?,1];[?,20,1,?]", "[d0_0,d1_1,d0_3|d1_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,?,?]", "[d0_0,d1_1,d0_2,d1_3]"); + INFER_OK(op, "[10,?,1,1];[?,20,1,?]", "[d0_0,d1_1,d0_2|d0_3|d1_2,d1_3]"); + } +} + +TEST(LinalgOpsTest, MatrixSolveLs_ShapeFn) { + ShapeInferenceTestOp op("MatrixSolveLs"); + INFER_OK(op, "?;?", "[?,?]"); + + // Inputs are [M,N] and [M,K]. Output is [N,K] + INFER_OK(op, "[1,?];[1,?]", "[d0_1,d1_1]"); + INFER_OK(op, "[1,2];[1,3]", "[d0_1,d1_1]"); + + // First dims must be compatible. + INFER_ERROR("Dimensions must be equal, but are 5 and 6", op, "[5,?];[6,?]"); + + // Rank checks. + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Shape must be rank 2 but is rank 1", op, "?;[1]"); +} + +TEST(LinalgOpsTest, BatchMatrixSolveLs_ShapeFn) { + ShapeInferenceTestOp op("BatchMatrixSolveLs"); + INFER_OK(op, "?;?", "?"); + + // Inputs are [...,M,N] and [...,M,K]. Output is [...,N,K] + + // Test with no batch dims. + INFER_OK(op, "[1,?];[1,?]", "[d0_1,d1_1]"); + INFER_OK(op, "[1,2];[1,3]", "[d0_1,d1_1]"); + INFER_ERROR("Dimensions must be equal, but are 5 and 6", op, "[5,?];[6,?]"); + + // Test with batch dims. + INFER_OK(op, "[10,?,1,?];[?,20,1,?]", "[d0_0,d1_1,d0_3,d1_3]"); + INFER_OK(op, "[10,20,1,2];[10,20,1,3]", "[d0_0|d1_0,d0_1|d1_1,d0_3,d1_3]"); + INFER_ERROR("Dimensions must be equal, but are 5 and 6", op, + "[10,?,5,?];[?,20,6,?]"); + + // Rank checks. + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?"); + INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[1]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 861c04dd63..0e76f772ca 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -93,7 +93,7 @@ def _BatchSelfAdjointEigShape(op): @ops.RegisterShape("MatrixTriangularSolve") def _SquareMatrixSolveShape(op): lhs_shape = op.inputs[0].get_shape().with_rank(2) - rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) + rhs_shape = op.inputs[1].get_shape().with_rank(2) # The matrix must be square. lhs_shape[0].assert_is_compatible_with(lhs_shape[1]) # The matrix and right-hand side must have the same number of rows. @@ -117,7 +117,7 @@ def _BatchSquareMatrixSolveShape(op): @ops.RegisterShape("MatrixSolveLs") def _MatrixSolveLsShape(op): lhs_shape = op.inputs[0].get_shape().with_rank(2) - rhs_shape = op.inputs[1].get_shape().with_rank_at_least(2) + rhs_shape = op.inputs[1].get_shape().with_rank(2) # The matrix and right-hand side must have the same number of rows. lhs_shape[0].assert_is_compatible_with(rhs_shape[0]) return [[lhs_shape[1], rhs_shape[1]]] -- cgit v1.2.3