diff options
author | 2018-09-12 20:32:37 -0700 | |
---|---|---|
committer | 2018-09-12 20:36:47 -0700 | |
commit | f4d8442e13356ab645446c9f4a9b3b6cedddcd63 (patch) | |
tree | 0538dabff85c0cd8a64be7bc0a589482bd7a859c /tensorflow/compiler/xla/service/hlo_module_dce_test.cc | |
parent | f03e8e0b9b149f95003099937dd35a220e3dfc95 (diff) |
Do not DCE while bodies which have IO operations.
PiperOrigin-RevId: 212750173
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_dce_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_dce_test.cc | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e490..d025edbb9c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,39 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + loop_var.1 = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + loop_var.2 = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + ROOT while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla |