diff options
author | Dimitris Vardoulakis <dimvar@google.com> | 2018-03-22 14:38:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-22 14:41:16 -0700 |
commit | 1a99109e8832bc94710d2dcfb5d9525688913a50 (patch) | |
tree | 68074eda34060864dae283a782dc4b1441e63122 /tensorflow/compiler/xla/service/algebraic_simplifier.cc | |
parent | 1004396a769ad9fdf350ed28083bca5b6ad00402 (diff) |
Merge consecutive broadcast HLO instructions.
As an optimization, replace consecutive broadcast instructions with a single equivalent broadcast in algebraic simplification.
PiperOrigin-RevId: 190127730
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 971c2935c8..88f6ff0a07 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { auto operand = broadcast->mutable_operand(0); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1157,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1201,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector<int64> new_dimensions(operand->dimensions().size()); + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } |