aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util.cc11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 34869cc507..b69c346f1e 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
+ if (!IsTuple(shape)) {
+ return 1;
+ }
int64 count = 0;
- ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) {
- if (IsLeafIndex(shape, index)) {
- ++count;
- }
- });
+ for (const Shape& subshape : shape.tuple_shapes()) {
+ count += GetLeafCount(subshape);
+ }
return count;
}