diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree.h')
-rw-r--r-- | tensorflow/compiler/xla/shape_tree.h | 140 |
1 files changed, 110 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 4aacc87b78..c74dd648ad 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,10 +44,6 @@ struct ShapeTreeNode { // Data corresponding to this node. std::pair<ShapeIndex, T> data; - // Children of this node, as indices into the container's nodes_ array. - std::vector<size_t> children; - - // Tells whether this is a leaf node. bool is_leaf = true; explicit ShapeTreeNode(ShapeIndex index) @@ -56,6 +52,20 @@ struct ShapeTreeNode { : data(std::move(index), std::move(data)) {} }; +// Internal representation of an index table entry. +struct IndexTableEntry { + // Index of the node in the ShapeTreeNode vector. + uint32 index; + // Index of the first child in a IndexTableEntry vector. In the index + // table all children entries for a given node will be placed next to each + // other. This allows us to use a single field to index them. + uint32 children_start; +#ifndef NDEBUG + // Number of children, used for bounds checking. + uint32 children_count; +#endif +}; + } // namespace internal template <typename ContainerType, typename IteratorType, typename ValueType> @@ -84,6 +94,7 @@ template <typename T> class ShapeTree { public: using Node = internal::ShapeTreeNode<T>; + using Index = internal::IndexTableEntry; // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -267,11 +278,12 @@ class ShapeTree { private: // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. - void InitChildren(const Shape& shape, const T& init_value, Node* node); + void InitChildren(const Shape& shape, const T& init_value, Node* node, + Index* index); // Initialize node->children based on 'shape'. All children have // default-constructed data values. - void InitChildren(const Shape& shape, Node* node); + void InitChildren(const Shape& shape, Node* node, Index* index); // Returns the number of subshapes, including interior nodes, in shape. int64 CountSubshapes(const Shape& shape); @@ -291,6 +303,9 @@ class ShapeTree { // The nodes in this shape tree. std::vector<Node> nodes_; + // Index table for node lookups. + std::vector<Index> index_table_; + // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. std::shared_ptr<Shape> shape_storage_; @@ -373,36 +388,74 @@ int64 ShapeTree<T>::CountSubshapes(const Shape& shape) { template <typename T> void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value, - Node* node) { + Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index, init_value); - InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } template <typename T> -void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) { +void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index); - InitChildren(shape.tuple_shapes(i), &nodes_.back()); + InitChildren(shape.tuple_shapes(i), &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } @@ -413,24 +466,36 @@ ShapeTree<T>::ShapeTree(Shape shape) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template <typename T> @@ -440,26 +505,38 @@ ShapeTree<T>::ShapeTree(Shape shape, const T& init_value) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template <typename T> ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template <typename T> @@ -474,13 +551,16 @@ T* ShapeTree<T>::mutable_element(ShapeIndexView index) { template <typename T> internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) { - Node* node = &nodes_[0]; + Index* iter = &index_table_[0]; for (const int64 i : index) { CHECK_GE(i, 0); - CHECK_LT(i, node->children.size()); - node = &nodes_[node->children[i]]; +#ifndef NDEBUG + CHECK_LT(i, iter->children_count); +#endif + iter = &index_table_[iter->children_start + i]; } - return node; + + return &nodes_[iter->index]; } template <typename T> |