aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-06-11 11:44:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 11:50:03 -0700
commit68d7bcaa52a2b3307e805e2c8512a8dc47fd3272 (patch)
tree14c85c1ffe2e97b55aa386e90859e50a75fc3603 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parentc73cd1afce146aa2559cafa4ac72fe638db43860 (diff)
[XLA] Fold consecutive reduces.
PiperOrigin-RevId: 200086761
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc31
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index dc5f1b31bf..3b36939b8a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1783,6 +1783,37 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReshape(reduce->shape(), arg));
}
+
+ // If a reduce feeds a reduce with the same computation and initial value,
+ // they can be combined into a single reduce.
+ if (arg->opcode() == HloOpcode::kReduce &&
+ init_value->Identical(*arg->operand(1)) &&
+ *function == *arg->to_apply()) {
+ // Create a new reduce with the combined reduction dimensions of both
+ // reduces.
+ std::vector<int64> arg_dims = arg->dimensions();
+ std::sort(arg_dims.begin(), arg_dims.end());
+ std::vector<int64> reduce_dims = reduce->dimensions();
+ std::sort(reduce_dims.begin(), reduce_dims.end());
+ // Transform reduce_dims to the same rank as the operand of the operand.
+ for (int64 arg_dim : arg_dims) {
+ for (int64& dim : reduce_dims) {
+ if (dim >= arg_dim) {
+ ++dim;
+ }
+ }
+ }
+ std::vector<int64> new_dimensions;
+ new_dimensions.reserve(arg->dimensions().size() +
+ reduce->dimensions().size());
+ std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
+ reduce_dims.end(), std::back_inserter(new_dimensions));
+ return ReplaceWithNewInstruction(
+ reduce,
+ HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0),
+ init_value, new_dimensions, function));
+ }
+
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.