aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2017-03-27 13:22:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 14:48:17 -0700
commit4eb07180f4a0fce2cb265010233bace014e2c026 (patch)
treed7513bbfcbcb062d6ffa086b9739f2be5292f047
parent5c65a93856028ea9717c769bbc96338e1ea05c76 (diff)
[XLA:CPU] Disable Dot strength reduction due to performance regressions. This
issue was reported by a customer that uses XLA:CPU in production. Change: 151370263
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc34
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h9
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc6
3 files changed, 34 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index d171a8dfff..235299e575 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -165,15 +165,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(
HloComputation* computation, bool is_layout_sensitive,
- AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback);
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
+ bool enable_dot_simplification);
private:
explicit AlgebraicSimplifierVisitor(
HloComputation* computation, bool is_layout_sensitive,
- AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback)
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
+ bool enable_dot_simplification)
: computation_(computation),
is_layout_sensitive_(is_layout_sensitive),
- valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
+ valid_bitcast_callback_(std::move(valid_bitcast_callback)),
+ enable_dot_simplification_(enable_dot_simplification) {}
// Convenience method for replacing an instruction with a bitcast.
void ReplaceWithBitcast(HloInstruction* instruction);
@@ -217,13 +220,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Callback used to determine if a bitcast is possible.
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_;
+
+ // Disable dot simplication on platforms where it causes a slowdown.
+ bool enable_dot_simplification_;
};
bool AlgebraicSimplifierVisitor::Run(
HloComputation* computation, bool is_layout_sensitive,
- AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) {
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
+ bool enable_dot_simplification) {
AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive,
- std::move(valid_bitcast_callback));
+ std::move(valid_bitcast_callback),
+ enable_dot_simplification);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@@ -328,6 +336,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
HloInstruction* lhs,
HloInstruction* rhs) {
+ if (!enable_dot_simplification_) {
+ return Status::OK();
+ }
// Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
// below.
if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
@@ -1320,12 +1331,13 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
- bool changed = std::any_of(
- module->computations().begin(), module->computations().end(),
- [=](const std::unique_ptr<HloComputation>& computation) {
- return AlgebraicSimplifierVisitor::Run(
- computation.get(), is_layout_sensitive_, valid_bitcast_callback_);
- });
+ bool changed =
+ std::any_of(module->computations().begin(), module->computations().end(),
+ [=](const std::unique_ptr<HloComputation>& computation) {
+ return AlgebraicSimplifierVisitor::Run(
+ computation.get(), is_layout_sensitive_,
+ valid_bitcast_callback_, enable_dot_simplification_);
+ });
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), after:\n" + module->ToString());
return changed;
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index d10d1edc1d..5d59a27c71 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -38,9 +38,11 @@ class AlgebraicSimplifier : public HloPassInterface {
// returns true, then the pass will replace reshapes and tranposes with
// bitcasts.
AlgebraicSimplifier(bool is_layout_sensitive,
- ValidBitcastCallback valid_bitcast_callback)
+ ValidBitcastCallback valid_bitcast_callback,
+ bool enable_dot_simplification = true)
: is_layout_sensitive_(is_layout_sensitive),
- valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
+ valid_bitcast_callback_(std::move(valid_bitcast_callback)),
+ enable_dot_simplification_(enable_dot_simplification) {}
~AlgebraicSimplifier() override {}
tensorflow::StringPiece name() const override { return "algsimp"; }
@@ -51,6 +53,9 @@ class AlgebraicSimplifier : public HloPassInterface {
private:
bool is_layout_sensitive_;
ValidBitcastCallback valid_bitcast_callback_;
+
+ // Enable dot simplication on platforms where it is profitable.
+ bool enable_dot_simplification_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 3c27809db3..786c9847dc 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -225,7 +225,8 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
dump_hlo);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
- [](const Shape&, const Shape&) { return false; });
+ [](const Shape&, const Shape&) { return false; },
+ /*enable_dot_simplification=*/false);
pass.AddPass<ReshapeMover>();
pass.AddPass<HloConstantFolding>();
}
@@ -239,7 +240,8 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
- [](const Shape&, const Shape&) { return true; });
+ [](const Shape&, const Shape&) { return true; },
+ /*enable_dot_simplification=*/false);
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
// Outline ops in the entry computation into calls to subcomputations.
legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();