diff options
author | 2018-06-21 08:36:58 -0700 | |
---|---|---|
committer | 2018-06-21 08:40:07 -0700 | |
commit | 4ec30cf37a44b64f0d48aa78adc77c09531dd981 (patch) | |
tree | f05e20a3ead86de3acf9d952ec14c600d6deb6d1 /tensorflow/compiler/xla/shape_tree.h | |
parent | 2c4fb3633e618941c2bed6e1672052706b849189 (diff) |
[XLA] Make ShapeTree use ShapeIndexViews
Avoids creating temporary std::vectors on the consumer side. Also push ShapeIndexViews
through the GPU backend a bit.
PiperOrigin-RevId: 201529722
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree.h')
-rw-r--r-- | tensorflow/compiler/xla/shape_tree.h | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 18e54d23c2..4aacc87b78 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -105,8 +105,8 @@ class ShapeTree { // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). - const T& element(const ShapeIndex& index) const; - T* mutable_element(const ShapeIndex& index); + const T& element(ShapeIndexView index) const; + T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } @@ -125,7 +125,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } + bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -211,12 +211,12 @@ class ShapeTree { // Returns an iterator pointing to the given ShapeIndex. // REQUIRES: index must exist in the ShapeTree. - iterator find(const ShapeIndex& index) { + iterator find(ShapeIndexView index) { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector<Node>::iterator(element), /*iterate_leaves_only=*/false); } - const_iterator find(const ShapeIndex& index) const { + const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector<Node>::const_iterator(element), @@ -285,8 +285,8 @@ class ShapeTree { static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes); // Return the tree node at the given index. - Node* Lookup(const ShapeIndex& index); - const Node* Lookup(const ShapeIndex& index) const; + Node* Lookup(ShapeIndexView index); + const Node* Lookup(ShapeIndexView index) const; // The nodes in this shape tree. std::vector<Node> nodes_; @@ -463,17 +463,17 @@ ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape, } template <typename T> -const T& ShapeTree<T>::element(const ShapeIndex& index) const { +const T& ShapeTree<T>::element(ShapeIndexView index) const { return Lookup(index)->data.second; } template <typename T> -T* ShapeTree<T>::mutable_element(const ShapeIndex& index) { +T* ShapeTree<T>::mutable_element(ShapeIndexView index) { return &Lookup(index)->data.second; } template <typename T> -internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) { +internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) { Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); @@ -485,7 +485,7 @@ internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) { template <typename T> const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup( - const ShapeIndex& index) const { + ShapeIndexView index) const { return const_cast<ShapeTree*>(this)->Lookup(index); } |