/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ #include #include #include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace internal { // Internal representation of each node in a ShapeTree. template struct ShapeTreeNode { // Data corresponding to this node. T data; // Children of this node. std::vector> children; ShapeTreeNode() = default; explicit ShapeTreeNode(const T& data) : data(data) {} ShapeTreeNode(const ShapeTreeNode& other) : data(other.data), children(other.children.size()) { for (size_t i = 0; i < children.size(); ++i) { children[i] = MakeUnique(*other.children[i]); } } ShapeTreeNode& operator=(const ShapeTreeNode& other) { if (this != &other) { data = other.data; children.resize(other.children.size()); for (size_t i = 0; i < children.size(); ++i) { children[i] = MakeUnique(*other.children[i]); } } return *this; } }; } // namespace internal // A ShapeTree is a recursive data structure which mirrors the structure of a // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) // in the shape. For array shapes, a ShapeTree trivially holds a single value of // type T. // // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a // ShapeTree is an identically structured tree with data elements of type T at // every node. I.e. the root is a tuple by definition, all interior nodes are // also tuples, and all leaves are arrays. // // Like the Shape data structure, this is a tree and tuple elements cannot be // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T // object. // // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes // it's helpful not to copy a Shape just to make a ShapeTree. In these cases, // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's // then up to you to ensure that the pointed-to Shape doesn't die or mutate // before its ShapeTree goes away. template class ShapeTree { public: // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} // Create ShapeTree with the given shape, and default-constructed T values for // all nodes. // // The version that takes a pointer may be cheaper because it doesn't require // any Shape copies, but then it's up to you to ensure that the pointer stays // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const ShapeTree& other) : root_(other.root_), shape_storage_(other.shape_storage_) { // Fix up internal pointer if necessary. if (shape_storage_) { CHECK_EQ(other.shape_, &*other.shape_storage_); shape_ = &*shape_storage_; } else { shape_ = other.shape_; } } ShapeTree& operator=(const ShapeTree& other) { root_ = other.root_; shape_storage_ = other.shape_storage_; // Fix up internal pointer if necessary. if (shape_storage_) { CHECK_EQ(other.shape_, &*other.shape_storage_); shape_ = &*shape_storage_; } else { shape_ = other.shape_; } return *this; } // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). const T& element(const ShapeIndex& index) const; T* mutable_element(const ShapeIndex& index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->children.empty(); } // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // // Fn : A callable of type void(const ShapeIndex& index, const T& data) // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. // data : The data value at this elemnt. template void ForEachElement(const Fn& func) const; // Like ForEachElement, but the callable has type // // void (const ShapeIndex& index, T* data). // template void ForEachMutableElement(const Fn& func); // Like ForEach(Mutable)Element, but the callable returns a Status instead of // void. The first non-OK return value is returned by the ForEach* function. template Status ForEachElementWithStatus(const Fn& func) const; template Status ForEachMutableElementWithStatus(const Fn& func); // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at // 'target_base_index'. // // Precondition: The subshape of other.shape() at index source_base_index must // be compatible with the subshape of shape() at index target_base_index. void CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& source_base_index, const ShapeIndex& target_base_index); bool operator==(const ShapeTree& other) const; bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: using Node = internal::ShapeTreeNode; // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); // Initialize node->children based on 'shape'. All children have // default-constructed data values. void InitChildren(const Shape& shape, Node* node); // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template static Status ForEachHelper(const Fn& func, const Node& node, ShapeIndex* index); template static Status ForEachMutableHelper(const Fn& func, Node* node, ShapeIndex* index); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); const Node* Lookup(const ShapeIndex& index) const; // The root node, which contains all other nodes. Node root_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullopt. tensorflow::gtl::optional shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either a pointer into // shape_storage_ or the Shape pointer passed to our constructor. const Shape* shape_; }; template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { node->children.emplace_back(new Node(init_value)); InitChildren(shape.tuple_shapes(i), init_value, node->children.back().get()); } } } template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { node->children.emplace_back(new Node()); InitChildren(shape.tuple_shapes(i), node->children.back().get()); } } } template ShapeTree::ShapeTree(Shape shape) : root_(), shape_storage_(std::move(shape)), shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(&*shape_storage_); InitChildren(*shape_, &root_); } template ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { InitChildren(*shape_, &root_); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) : root_(init_value), shape_storage_(std::move(shape)), shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(&*shape_storage_); InitChildren(*shape_, init_value, &root_); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) : root_(init_value), shape_(shape) { InitChildren(*shape_, init_value, &root_); } template const T& ShapeTree::element(const ShapeIndex& index) const { return Lookup(index)->data; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { return &Lookup(index)->data; } template internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { Node* node = &root_; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); node = node->children[i].get(); } return node; } template const internal::ShapeTreeNode* ShapeTree::Lookup( const ShapeIndex& index) const { return const_cast(this)->Lookup(index); } /* static */ template template Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, node.data)); for (int64 i = 0; i < node.children.size(); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); index->pop_back(); } return Status::OK(); } /* static */ template template Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, &node->data)); for (int64 i = 0; i < node->children.size(); ++i) { index->push_back(i); TF_RETURN_IF_ERROR( ForEachMutableHelper(func, node->children[i].get(), index)); index->pop_back(); } return Status::OK(); } template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { ShapeIndex index; return ForEachHelper(func, root_, &index); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { ShapeIndex index; return ForEachMutableHelper(func, &root_, &index); } template template void ShapeTree::ForEachElement(const Fn& func) const { ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, root_, &index) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, &root_, &index) .IgnoreError(); } template void ShapeTree::CopySubtreeFrom(const ShapeTree& other, const ShapeIndex& source_base_index, const ShapeIndex& target_base_index) { CHECK(ShapeUtil::Compatible( ShapeUtil::GetSubshape(shape(), target_base_index), ShapeUtil::GetSubshape(other.shape(), source_base_index))); ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( const ShapeIndex& index, T* data) { // Copy the data element only if index is in the // subtree rooted at target_base_index. for (int i = 0; i < target_base_index.size(); ++i) { if (i >= index.size() || index[i] != target_base_index[i]) { return; } } // Construct source element index to copy from. ShapeIndex source_index = source_base_index; for (int i = target_base_index.size(); i < index.size(); ++i) { source_index.push_back(index[i]); } *data = other.element(source_index); }); } template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; ForEachElement( [this, &other, &equal](const ShapeIndex& index, const T& data) { if (data != other.element(index)) { equal = false; } }); return equal; } } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_