aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-07-17 14:24:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 14:28:24 -0700
commit8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch)
treedfcdb842e4871717c64f15b05e749acbb8a64ac7
parent2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff)
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc8
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc12
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 214146cf68..35df792b07 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -970,6 +970,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
} else if (operand_shapes.size() == 2) {
+ if (!ShapeUtil::SameDimensions(*operand_shapes[0],
+ *operand_shapes[1])) {
+ return InvalidArgument(
+ "Sort keys and values dimensions must match. "
+ "Keys shape is: %s\n, Values shape is: %s",
+ ShapeUtil::HumanString(*operand_shapes[0]).c_str(),
+ ShapeUtil::HumanString(*operand_shapes[1]).c_str());
+ }
return ShapeUtil::MakeTupleShape(
{*operand_shapes[0], *operand_shapes[1]});
}
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 9b1ce143c6..6046d50c6d 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1524,6 +1524,18 @@ TEST_F(ShapeInferenceTest, BadSlice) {
<< statusor.status();
}
+TEST_F(ShapeInferenceTest, BadSort) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values = ShapeUtil::MakeShape(F32, {5});
+ StatusOr<Shape> statusor =
+ ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
+ ASSERT_FALSE(statusor.ok());
+
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("dimensions must match"))
+ << statusor.status();
+}
+
class GatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});