diff options
author | 2016-07-21 15:07:26 -0800 | |
---|---|---|
committer | 2016-07-21 16:20:18 -0700 | |
commit | b35842b44b4a414287a09d39d7d365993ad381e9 (patch) | |
tree | a498202a5a90ebee174b34546ac847578483390a | |
parent | 2641ef0c440161c8f75683d29d4ec53c52de51ed (diff) |
Add C++ shape inference function for BatchMatMul.
Change: 128117162
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops_test.cc | 39 |
2 files changed, 72 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 7dcc2d6aad..152d3627cb 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -125,6 +125,39 @@ REGISTER_OP("BatchMatMul") .Attr("T: {half, float, double, int32, complex64, complex128}") .Attr("adj_x: bool = false") .Attr("adj_y: bool = false") + .SetShapeFn([](InferenceContext* c) { + const Shape* a_shape; + const Shape* b_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &a_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 3, &b_shape)); + + // Determine output rows and cols. + bool adj_x; + bool adj_y; + TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x)); + TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y)); + const Dimension* output_rows = c->Dim(a_shape, adj_x ? -1 : -2); + const Dimension* output_cols = c->Dim(b_shape, adj_y ? -2 : -1); + + // Batch dims match between inputs. + const Shape* a_batch_dims; + const Shape* b_batch_dims; + const Shape* batch_dims; + TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims)); + TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims)); + TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); + + // Assert inner dims match. + const Dimension* unused; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1), + c->Dim(b_shape, adj_y ? -1 : -2), &unused)); + + const Shape* out; + TF_RETURN_IF_ERROR(c->Concatenate( + batch_dims, c->Matrix(output_rows, output_cols), &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Multiplies slices of two tensors in batches. @@ -729,7 +762,6 @@ Returns the truth value of x OR y element-wise. // -------------------------------------------------------------------------- -// TODO(cwhipkey): review what the python code here does. REGISTER_OP("Select") .Input("condition: bool") .Input("t: T") diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 422b99a2d5..61da4a3ac8 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -310,4 +310,43 @@ TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) { op, "[3];[3];?"); } +TEST(MathOpsTest, BatchMatMul_ShapeFn) { + ShapeInferenceTestOp op("BatchMatMul"); + auto set_adj = [&op](bool adj_x, bool adj_y) { + TF_CHECK_OK(NodeDefBuilder("test", "BatchMatMul") + .Input({"a", 0, DT_FLOAT}) + .Input({"b", 0, DT_FLOAT}) + .Attr("adj_x", adj_x) + .Attr("adj_y", adj_y) + .Finalize(&op.node_def)); + }; + + set_adj(false, false); + + // Rank checks. + INFER_ERROR("at least rank 3", op, "[1,2];?"); + INFER_ERROR("at least rank 3", op, "?;[1,2]"); + + INFER_OK(op, "?;?", "?"); + + // 2 batch dims. + INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]"); + + // Test adj_a, testing output and that inner dims are compared. + set_adj(false, false); + INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch + set_adj(true, false); + INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]"); + INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch + + // Test adj_b=true. + set_adj(false, true); + INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]"); + INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch + set_adj(true, true); + INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]"); + INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch +} + } // end namespace tensorflow |