diff options
author | Blake Hechtman <blakehechtman@google.com> | 2018-06-11 11:44:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-11 11:50:03 -0700 |
commit | 68d7bcaa52a2b3307e805e2c8512a8dc47fd3272 (patch) | |
tree | 14c85c1ffe2e97b55aa386e90859e50a75fc3603 /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | c73cd1afce146aa2559cafa4ac72fe638db43860 (diff) |
[XLA] Fold consecutive reduces.
PiperOrigin-RevId: 200086761
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index dc5f1b31bf..3b36939b8a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1783,6 +1783,37 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector<int64> arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector<int64> reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector<int64> new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. |