diff options
author | 2018-05-25 17:22:11 -0700 | |
---|---|---|
committer | 2018-05-25 17:24:56 -0700 | |
commit | 8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch) | |
tree | 66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow | |
parent | 06717b77e05bd602d10fe40f4519dbb105fabd5c (diff) |
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they
may cause extra copies) we run a late pass of LICM that has been asked to hoist
constants when it can.
PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow')
6 files changed, 127 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 749873e560..2976bdb9e9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2862,6 +2862,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ffb1af2d87..2794930248 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -546,6 +546,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 5ef422c90b..b857219807 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -73,6 +73,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -176,6 +178,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass<TupleSimplifier>(); + pass.AddPass<WhileLoopConstantSinking>(); pass.AddPass<WhileLoopSimplifier>(); pass.AddPass<HloDCE>(); pass.AddPass<ReshapeMover>(); @@ -274,6 +277,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass<WhileLoopInvariantCodeMotion>(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 321fdeb1ea..09ddcffb22 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( +StatusOr<bool> +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( } } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( } StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector<HloInstruction*> while_instrs; for (auto* comp : module->computations()) { @@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b00..8e6cc87875 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr<bool> Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda9..e1ec12192f 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla |