aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_tree.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree.h')
-rw-r--r--tensorflow/compiler/xla/shape_tree.h140
1 files changed, 110 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 4aacc87b78..c74dd648ad 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -44,10 +44,6 @@ struct ShapeTreeNode {
// Data corresponding to this node.
std::pair<ShapeIndex, T> data;
- // Children of this node, as indices into the container's nodes_ array.
- std::vector<size_t> children;
-
- // Tells whether this is a leaf node.
bool is_leaf = true;
explicit ShapeTreeNode(ShapeIndex index)
@@ -56,6 +52,20 @@ struct ShapeTreeNode {
: data(std::move(index), std::move(data)) {}
};
+// Internal representation of an index table entry.
+struct IndexTableEntry {
+ // Index of the node in the ShapeTreeNode vector.
+ uint32 index;
+ // Index of the first child in a IndexTableEntry vector. In the index
+ // table all children entries for a given node will be placed next to each
+ // other. This allows us to use a single field to index them.
+ uint32 children_start;
+#ifndef NDEBUG
+ // Number of children, used for bounds checking.
+ uint32 children_count;
+#endif
+};
+
} // namespace internal
template <typename ContainerType, typename IteratorType, typename ValueType>
@@ -84,6 +94,7 @@ template <typename T>
class ShapeTree {
public:
using Node = internal::ShapeTreeNode<T>;
+ using Index = internal::IndexTableEntry;
// Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
@@ -267,11 +278,12 @@ class ShapeTree {
private:
// Initialize node->children based on 'shape'. All children are assigned the
// the given 'init_value'.
- void InitChildren(const Shape& shape, const T& init_value, Node* node);
+ void InitChildren(const Shape& shape, const T& init_value, Node* node,
+ Index* index);
// Initialize node->children based on 'shape'. All children have
// default-constructed data values.
- void InitChildren(const Shape& shape, Node* node);
+ void InitChildren(const Shape& shape, Node* node, Index* index);
// Returns the number of subshapes, including interior nodes, in shape.
int64 CountSubshapes(const Shape& shape);
@@ -291,6 +303,9 @@ class ShapeTree {
// The nodes in this shape tree.
std::vector<Node> nodes_;
+ // Index table for node lookups.
+ std::vector<Index> index_table_;
+
// If we own our Shape, this field contains it, and shape_ is a pointer into
// here. Otherwise if we don't own our shape, this is nullptr.
std::shared_ptr<Shape> shape_storage_;
@@ -373,36 +388,74 @@ int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
template <typename T>
void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
- Node* node) {
+ Node* node, Index* index) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
- node->children.reserve(size);
+#ifndef NDEBUG
+ index->children_count = size;
+#endif
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
+
+ // At the end of the index_table, reserve a continuous space to hold the
+ // children of current node. In order to enforce the invariant that all
+ // children of a given node are placed together, we need to do the
+ // reservation before we recurse into any of its children.
+ int64 children_start_position = index_table_.size();
+ index_table_.resize(index_table_.size() + size);
+
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- node->children.push_back(nodes_.size());
+ index_table_[children_start_position + i].index = nodes_.size();
+ // The first child of the node in the index table is placed at the end of
+ // the table.
+ index_table_[children_start_position + i].children_start =
+ index_table_.size();
nodes_.emplace_back(shape_index, init_value);
- InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back());
+ InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
+ &index_table_[children_start_position + i]);
}
+ } else {
+#ifndef NDEBUG
+ index->children_count = 0;
+#endif
}
}
template <typename T>
-void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
+void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
- node->children.reserve(size);
+#ifndef NDEBUG
+ index->children_count = size;
+#endif
node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
+
+ // At the end of the index_table, reserve a continuous space to hold the
+ // children of current node. In order to enforce the invariant that all
+ // children of a given node are placed together, we need to do the
+ // reservation before we recurse into any of its children.
+ int64 children_start_position = index_table_.size();
+ index_table_.resize(index_table_.size() + size);
+
for (int i = 0; i < size; ++i) {
shape_index[shape_index.size() - 1] = i;
- node->children.push_back(nodes_.size());
+ index_table_[children_start_position + i].index = nodes_.size();
+ // The first child of the node in the index table is placed at the end of
+ // the table.
+ index_table_[children_start_position + i].children_start =
+ index_table_.size();
nodes_.emplace_back(shape_index);
- InitChildren(shape.tuple_shapes(i), &nodes_.back());
+ InitChildren(shape.tuple_shapes(i), &nodes_.back(),
+ &index_table_[children_start_position + i]);
}
+ } else {
+#ifndef NDEBUG
+ index->children_count = 0;
+#endif
}
}
@@ -413,24 +466,36 @@ ShapeTree<T>::ShapeTree(Shape shape)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{});
- InitChildren(*shape_, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, &nodes_[0], &index_table_[0]);
}
template <typename T>
@@ -440,26 +505,38 @@ ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
// The shape_ field is just used to hold the structure of the shape.
// It should not be relied upon to store layout information.
LayoutUtil::ClearLayout(shape_storage_.get());
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
: shape_(shape) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
const T& init_value)
: shape_storage_(shape), shape_(shape_storage_.get()) {
- nodes_.reserve(CountSubshapes(*shape_));
+ const int64 count = CountSubshapes(*shape_);
+ nodes_.reserve(count);
nodes_.emplace_back(ShapeIndex{}, init_value);
- InitChildren(*shape_, init_value, &nodes_[0]);
+
+ index_table_.reserve(count);
+ index_table_.emplace_back(Index{0, 1});
+ InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
}
template <typename T>
@@ -474,13 +551,16 @@ T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
template <typename T>
internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
- Node* node = &nodes_[0];
+ Index* iter = &index_table_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
- CHECK_LT(i, node->children.size());
- node = &nodes_[node->children[i]];
+#ifndef NDEBUG
+ CHECK_LT(i, iter->children_count);
+#endif
+ iter = &index_table_[iter->children_start + i];
}
- return node;
+
+ return &nodes_[iter->index];
}
template <typename T>