diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 34869cc507..b69c346f1e 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } |