diff options
author | Blake Hechtman <blakehechtman@google.com> | 2017-03-27 13:22:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 14:48:17 -0700 |
commit | 4eb07180f4a0fce2cb265010233bace014e2c026 (patch) | |
tree | d7513bbfcbcb062d6ffa086b9739f2be5292f047 | |
parent | 5c65a93856028ea9717c769bbc96338e1ea05c76 (diff) |
[XLA:CPU] Disable Dot strength reduction due to performance regressions. This
issue was reported by a customer that uses XLA:CPU in production.
Change: 151370263
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 34 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 6 |
3 files changed, 34 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index d171a8dfff..235299e575 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -165,15 +165,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback); + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + valid_bitcast_callback_(std::move(valid_bitcast_callback)), + enable_dot_simplification_(enable_dot_simplification) {} // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -217,13 +220,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; + + // Disable dot simplication on platforms where it causes a slowdown. + bool enable_dot_simplification_; }; bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) { + AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification) { AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, - std::move(valid_bitcast_callback)); + std::move(valid_bitcast_callback), + enable_dot_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -328,6 +336,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) { + if (!enable_dot_simplification_) { + return Status::OK(); + } // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or // below. if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || @@ -1320,12 +1331,13 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); - bool changed = std::any_of( - module->computations().begin(), module->computations().end(), - [=](const std::unique_ptr<HloComputation>& computation) { - return AlgebraicSimplifierVisitor::Run( - computation.get(), is_layout_sensitive_, valid_bitcast_callback_); - }); + bool changed = + std::any_of(module->computations().begin(), module->computations().end(), + [=](const std::unique_ptr<HloComputation>& computation) { + return AlgebraicSimplifierVisitor::Run( + computation.get(), is_layout_sensitive_, + valid_bitcast_callback_, enable_dot_simplification_); + }); XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); return changed; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d10d1edc1d..5d59a27c71 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -38,9 +38,11 @@ class AlgebraicSimplifier : public HloPassInterface { // returns true, then the pass will replace reshapes and tranposes with // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback) + ValidBitcastCallback valid_bitcast_callback, + bool enable_dot_simplification = true) : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + valid_bitcast_callback_(std::move(valid_bitcast_callback)), + enable_dot_simplification_(enable_dot_simplification) {} ~AlgebraicSimplifier() override {} tensorflow::StringPiece name() const override { return "algsimp"; } @@ -51,6 +53,9 @@ class AlgebraicSimplifier : public HloPassInterface { private: bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; + + // Enable dot simplication on platforms where it is profitable. + bool enable_dot_simplification_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3c27809db3..786c9847dc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -225,7 +225,8 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, dump_hlo); pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }); + [](const Shape&, const Shape&) { return false; }, + /*enable_dot_simplification=*/false); pass.AddPass<ReshapeMover>(); pass.AddPass<HloConstantFolding>(); } @@ -239,7 +240,8 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }); + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_simplification=*/false); pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); // Outline ops in the entry computation into calls to subcomputations. legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); |