aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-27 18:24:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-27 18:27:28 -0700
commitb2b8dca5833344a0dfe4233ad57c907f3c553f0d (patch)
treeb97799f7c80606be41918d0c7e9a6422c322e4e5 /tensorflow/compiler/xla/shape_util.cc
parent864e0566bd0da15b5f93bcb1873c1e19b90f83cc (diff)
[XLA] Fix bug in ShapeUtil::StripDegenerateDimensions
PiperOrigin-RevId: 194621163
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util.cc15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index ac7e201bfd..d58baa3220 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -905,10 +905,17 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
dims.begin()));
}
- Shape stripped_shape =
- shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
- dimension_sizes, minor_to_major)
- : MakeShape(shape.element_type(), dimension_sizes);
+ Shape stripped_shape;
+ if (LayoutUtil::IsDenseArray(shape)) {
+ stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
+ minor_to_major);
+ } else if (LayoutUtil::IsSparseArray(shape)) {
+ stripped_shape =
+ MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
+ shape.layout().max_sparse_elements());
+ } else {
+ stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
+ }
VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);