From 05412bd367198ec491ca034b4bc634784c03125c Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Tue, 6 Jun 2017 15:42:43 -0700 Subject: [XLA] Simplify Shape traversal visitors. Simplify shape traversal visitors in ShapeUtil and ShapeTree. Add a non-Status form because most uses of the traversal methods do not use it, and remove is_leaf parameter from ShapeTree.ForEach* as it is not frequently used. PiperOrigin-RevId: 158201574 --- tensorflow/compiler/xla/shape_util_test.cc | 39 +++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) (limited to 'tensorflow/compiler/xla/shape_util_test.cc') diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 73538b8b88..8ac2e8345b 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -322,6 +322,30 @@ TEST(ShapeUtilTest, GetSubshape) { ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0}))); } +TEST(ShapeUtilTest, IsLeafIndex) { + // Test array shape. + Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123}); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(array_shape, {})); + + // Test tuple shape. + Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape}); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(tuple_shape, {})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {0})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {1})); + + // Test nested tuple shape. + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape( + {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({array_shape, array_shape}), + array_shape})}); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {0})); + EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 0})); + EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); +} + TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); @@ -380,13 +404,12 @@ TEST(ShapeUtilTest, HumanString) { TEST(ShapeUtilTest, ForEachSubshapeArray) { const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { EXPECT_EQ(&shape, &subshape); EXPECT_TRUE(index.empty()); ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(1, calls); } @@ -396,7 +419,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), ShapeUtil::MakeShape(PRED, {33})})}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachSubshape( + ShapeUtil::ForEachSubshape( shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) { EXPECT_TRUE( ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index))); @@ -408,8 +431,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) { EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape)); } ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(5, calls); } @@ -419,7 +441,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}), ShapeUtil::MakeShape(PRED, {33})})}); int calls = 0; - EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape( + ShapeUtil::ForEachMutableSubshape( &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) { // Pointer values should be equal EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index)); @@ -431,8 +453,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) { EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape)); } ++calls; - return tensorflow::Status::OK(); - })); + }); EXPECT_EQ(5, calls); } -- cgit v1.2.3