aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_tree_test.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-06-28 21:50:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-28 21:54:41 -0700
commit181816fe27684585bface6e2260a0ff1c890e3e9 (patch)
tree36dfeab13e57a6a5f37f34afabae7f8aafe37108 /tensorflow/compiler/xla/shape_tree_test.cc
parente6a45475735ee8a31c7d6c8e28e9164cda7d1853 (diff)
Speed up TuplePointsToAnalysis.
This analysis is one of the most expensive parts of the HLO optimization pipeline. - Avoid one or two unnecessary hashtable lookups in PopulateDefinedBuffersAndAliases. - Add a mode to ShapeTree wherein we avoid copying Shapes. - Use templated functors rather than std::function in ShapeTree's iterators, thus avoiding the overhead of std::function. PiperOrigin-RevId: 160487485
Diffstat (limited to 'tensorflow/compiler/xla/shape_tree_test.cc')
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index afc3a2b2a3..3a5db1b3a6 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -365,5 +365,31 @@ TEST_F(ShapeTreeTest, OperatorEquals) {
}
}
+TEST_F(ShapeTreeTest, ConstructWithPointerToShape) {
+ // Construct a ShapeTree using a pointer to a shape, rather than a reference
+ // to a shape. This constructor is an optimization to let us avoid
+ // constructing and destroying temporary shapes when we have many ShapeTrees.
+ ShapeTree<int> t(&nested_tuple_shape_, 42);
+ int num_nodes = 0;
+ t.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) {
+ EXPECT_EQ(42, data);
+ ++num_nodes;
+ });
+ EXPECT_EQ(10, num_nodes);
+}
+
+TEST_F(ShapeTreeTest, CopyWithPointerToShape) {
+ ShapeTree<int> source(&nested_tuple_shape_, 0);
+ ShapeTree<int> dest(source);
+ EXPECT_EQ(&dest.shape(), &nested_tuple_shape_);
+}
+
+TEST_F(ShapeTreeTest, CopyAssignWithPointerToShape) {
+ ShapeTree<int> source(&nested_tuple_shape_, 0);
+ ShapeTree<int> dest;
+ dest = source;
+ EXPECT_EQ(&dest.shape(), &nested_tuple_shape_);
+}
+
} // namespace
} // namespace xla