aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-09-29 14:02:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-29 14:13:20 -0700
commit7ec44b7541faabe781bb9b6113534452cda7598c (patch)
treef35a3a706553bd2b044a3944c4f0368bd34fa736 /tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
parentb1f00fc15047967698618a8e9218fac6c2278414 (diff)
[XLA] Make HloModule::computations() return raw pointers.
Like HloComputation::instructions(), HloModule::computations() used to return a list of unique_ptrs. But this is an implementation detail that shouldn't be leaked into the public API. This patch also adds HloModule::MakeNonFusionComputations(), because many of the callers of computations() went on to filter out all the fusion computations. It would be possible to implement MakeNonFusionComputations() "in place" using a filtering iterator, but I don't think it's necessary -- we never have *that* many computations, and since many callers go on to copy the list of non-fusion computations, making it unconditionally a copy is simpler and avoids a footgun. PiperOrigin-RevId: 170529051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
index 33b3634cfc..7b601f9a95 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
@@ -85,7 +85,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, module->computations().size());
+ EXPECT_EQ(3, module->computation_count());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
hlo_graph_dumper::DumpGraph(*module->entry_computation(),
@@ -98,7 +98,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
"after unification",
module->config().debug_options());
}
- EXPECT_EQ(2, module->computations().size());
+ EXPECT_EQ(2, module->computation_count());
EXPECT_EQ(x->to_apply(), y->to_apply());
}
@@ -124,7 +124,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, module->computations().size());
+ EXPECT_EQ(3, module->computation_count());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
hlo_graph_dumper::DumpGraph(*module->entry_computation(),
@@ -137,7 +137,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
"after unification",
module->config().debug_options());
}
- EXPECT_EQ(2, module->computations().size());
+ EXPECT_EQ(2, module->computation_count());
EXPECT_EQ(x->to_apply(), y->to_apply());
}
@@ -164,7 +164,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, module->computations().size());
+ EXPECT_EQ(3, module->computation_count());
EXPECT_NE(x->to_apply(), y->to_apply());
if (VLOG_IS_ON(1)) {
hlo_graph_dumper::DumpGraph(*module->entry_computation(),
@@ -177,7 +177,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
"after unification",
module->config().debug_options());
}
- EXPECT_EQ(3, module->computations().size());
+ EXPECT_EQ(3, module->computation_count());
EXPECT_NE(x->to_apply(), y->to_apply());
}
@@ -201,8 +201,8 @@ TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
}
EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie());
- EXPECT_EQ(1, module->computations().size());
- EXPECT_EQ(module->computations().front().get(), module->entry_computation());
+ EXPECT_EQ(1, module->computation_count());
+ EXPECT_EQ(*module->computations().begin(), module->entry_computation());
}
} // namespace xla