aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-07-17 14:24:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 14:28:24 -0700
commit8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch)
treedfcdb842e4871717c64f15b05e749acbb8a64ac7 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff)
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 9b1ce143c6..6046d50c6d 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1524,6 +1524,18 @@ TEST_F(ShapeInferenceTest, BadSlice) {
<< statusor.status();
}
+TEST_F(ShapeInferenceTest, BadSort) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values = ShapeUtil::MakeShape(F32, {5});
+ StatusOr<Shape> statusor =
+ ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
+ ASSERT_FALSE(statusor.ok());
+
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("dimensions must match"))
+ << statusor.status();
+}
+
class GatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});