aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier.cc
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-03-22 14:38:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 14:41:16 -0700
commit1a99109e8832bc94710d2dcfb5d9525688913a50 (patch)
tree68074eda34060864dae283a782dc4b1441e63122 /tensorflow/compiler/xla/service/algebraic_simplifier.cc
parent1004396a769ad9fdf350ed28083bca5b6ad00402 (diff)
Merge consecutive broadcast HLO instructions.
As an optimization, replace consecutive broadcast instructions with a single equivalent broadcast in algebraic simplification. PiperOrigin-RevId: 190127730
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc22
1 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 971c2935c8..88f6ff0a07 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1121,10 +1121,10 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
auto operand = broadcast->mutable_operand(0);
+ auto dims = broadcast->dimensions();
// A degenerate broadcast of a reshape that does not change the number of
// elements can be replaced by a reshape.
- if (std::is_sorted(broadcast->dimensions().begin(),
- broadcast->dimensions().end()) &&
+ if (std::is_sorted(dims.begin(), dims.end()) &&
ShapeUtil::ElementsIn(broadcast->shape()) ==
ShapeUtil::ElementsIn(operand->shape())) {
VLOG(10) << "transform broadcast(X) -> reshape(X) where "
@@ -1142,8 +1142,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
VLOG(10) << "transform broadcast(X) -> transpose(X) where "
"n(broadcast(X)) == n(X)";
return ReplaceWithNewInstruction(
- broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand,
- broadcast->dimensions()));
+ broadcast,
+ HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
}
// A broadcast of a reshape which merely inserts 1-sized dimensions can
@@ -1157,7 +1157,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
if (merely_inserts_or_deletes_1_sized_dimensions &&
deleted_indices.empty()) {
std::reverse(inserted_indices.begin(), inserted_indices.end());
- auto dims = broadcast->dimensions();
for (auto inserted_index : inserted_indices) {
dims.erase(dims.begin() + inserted_index);
}
@@ -1201,6 +1200,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
return user->ReplaceAllUsesWith(new_broadcast);
}
}
+ return Status::OK();
+ }
+
+ // Merge two consecutive broadcasts into a single one.
+ if (operand->opcode() == HloOpcode::kBroadcast) {
+ std::vector<int64> new_dimensions(operand->dimensions().size());
+ for (auto dim : operand->dimensions()) {
+ new_dimensions.push_back(dims[dim]);
+ }
+ return ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateBroadcast(
+ broadcast->shape(), operand->mutable_operand(0), new_dimensions));
}
return Status::OK();
}