aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-12 20:32:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 20:36:47 -0700
commitf4d8442e13356ab645446c9f4a9b3b6cedddcd63 (patch)
tree0538dabff85c0cd8a64be7bc0a589482bd7a859c /tensorflow/compiler/xla/service/hlo_module_dce_test.cc
parentf03e8e0b9b149f95003099937dd35a220e3dfc95 (diff)
Do not DCE while bodies which have IO operations.
PiperOrigin-RevId: 212750173
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_dce_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e490..d025edbb9c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,39 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ loop_var.1 = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ loop_var.2 = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ ROOT while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
} // namespace
} // namespace xla