aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_tree.h
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-06-21 08:36:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 08:40:07 -0700
commit4ec30cf37a44b64f0d48aa78adc77c09531dd981 (patch)
treef05e20a3ead86de3acf9d952ec14c600d6deb6d1 /tensorflow/compiler/xla/shape_tree.h
parent2c4fb3633e618941c2bed6e1672052706b849189 (diff)
[XLA] Make ShapeTree use ShapeIndexViews
Avoids creating temporary std::vectors on the consumer side. Also push ShapeIndexViews through the GPU backend a bit. PiperOrigin-RevId: 201529722
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree.h')
-rw-r--r--tensorflow/compiler/xla/shape_tree.h22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 18e54d23c2..4aacc87b78 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -105,8 +105,8 @@ class ShapeTree {
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
- const T& element(const ShapeIndex& index) const;
- T* mutable_element(const ShapeIndex& index);
+ const T& element(ShapeIndexView index) const;
+ T* mutable_element(ShapeIndexView index);
// Return the shape represented with this ShapeTree.
const Shape& shape() const { return *shape_; }
@@ -125,7 +125,7 @@ class ShapeTree {
// Returns true if the node at the given index is a leaf node (an array
// shape).
- bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; }
+ bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
ShapeTree(const ShapeTree&) = default;
ShapeTree& operator=(const ShapeTree&) = default;
@@ -211,12 +211,12 @@ class ShapeTree {
// Returns an iterator pointing to the given ShapeIndex.
// REQUIRES: index must exist in the ShapeTree.
- iterator find(const ShapeIndex& index) {
+ iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
return iterator(&nodes_, typename std::vector<Node>::iterator(element),
/*iterate_leaves_only=*/false);
}
- const_iterator find(const ShapeIndex& index) const {
+ const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
return iterator(&nodes_,
typename std::vector<Node>::const_iterator(element),
@@ -285,8 +285,8 @@ class ShapeTree {
static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
// Return the tree node at the given index.
- Node* Lookup(const ShapeIndex& index);
- const Node* Lookup(const ShapeIndex& index) const;
+ Node* Lookup(ShapeIndexView index);
+ const Node* Lookup(ShapeIndexView index) const;
// The nodes in this shape tree.
std::vector<Node> nodes_;
@@ -463,17 +463,17 @@ ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
}
template <typename T>
-const T& ShapeTree<T>::element(const ShapeIndex& index) const {
+const T& ShapeTree<T>::element(ShapeIndexView index) const {
return Lookup(index)->data.second;
}
template <typename T>
-T* ShapeTree<T>::mutable_element(const ShapeIndex& index) {
+T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
return &Lookup(index)->data.second;
}
template <typename T>
-internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
+internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
Node* node = &nodes_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
@@ -485,7 +485,7 @@ internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
template <typename T>
const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
- const ShapeIndex& index) const {
+ ShapeIndexView index) const {
return const_cast<ShapeTree*>(this)->Lookup(index);
}