aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/ops/math_ops.cc24
-rw-r--r--tensorflow/core/ops/math_ops_test.cc11
2 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index d30b847696..df75caca37 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -2331,11 +2331,25 @@ REGISTER_OP("Cross")
.Input("b: T")
.Output("product: T")
.Attr("T: realnumbertype")
- // TODO(cwhipkey): implement these shape inference constraints here:
- // * Both inputs have the same shape.
- // * Input rank >= 1.
- // * input_shape[-1] == 3.
- .SetShapeFn(shape_inference::UnchangedShape)
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle a_shape;
+ ShapeHandle b_shape;
+ // * Input rank >= 1.
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
+
+ // * Both inputs have the same shape.
+ TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
+
+ // * input_shape[-1] == 3.
+ if (c->RankKnown(a_shape)) {
+ int rank = c->Rank(a_shape);
+ auto dim = c->Dim(a_shape, rank - 1);
+ TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
+ }
+ c->set_output(0, a_shape);
+ return Status::OK();
+ })
.Doc(R"doc(
Compute the pairwise cross product.
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 28f9969de5..3dfa776d26 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -515,4 +515,15 @@ TEST(MathOpstest, RequantizationRange_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;[2]");
}
+TEST(MathOpsTest, Cross_ShapeFn) {
+ ShapeInferenceTestOp op("Cross");
+
+ INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
+ INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
+ INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
+
+ INFER_OK(op, "?;?", "?");
+ INFER_OK(op, "[?];[?]", "in0");
+ INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
+}
} // end namespace tensorflow