aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-10-09 19:41:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 19:46:05 -0700
commit58fcfc98cd59ae3952399fc55380b8733df08df9 (patch)
tree24d5ac5d6691e73c227f5afa5ef68ba2ecba4ec0 /tensorflow/compiler/xla/service/shape_inference_test.cc
parent93eef55c4d04af24a6c8080f34629db179634f07 (diff)
[XLA] Add documentation and HLO-level support for multi-value sort.
No support in any of the backends, and not yet exposed through XlaBuilder. PiperOrigin-RevId: 216465753
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc26
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 864ed43118..7b65e8c1c9 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) {
auto values = ShapeUtil::MakeShape(F32, {5});
StatusOr<Shape> statusor =
ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
- ASSERT_FALSE(statusor.ok());
+ EXPECT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("dimensions must match"))
+ << statusor.status();
+}
+TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_good = ShapeUtil::MakeShape(F32, {4});
+ auto values_bad = ShapeUtil::MakeShape(F32, {5});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_good, &values_bad});
+ EXPECT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("dimensions must match"))
<< statusor.status();
}
+TEST_F(ShapeInferenceTest, SortManyValues) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_s32 = ShapeUtil::MakeShape(S32, {4});
+ auto values_u32 = ShapeUtil::MakeShape(U32, {4});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_s32, &values_u32});
+ EXPECT_IS_OK(statusor);
+ Shape inferred_shape = statusor.ValueOrDie();
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ inferred_shape,
+ ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
+}
+
class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});