aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_tree.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree.h')
-rw-r--r--tensorflow/compiler/xla/shape_tree.h9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 52c895e8d4..df610102b4 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -224,14 +224,13 @@ class ShapeTree {
// REQUIRES: index must exist in the ShapeTree.
iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
- return iterator(&nodes_, typename std::vector<Node>::iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.begin() + (element - &nodes_[0]);
+ return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
- return iterator(&nodes_,
- typename std::vector<Node>::const_iterator(element),
- /*iterate_leaves_only=*/false);
+ auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
+ return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
}
// Returns the number of leaf nodes in the tree.