aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-25 17:22:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-25 17:24:56 -0700
commit8fcc95ebf42ed8eea543ec2edf1a1ed1c62ca7e8 (patch)
tree66758f11ac719ff941050e4bb5182c1dfea2c4c0 /tensorflow
parent06717b77e05bd602d10fe40f4519dbb105fabd5c (diff)
Enable while loop constant sinking for GPU
To avoid keeping constants in while loop bodies after optimization (where they may cause extra copies) we run a late pass of LICM that has been asked to hoist constants when it can. PiperOrigin-RevId: 198126497
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc12
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc27
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h16
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc73
6 files changed, 127 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 749873e560..2976bdb9e9 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2862,6 +2862,7 @@ tf_cc_test(
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index ffb1af2d87..2794930248 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -546,6 +546,8 @@ cc_library(
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
+ "//tensorflow/compiler/xla/service:while_loop_constant_sinking",
+ "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/compiler/xla/service:zero_sized_hlo_elimination",
"//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 5ef422c90b..b857219807 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -73,6 +73,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
+#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -176,6 +178,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
pass.AddPass<TupleSimplifier>();
+ pass.AddPass<WhileLoopConstantSinking>();
pass.AddPass<WhileLoopSimplifier>();
pass.AddPass<HloDCE>();
pass.AddPass<ReshapeMover>();
@@ -274,6 +277,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
}
}
+
+ {
+ // Do an aggressive LICM pass over while loops. In particular, this hoists
+ // constants that were sunk by WhileLoopConstantSinking. Leaving them in
+ // the while loop may result in unnecessary copies.
+ HloPassPipeline pipeline("while-loop-licm");
+ pipeline.AddPass<WhileLoopInvariantCodeMotion>(true);
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index 321fdeb1ea..09ddcffb22 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy(
// Returns true if `instruction` is worth hoisting only if it lets us hoist some
// instruction using it. The rationale is that hoisting these instructions will
// prevent simplification and fusion in the while body.
-static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
+bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually(
+ const HloInstruction& instruction) {
switch (instruction.opcode()) {
default:
return false;
+ case HloOpcode::kConstant:
+ return !hoist_constants_;
+
case HloOpcode::kBitcast:
case HloOpcode::kBroadcast:
- case HloOpcode::kConstant:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
@@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) {
}
}
-static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+StatusOr<bool>
+WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
HloInstruction* while_instr) {
auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false);
@@ -161,12 +165,16 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
}
- if (unhoisted_invariant_instructions.empty()) {
+ if (unhoisted_invariant_instructions.empty() && !hoist_constants_) {
// There are no obviously loop invariant elements in the state being
// threaded through the while loop so give up. In theory this precondition
// is too strong -- we could have code that e.g. permutes the elements in
// the while state but uses a select to pick the same value on every
// iteration.
+ //
+ // If we were asked to hoist constants, we need to scan the while body for
+ // constants even if we didn't find any loop invariant values in the while
+ // state tuple.
return false;
}
@@ -243,6 +251,9 @@ static StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
}
StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
+ VLOG(2) << "HLO module before WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+
bool changed = false;
std::vector<HloInstruction*> while_instrs;
for (auto* comp : module->computations()) {
@@ -270,6 +281,14 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) {
TryHoistingInvariantInstructionsFromWhileBody(while_instr));
changed |= result;
}
+
+ if (changed) {
+ VLOG(2) << "HLO module after WhileLoopConstantSinking:";
+ XLA_VLOG_LINES(2, module->ToString());
+ } else {
+ VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
+ }
+
return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 8c4b765b00..8e6cc87875 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -27,12 +27,28 @@ namespace xla {
class WhileLoopInvariantCodeMotion : public HloPassInterface {
public:
+ // If `hoist_constants` is true then constants are always hoisted out of while
+ // loop bodies. Otherwise they are only hoisted out if they enable other
+ // non-trivial computations to be hoisted out.
+ //
+ // Setting `hoist_constants` to false can be help if LICM is run in the mid
+ // level HLO pipeline because hoisting constants out of while loop bodies can
+ // break optimizations like constant folding.
+ explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false)
+ : hoist_constants_(hoist_constants) {}
~WhileLoopInvariantCodeMotion() override = default;
tensorflow::StringPiece name() const override {
return "while-loop-invariant-code-motion";
}
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool NotWorthHoistingIndividually(const HloInstruction& instruction);
+ StatusOr<bool> TryHoistingInvariantInstructionsFromWhileBody(
+ HloInstruction* while_instr);
+
+ bool hoist_constants_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index 799340fda9..e1ec12192f 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) {
EXPECT_FALSE(simplified_loop);
}
+const char* const kConstantHoistingTestCase = R"(
+HloModule ModuleWithWhile
+
+body {
+ p_body = (f32[2]{0}) parameter(0)
+ p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
+ const = f32[2]{0} constant({3, 4})
+ add.0 = f32[2]{0} add(p_body.1, const)
+ ROOT root = (f32[2]{0}) tuple(add.0)
+}
+
+condition {
+ p_cond = (f32[2]{0}) parameter(0)
+ ROOT result = pred[] constant(true)
+}
+
+ENTRY entry {
+ const_0 = f32[2]{0} constant({1, 2})
+ while_init = (f32[2]{0}) tuple(const_0)
+ ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body
+}
+)";
+
+TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool simplified_loop,
+ WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module()));
+ EXPECT_TRUE(simplified_loop);
+
+ HloComputation* while_body = module().GetComputationWithName("wide.body");
+ ASSERT_NE(while_body, nullptr);
+
+ // We expect the while body to be the equivalent of:
+ //
+ // wide.body {
+ // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0)
+ // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0
+ // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1)
+ // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0
+ // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7)
+ // tuple.3 = (f32[2]{0}) tuple(add.1)
+ // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0
+ // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1
+ // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8,
+ // get-tuple-element.9)
+ // }
+
+ auto wide_param_1 = op::Parameter(0);
+ auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0);
+ auto tuple_1 = op::Tuple(get_tuple_element_1);
+ auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0);
+ auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1);
+ auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7);
+ auto tuple_3 = op::Tuple(add_1);
+ auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0);
+ auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1);
+ auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9);
+
+ EXPECT_THAT(while_body->root_instruction(), tuple_4);
+}
+
+TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) {
+ ParseAndVerifyModule(kConstantHoistingTestCase);
+
+ TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+ WhileLoopInvariantCodeMotion{}.Run(&module()));
+ EXPECT_FALSE(simplified_loop);
+}
+
} // namespace
} // namespace xla