aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index cd735256b8..892d0d7b54 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -2007,5 +2007,46 @@ ENTRY TestComputation {
InsertCopies(module.get());
}
+TEST_F(CopyInsertionTest, NestedWhiles) {
+ // Verify that only no unnecessary copies remain after copy insertion for
+ // trivial nested whiles (b/112472605).
+ const string& hlo_string = R"(
+HloModule TestModule
+
+cond.inner {
+ ROOT param.cond.inner = pred[] parameter(0)
+}
+
+body.inner {
+ param.body.inner = pred[] parameter(0)
+ ROOT neg = pred[] negate(param.body.inner)
+}
+
+cond.outer {
+ ROOT param.cond.outer = pred[] parameter(0)
+}
+
+body.outer {
+ param.cond.outer = pred[] parameter(0)
+ ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
+}
+
+ENTRY TestComputation {
+ entry_param = pred[] parameter(0)
+ ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should only be a single copy inserted, and it's in the entry
+ // computation.
+ EXPECT_EQ(CountCopies(*module), 1);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::While(op::Copy(op::Parameter())));
+}
+
} // namespace
} // namespace xla