diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-27 18:24:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-27 18:27:28 -0700 |
commit | b2b8dca5833344a0dfe4233ad57c907f3c553f0d (patch) | |
tree | b97799f7c80606be41918d0c7e9a6422c322e4e5 /tensorflow/compiler | |
parent | 864e0566bd0da15b5f93bcb1873c1e19b90f83cc (diff) |
[XLA] Fix bug in ShapeUtil::StripDegenerateDimensions
PiperOrigin-RevId: 194621163
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util_test.cc | 10 |
2 files changed, 21 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ac7e201bfd..d58baa3220 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -905,10 +905,17 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { std::is_permutation(minor_to_major.begin(), minor_to_major.end(), dims.begin())); } - Shape stripped_shape = - shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), - dimension_sizes, minor_to_major) - : MakeShape(shape.element_type(), dimension_sizes); + Shape stripped_shape; + if (LayoutUtil::IsDenseArray(shape)) { + stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, + minor_to_major); + } else if (LayoutUtil::IsSparseArray(shape)) { + stripped_shape = + MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, + shape.layout().max_sparse_elements()); + } else { + stripped_shape = MakeShape(shape.element_type(), dimension_sizes); + } VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 13582a2a26..f7675e97da 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -713,6 +713,16 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, StripDegenerateDimensions) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShape(F32, {3, 1, 2})), + ShapeUtil::MakeShape(F32, {3, 2}))); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), |