/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ #include #include #include #include #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace internal { // Internal representation of each node in a ShapeTree. template struct ShapeTreeNode { // Data corresponding to this node. std::pair data; bool is_leaf = true; explicit ShapeTreeNode(ShapeIndex index) : ShapeTreeNode(std::move(index), T()) {} ShapeTreeNode(ShapeIndex index, T data) : data(std::move(index), std::move(data)) {} }; // Internal representation of an index table entry. struct IndexTableEntry { // Index of the node in the ShapeTreeNode vector. uint32 index; // Index of the first child in a IndexTableEntry vector. In the index // table all children entries for a given node will be placed next to each // other. This allows us to use a single field to index them. uint32 children_start; #ifndef NDEBUG // Number of children, used for bounds checking. uint32 children_count; #endif }; } // namespace internal template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) // in the shape. For array shapes, a ShapeTree trivially holds a single value of // type T. // // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a // ShapeTree is an identically structured tree with data elements of type T at // every node. I.e. the root is a tuple by definition, all interior nodes are // also tuples, and all leaves are arrays. // // Like the Shape data structure, this is a tree and tuple elements cannot be // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T // object. // // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes // it's helpful not to copy a Shape just to make a ShapeTree. In these cases, // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's // then up to you to ensure that the pointed-to Shape doesn't die or mutate // before its ShapeTree goes away. template class ShapeTree { public: using Node = internal::ShapeTreeNode; using Index = internal::IndexTableEntry; // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} // Create ShapeTree with the given shape, and default-constructed T values for // all nodes. // // The version that takes a pointer may be cheaper because it doesn't require // any Shape copies, but then it's up to you to ensure that the pointer stays // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). const T& element(ShapeIndexView index) const; T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } // Replaces *only* the underlying shape of this ShapeTree. The caller must own // the Shape object and hence shape_storage_ is not updated. // // Only safe to use this if the ShapeTree was constructed with 'explicit // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The // caller must ensure that the input shape is consistent with the underlying // tree. void replace_shape_ptr(const Shape* shape) { CHECK(shape_storage_.get() == nullptr); shape_ = shape; } // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; ShapeTree(ShapeTree&&) = default; ShapeTree& operator=(ShapeTree&& other) = default; // iterator implements a bidirectional_iterator with // value_type = std::pair. // // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. using iterator = ShapeTreeIterator, typename std::vector::iterator, std::pair>; using const_iterator = ShapeTreeIterator, typename std::vector::const_iterator, const std::pair>; using reverse_iterator = std::reverse_iterator; using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { return iterator(&nodes_, nodes_.begin(), /*iterate_leaves_only=*/false); } iterator end() { return iterator(&nodes_, nodes_.end(), /*iterate_leaves_only=*/false); } const_iterator begin() const { return const_iterator(&nodes_, nodes_.begin(), /*iterate_leaves_only=*/false); } const_iterator end() const { return const_iterator(&nodes_, nodes_.end(), /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. reverse_iterator rbegin() { return reverse_iterator(end()); } reverse_iterator rend() { return reverse_iterator(begin()); } const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { return iterator(&nodes_, nodes_.begin(), /*iterate_leaves_only=*/true); } iterator leaf_end() { return iterator(&nodes_, nodes_.end(), /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { return const_iterator(&nodes_, nodes_.begin(), /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { return const_iterator(&nodes_, nodes_.end(), /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } tensorflow::gtl::iterator_range leaves() const { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } const_reverse_iterator leaf_rbegin() const { return const_reverse_iterator(leaf_end()); } const_reverse_iterator leaf_rend() const { return const_reverse_iterator(leaf_begin()); } // Returns an iterator pointing to the given ShapeIndex. // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); auto element_iter = nodes_.begin() + (element - &nodes_[0]); return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // // Fn : A callable of type void(const ShapeIndex& index, const T& data) // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. // data : The data value at this element. template void ForEachElement(const Fn& func) const; // Like ForEachElement, but the callable has type // // void (const ShapeIndex& index, T* data). // template void ForEachMutableElement(const Fn& func); // Like ForEach(Mutable)Element, but the callable returns a Status instead of // void. The first non-OK return value is returned by the ForEach* function. template Status ForEachElementWithStatus(const Fn& func) const; template Status ForEachMutableElementWithStatus(const Fn& func); // Maps each element to generate a new tree with the same shape. template ShapeTree Map(const std::function& func) { ShapeTree result(shape_storage_); ForEachElement([&](const ShapeIndex& index, const T& t) { *result.mutable_element(index) = func(t); }); return result; } template ShapeTree Map(const std::function& func) { ShapeTree result(shape_storage_); ForEachMutableElement([&](const ShapeIndex& index, T* t) { *result.mutable_element(index) = func(t); }); return result; } // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at // 'target_base_index'. // // Precondition: The subshape of other.shape() at index source_base_index must // be compatible with the subshape of shape() at index target_base_index. void CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& source_base_index, const ShapeIndex& target_base_index); bool operator==(const ShapeTree& other) const; bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node, Index* index); // Initialize node->children based on 'shape'. All children have // default-constructed data values. void InitChildren(const Shape& shape, Node* node, Index* index); // Returns the number of subshapes, including interior nodes, in shape. int64 CountSubshapes(const Shape& shape); // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template static Status ForEachHelper(const Fn& func, const std::vector& nodes); template static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. Node* Lookup(ShapeIndexView index); const Node* Lookup(ShapeIndexView index) const; // The nodes in this shape tree. std::vector nodes_; // Index table for node lookups. std::vector index_table_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; // Internal iterator that performs a pre-order walk. This is cheap to copy. // The iterator value_type is equivalent to a // std::pair&, similar to std::map. template class ShapeTreeIterator : public std::iterator { public: ShapeTreeIterator(ContainerType* nodes, IteratorType node, bool iterate_leaves_only) : nodes_(nodes), node_(std::move(node)), iterate_leaves_only_(iterate_leaves_only) { while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } } ShapeTreeIterator& operator++() { ++node_; while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } return *this; } ShapeTreeIterator operator++(int) { auto i = *this; ++(*this); return i; } ShapeTreeIterator& operator--() { --node_; while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { --node_; } return *this; } ShapeTreeIterator operator--(int) { auto i = *this; --(*this); return i; } bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } ValueType& operator*() { return node_->data; } ValueType* operator->() { return &node_->data; } private: ContainerType* nodes_; IteratorType node_; // True if we should not include interior nodes in our walk. const bool iterate_leaves_only_; }; template int64 ShapeTree::CountSubshapes(const Shape& shape) { int64 current_count = 1; if (ShapeUtil::IsTuple(shape)) { int64 count = ShapeUtil::TupleElementCount(shape); for (int i = 0; i < count; ++i) { current_count += CountSubshapes(shape.tuple_shapes(i)); } } return current_count; } template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; #endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); // At the end of the index_table, reserve a continuous space to hold the // children of current node. In order to enforce the invariant that all // children of a given node are placed together, we need to do the // reservation before we recurse into any of its children. int64 children_start_position = index_table_.size(); index_table_.resize(index_table_.size() + size); for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; index_table_[children_start_position + i].index = nodes_.size(); // The first child of the node in the index table is placed at the end of // the table. index_table_[children_start_position + i].children_start = index_table_.size(); nodes_.emplace_back(shape_index, init_value); InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(), &index_table_[children_start_position + i]); } } else { #ifndef NDEBUG index->children_count = 0; #endif } } template void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; #endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); // At the end of the index_table, reserve a continuous space to hold the // children of current node. In order to enforce the invariant that all // children of a given node are placed together, we need to do the // reservation before we recurse into any of its children. int64 children_start_position = index_table_.size(); index_table_.resize(index_table_.size() + size); for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; index_table_[children_start_position + i].index = nodes_.size(); // The first child of the node in the index table is placed at the end of // the table. index_table_[children_start_position + i].children_start = index_table_.size(); nodes_.emplace_back(shape_index); InitChildren(shape.tuple_shapes(i), &nodes_.back(), &index_table_[children_start_position + i]); } } else { #ifndef NDEBUG index->children_count = 0; #endif } } template ShapeTree::ShapeTree(Shape shape) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape) : shape_storage_(shape), shape_(shape_storage_.get()) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) : shape_(shape) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape, const T& init_value) : shape_storage_(shape), shape_(shape_storage_.get()) { const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); index_table_.reserve(count); index_table_.emplace_back(Index{0, 1}); InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template const T& ShapeTree::element(ShapeIndexView index) const { return Lookup(index)->data.second; } template T* ShapeTree::mutable_element(ShapeIndexView index) { return &Lookup(index)->data.second; } template internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { Index* iter = &index_table_[0]; for (const int64 i : index) { CHECK_GE(i, 0); #ifndef NDEBUG CHECK_LT(i, iter->children_count); #endif iter = &index_table_[iter->children_start + i]; } return &nodes_[iter->index]; } template const internal::ShapeTreeNode* ShapeTree::Lookup( ShapeIndexView index) const { return const_cast(this)->Lookup(index); } /* static */ template template Status ShapeTree::ForEachHelper(const Fn& func, const std::vector& nodes) { for (const auto& node : nodes) { TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } /* static */ template template Status ShapeTree::ForEachMutableHelper(const Fn& func, std::vector* nodes) { for (auto& node : *nodes) { TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, &nodes_) .IgnoreError(); } template void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& source_base_index, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), ShapeUtil::GetSubshape(other.shape(), source_base_index))); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the // subtree rooted at target_base_index. for (int i = 0; i < target_base_index.size(); ++i) { if (i >= index.size() || index[i] != target_base_index[i]) { return; } } // Construct source element index to copy from. ShapeIndex source_index = source_base_index; for (int i = target_base_index.size(); i < index.size(); ++i) { source_index.push_back(index[i]); } *data = other.element(source_index); }); } template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; ForEachElement( [this, &other, &equal](const ShapeIndex& index, const T& data) { if (data != other.element(index)) { equal = false; } }); return equal; } } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_