diff options
author | 2018-07-17 14:24:43 -0700 | |
---|---|---|
committer | 2018-07-17 14:28:24 -0700 | |
commit | 8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch) | |
tree | dfcdb842e4871717c64f15b05e749acbb8a64ac7 | |
parent | 2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff) |
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 12 |
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 214146cf68..35df792b07 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -970,6 +970,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (operand_shapes.size() == 1) { return *operand_shapes[0]; } else if (operand_shapes.size() == 2) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[1])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), + ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 9b1ce143c6..6046d50c6d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1524,6 +1524,18 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } +TEST_F(ShapeInferenceTest, BadSort) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values = ShapeUtil::MakeShape(F32, {5}); + StatusOr<Shape> statusor = + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); + ASSERT_FALSE(statusor.ok()); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} + class GatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); |