aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-06-06 15:42:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 15:46:30 -0700
commit05412bd367198ec491ca034b4bc634784c03125c (patch)
tree6b1e76ec79446337d55055dcc3ca6503bd7b345a /tensorflow/compiler/xla/shape_util_test.cc
parent69c9365b4b71b9ab9663ee4f2a0fb226ce2fd26d (diff)
[XLA] Simplify Shape traversal visitors.
Simplify shape traversal visitors in ShapeUtil and ShapeTree. Add a non-Status form because most uses of the traversal methods do not use it, and remove is_leaf parameter from ShapeTree.ForEach* as it is not frequently used. PiperOrigin-RevId: 158201574
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc39
1 files changed, 30 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 73538b8b88..8ac2e8345b 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -322,6 +322,30 @@ TEST(ShapeUtilTest, GetSubshape) {
ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0})));
}
+TEST(ShapeUtilTest, IsLeafIndex) {
+ // Test array shape.
+ Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123});
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(array_shape, {}));
+
+ // Test tuple shape.
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape});
+ EXPECT_FALSE(ShapeUtil::IsLeafIndex(tuple_shape, {}));
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {0}));
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(tuple_shape, {1}));
+
+ // Test nested tuple shape.
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape(
+ {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ array_shape})});
+ EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {}));
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {0}));
+ EXPECT_FALSE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1}));
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 0}));
+ EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1}));
+}
+
TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
@@ -380,13 +404,12 @@ TEST(ShapeUtilTest, HumanString) {
TEST(ShapeUtilTest, ForEachSubshapeArray) {
const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
int calls = 0;
- EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ ShapeUtil::ForEachSubshape(
shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
EXPECT_EQ(&shape, &subshape);
EXPECT_TRUE(index.empty());
++calls;
- return tensorflow::Status::OK();
- }));
+ });
EXPECT_EQ(1, calls);
}
@@ -396,7 +419,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
ShapeUtil::MakeShape(PRED, {33})})});
int calls = 0;
- EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ ShapeUtil::ForEachSubshape(
shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
EXPECT_TRUE(
ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index)));
@@ -408,8 +431,7 @@ TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) {
EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape));
}
++calls;
- return tensorflow::Status::OK();
- }));
+ });
EXPECT_EQ(5, calls);
}
@@ -419,7 +441,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
ShapeUtil::MakeShape(PRED, {33})})});
int calls = 0;
- EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape(
+ ShapeUtil::ForEachMutableSubshape(
&shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) {
// Pointer values should be equal
EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index));
@@ -431,8 +453,7 @@ TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) {
EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape));
}
++calls;
- return tensorflow::Status::OK();
- }));
+ });
EXPECT_EQ(5, calls);
}