diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_util_test.cc | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 13582a2a26..f7675e97da 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -713,6 +713,16 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, StripDegenerateDimensions) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShape(F32, {3, 1, 2})), + ShapeUtil::MakeShape(F32, {3, 2}))); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), |