diff options
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops_test.cc | 11 |
2 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index d30b847696..df75caca37 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2331,11 +2331,25 @@ REGISTER_OP("Cross") .Input("b: T") .Output("product: T") .Attr("T: realnumbertype") - // TODO(cwhipkey): implement these shape inference constraints here: - // * Both inputs have the same shape. - // * Input rank >= 1. - // * input_shape[-1] == 3. - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle a_shape; + ShapeHandle b_shape; + // * Input rank >= 1. + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape)); + + // * Both inputs have the same shape. + TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape)); + + // * input_shape[-1] == 3. + if (c->RankKnown(a_shape)) { + int rank = c->Rank(a_shape); + auto dim = c->Dim(a_shape, rank - 1); + TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim)); + } + c->set_output(0, a_shape); + return Status::OK(); + }) .Doc(R"doc( Compute the pairwise cross product. diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 28f9969de5..3dfa776d26 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -515,4 +515,15 @@ TEST(MathOpstest, RequantizationRange_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;[2]"); } +TEST(MathOpsTest, Cross_ShapeFn) { + ShapeInferenceTestOp op("Cross"); + + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]"); + INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]"); + INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]"); + + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[?];[?]", "in0"); + INFER_OK(op, "[1,?,3];[?,?,?]", "in0"); +} } // end namespace tensorflow |