diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-07-17 14:24:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-17 14:28:24 -0700 |
commit | 8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch) | |
tree | dfcdb842e4871717c64f15b05e749acbb8a64ac7 /tensorflow/compiler/xla/service/shape_inference_test.cc | |
parent | 2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff) |
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 9b1ce143c6..6046d50c6d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1524,6 +1524,18 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } +TEST_F(ShapeInferenceTest, BadSort) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values = ShapeUtil::MakeShape(F32, {5}); + StatusOr<Shape> statusor = + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); + ASSERT_FALSE(statusor.ok()); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} + class GatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); |