aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc26
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 864ed43118..7b65e8c1c9 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) {
auto values = ShapeUtil::MakeShape(F32, {5});
StatusOr<Shape> statusor =
ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
- ASSERT_FALSE(statusor.ok());
+ EXPECT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("dimensions must match"))
+ << statusor.status();
+}
+TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_good = ShapeUtil::MakeShape(F32, {4});
+ auto values_bad = ShapeUtil::MakeShape(F32, {5});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_good, &values_bad});
+ EXPECT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("dimensions must match"))
<< statusor.status();
}
+TEST_F(ShapeInferenceTest, SortManyValues) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_s32 = ShapeUtil::MakeShape(S32, {4});
+ auto values_u32 = ShapeUtil::MakeShape(U32, {4});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_s32, &values_u32});
+ EXPECT_IS_OK(statusor);
+ Shape inferred_shape = statusor.ValueOrDie();
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ inferred_shape,
+ ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
+}
+
class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});