aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-21 15:07:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-21 16:20:18 -0700
commitb35842b44b4a414287a09d39d7d365993ad381e9 (patch)
treea498202a5a90ebee174b34546ac847578483390a
parent2641ef0c440161c8f75683d29d4ec53c52de51ed (diff)
Add C++ shape inference function for BatchMatMul.
Change: 128117162
-rw-r--r--tensorflow/core/ops/math_ops.cc34
-rw-r--r--tensorflow/core/ops/math_ops_test.cc39
2 files changed, 72 insertions, 1 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 7dcc2d6aad..152d3627cb 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -125,6 +125,39 @@ REGISTER_OP("BatchMatMul")
.Attr("T: {half, float, double, int32, complex64, complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ const Shape* a_shape;
+ const Shape* b_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &a_shape));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 3, &b_shape));
+
+ // Determine output rows and cols.
+ bool adj_x;
+ bool adj_y;
+ TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
+ TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
+ const Dimension* output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
+ const Dimension* output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
+
+ // Batch dims match between inputs.
+ const Shape* a_batch_dims;
+ const Shape* b_batch_dims;
+ const Shape* batch_dims;
+ TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
+ TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
+ TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
+
+ // Assert inner dims match.
+ const Dimension* unused;
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
+ c->Dim(b_shape, adj_y ? -1 : -2), &unused));
+
+ const Shape* out;
+ TF_RETURN_IF_ERROR(c->Concatenate(
+ batch_dims, c->Matrix(output_rows, output_cols), &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
.Doc(R"doc(
Multiplies slices of two tensors in batches.
@@ -729,7 +762,6 @@ Returns the truth value of x OR y element-wise.
// --------------------------------------------------------------------------
-// TODO(cwhipkey): review what the python code here does.
REGISTER_OP("Select")
.Input("condition: bool")
.Input("t: T")
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 422b99a2d5..61da4a3ac8 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -310,4 +310,43 @@ TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
op, "[3];[3];?");
}
+TEST(MathOpsTest, BatchMatMul_ShapeFn) {
+ ShapeInferenceTestOp op("BatchMatMul");
+ auto set_adj = [&op](bool adj_x, bool adj_y) {
+ TF_CHECK_OK(NodeDefBuilder("test", "BatchMatMul")
+ .Input({"a", 0, DT_FLOAT})
+ .Input({"b", 0, DT_FLOAT})
+ .Attr("adj_x", adj_x)
+ .Attr("adj_y", adj_y)
+ .Finalize(&op.node_def));
+ };
+
+ set_adj(false, false);
+
+ // Rank checks.
+ INFER_ERROR("at least rank 3", op, "[1,2];?");
+ INFER_ERROR("at least rank 3", op, "?;[1,2]");
+
+ INFER_OK(op, "?;?", "?");
+
+ // 2 batch dims.
+ INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
+
+ // Test adj_a, testing output and that inner dims are compared.
+ set_adj(false, false);
+ INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
+ INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch
+ set_adj(true, false);
+ INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
+ INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch
+
+ // Test adj_b=true.
+ set_adj(false, true);
+ INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
+ INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch
+ set_adj(true, true);
+ INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
+ INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch
+}
+
} // end namespace tensorflow