aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 21:22:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 21:26:37 -0700
commit725dfe9cd0eef3f4b858eaeda38728813c99a210 (patch)
treeb51c0d7d188693cf11e789e5b5e37a194bbb0198 /tensorflow/compiler/xla/shape_util_test.cc
parentf4d8442e13356ab645446c9f4a9b3b6cedddcd63 (diff)
internal change only.
PiperOrigin-RevId: 212754752
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 6ca4085aaf..c622ecdca1 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
+TEST(ShapeUtilTest, HasPrimitiveType) {
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
+ S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
+ S16));
+}
+
TEST(ShapeUtilTest, IsZeroElementArray) {
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));