diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 864ed43118..7b65e8c1c9 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); |