aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-25 17:22:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 17:24:56 -0700
commit8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch)
tree66758f11ac719ff941050e4bb5182c1dfea2c4c0
parent06717b77e05bd602d10fe40f4519dbb105fabd5c (diff)
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they may cause extra copies) we run a late pass of LICM that has been asked to hoist constants when it can. PiperOrigin-RevId: 198126497
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc12
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc27
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h16
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc73
6 files changed, 127 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 749873e560..2976bdb9e9 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2862,6 +2862,7 @@ tf_cc_test(
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index ffb1af2d87..2794930248 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -546,6 +546,8 @@ cc_library(
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
+ "//tensorflow/compiler/xla/service:while_loop_constant_sinking",
+ "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5ef422c90b..b857219807 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -73,6 +73,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -176,6 +178,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
pass.AddPass<TupleSimplifier>();
+ pass.AddPass<WhileLoopConstantSinking>();
pass.AddPass<WhileLoopSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<ReshapeMover>();
@@ -274,6 +277,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
}
}
+
+ {
+ // Do an aggressive LICM pass over while loops. In particular, this hoists
+ // constants that were sunk by WhileLoopConstantSinking. Leaving them in
+ // the while loop may result in unnecessary copies.
+ HloPassPipeline pipeline("while-loop-licm");
+ pipeline.AddPass<WhileLoopInvariantCodeMotion>(true);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 321fdeb1ea..09ddcffb22 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy(
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
-static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
+bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
+ const HloInstruction& instruction) {
switch (instruction.opcode()) {
default:
return false;
+ case HloOpcode::kConstant:
+ return !hoist_constants_;
+
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
- case HloOpcode::kConstant:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
@@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
}
}
-static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+StatusOr<bool>
+WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
HloInstruction* while_instr) {
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
@@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
}
- if (unhoisted_invariant_instructions.empty()) {
+ if (unhoisted_invariant_instructions.empty() && !hoist_constants_) {
// There are no obviously loop invariant elements in the state being
// threaded through the while loop so give up. In theory this precondition
// is too strong -- we could have code that e.g. permutes the elements in
// the while state but uses a select to pick the same value on every
// iteration.
+ //
+ // If we were asked to hoist constants, we need to scan the while body for
+ // constants even if we didn't find any loop invariant values in the while
+ // state tuple.
return false;
}
@@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
+ VLOG(2) << "HLO module before WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
@@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
changed |= result;
}
+
+ if (changed) {
+ VLOG(2) << "HLO module after WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+ } else {
+ VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
+ }
+
return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8c4b765b00..8e6cc87875 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -27,12 +27,28 @@ namespace xla {
class WhileLoopInvariantCodeMotion : public HloPassInterface {
public:
+ // If `hoist_constants` is true then constants are always hoisted out of while
+ // loop bodies. Otherwise they are only hoisted out if they enable other
+ // non-trivial computations to be hoisted out.
+ //
+ // Setting `hoist_constants` to false can be help if LICM is run in the mid
+ // level HLO pipeline because hoisting constants out of while loop bodies can
+ // break optimizations like constant folding.
+ explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false)
+ : hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
tensorflow::StringPiece name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool NotWorthHoistingIndividually(const HloInstruction& instruction);
+ StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+ HloInstruction* while_instr);
+
+ bool hoist_constants_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 799340fda9..e1ec12192f 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
EXPECT_FALSE(simplified_loop);
}
+const char* const kConstantHoistingTestCase = R"(
+HloModule ModuleWithWhile
+
+body {
+ p_body = (f32[2]{0}) parameter(0)
+ p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
+ const = f32[2]{0} constant({3, 4})
+ add.0 = f32[2]{0} add(p_body.1, const)
+ ROOT root = (f32[2]{0}) tuple(add.0)
+}
+
+condition {
+ p_cond = (f32[2]{0}) parameter(0)
+ ROOT result = pred[] constant(true)
+}
+
+ENTRY entry {
+ const_0 = f32[2]{0} constant({1, 2})
+ while_init = (f32[2]{0}) tuple(const_0)
+ ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body
+}
+)";
+
+TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool simplified_loop,
+ WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module()));
+ EXPECT_TRUE(simplified_loop);
+
+ HloComputation* while_body = module().GetComputationWithName("wide.body");
+ ASSERT_NE(while_body, nullptr);
+
+ // We expect the while body to be the equivalent of:
+ //
+ // wide.body {
+ // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0)
+ // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0
+ // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1)
+ // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0
+ // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7)
+ // tuple.3 = (f32[2]{0}) tuple(add.1)
+ // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0
+ // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8,
+ // get-tuple-element.9)
+ // }
+
+ auto wide_param_1 = op::Parameter(0);
+ auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0);
+ auto tuple_1 = op::Tuple(get_tuple_element_1);
+ auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0);
+ auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1);
+ auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7);
+ auto tuple_3 = op::Tuple(add_1);
+ auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0);
+ auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1);
+ auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9);
+
+ EXPECT_THAT(while_body->root_instruction(), tuple_4);
+}
+
+TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+ WhileLoopInvariantCodeMotion{}.Run(&module()));
+ EXPECT_FALSE(simplified_loop);
+}
+
} // namespace
} // namespace xla