aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/instruction_fusion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc260
1 files changed, 156 insertions, 104 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907eae0c..3fdc2cee9a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible(
return do_not_duplicate;
}
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+ explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+ post_order_ = computation->MakeInstructionPostOrder();
+
+ for (size_t i = 0; i < post_order_.size(); ++i) {
+ InsertOrDie(&post_order_index_, post_order_[i], i);
+ }
+ }
+
+ std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ // Instructions are "removed" from the post order by nulling out the element
+ // in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ while (!post_order_.empty() && post_order_.back() == nullptr) {
+ post_order_.pop_back();
+ }
+ if (post_order_.empty()) {
+ return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+ }
+ // We want to iterate in reverse post order, so remove from the back of the
+ // vector.
+ HloInstruction* instruction = post_order_.back();
+ post_order_.pop_back();
+
+ CHECK(instruction != nullptr);
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index_.erase(instruction);
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid creating duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the order
+ // they appear int the queue. In the example, this ensures that B will be
+ // considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the queue.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher priority in the queue come first.
+ return (
+ FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+ });
+ return std::make_pair(instruction, sorted_operand_numbers);
+ }
+
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) override {
+ // Fusing an instruction into a fusion instruction can change the operand
+ // set of the fusion instruction. For simplicity just re-enqueue the
+ // instruction and reconsider it for further fusion in the next iteration.
+ InsertOrDie(&post_order_index_, fusion, post_order_.size());
+ post_order_.push_back(fusion);
+ }
+
+ void RemoveInstruction(HloInstruction* instruction) override {
+ post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+ post_order_index_.erase(instruction);
+ }
+
+ private:
+ std::vector<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+} // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer) {
+ return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(2) << "Before instruction fusion:";
XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
computation_ = computation;
reachability_ = computation_->ComputeReachability();
- // We want to be able to remove arbitrary instructions from the post order
- // and also compare positions of instructions in the post order. To make
- // this possible, create vector of instructions in post order and create a
- // map from HloInstruction* to the instruction's index in the vector. An
- // instruction is "removed" from the vector by setting it's element to
- // nullptr.
- std::vector<HloInstruction*> post_order =
- computation_->MakeInstructionPostOrder();
-
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
- for (size_t i = 0; i < post_order.size(); ++i) {
- InsertOrDie(&post_order_index, post_order[i], i);
- }
-
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+ HloInstructionSet do_not_duplicate =
+ ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+ auto fusion_queue =
+ GetFusionQueue(computation_, [&](HloInstruction* producer) {
+ return do_not_duplicate.count(producer) > 0;
+ });
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
// edges. When we fuse an edge, we create a copy of the producer inside the
// fusion instruction.
- while (!post_order.empty()) {
- // We want to iterate in reverse post order, so remove from the back of
- // the vector.
- HloInstruction* instruction = post_order.back();
- post_order.pop_back();
-
- // Instructions are "removed" from the post order by nulling out the
- // element in the vector, so if the pointer is null, continue to the next
- // instruction in the sort.
+ while (true) {
+ auto next_entry =
+ fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+ auto instruction = next_entry.first;
if (instruction == nullptr) {
- continue;
+ break;
}
- // Remove instruction from the index map to ensure the vector and map stay
- // consistent.
- post_order_index.erase(instruction);
-
if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- // Consider each operand of this instruction for fusion into this
- // instruction. We want to consider the operands in a particular order to
- // avoid creating duplicate instruction clones in the fusion instruction.
- // For example, consider the following expression:
- //
- // A = ...
- // B = op(A)
- // C = op(A, B)
- //
- // If we are considering the operands of C for fusion into C. We might
- // fuse A or B first. If we fuse A first, we get:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // C' = op(A', B) }
- //
- // Where A' and C' are clones of A and C, respectively. Now only B is an
- // operand of the fusion instruction C_fusion, so then we fuse B:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // B' = op(A)
- // C' = op(A', B') }
- //
- // Now A is an operand of C_fusion again, so we then fuse A (again!):
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // A" = ..
- // B' = op(A")
- // C' = op(A', B') }
- //
- // We prevent this duplication by considering the operands in the reverse
- // order they appear in the instruction post order. In the example, this
- // ensures that B will be considered before A.
- //
- // We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers;
- sorted_operand_numbers.reserve(instruction->operands().size());
- for (int i = 0; i < instruction->operands().size(); ++i) {
- // This will happen if we have two possible instructions to fuse the
- // same operand into; once the operand is fused into one instruction,
- // the other instruction will get a new get-tuple-element as its
- // operand, which is not in the post-order index.
- // TODO(tjoerg): Look into fusing past these multi-output fuse points.
- if (post_order_index.find(instruction->mutable_operand(i)) ==
- post_order_index.end()) {
- continue;
- }
- sorted_operand_numbers.push_back(i);
- }
- std::sort(
- sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
- [&](int64 i, int64 j) {
- // Instructions with higher indices in the post order come
- // first.
- return (
- FindOrDie(post_order_index, instruction->mutable_operand(i)) >
- FindOrDie(post_order_index, instruction->mutable_operand(j)));
- });
+ std::vector<int64>& sorted_operand_numbers = next_entry.second;
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// TODO(tjoerg): Consider making multi-output fusion the default.
if (ShouldFuse(instruction, i) &&
do_not_duplicate.count(operand) == 0) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = Fuse(operand, instruction);
} else if (ShouldFuseIntoMultiOutput(instruction, i) &&
!MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = FuseIntoMultiOutput(operand, instruction);
} else {
continue;
}
- // Fusing an instruction into a fusion instruction can change the
- // operand set of the fusion instruction. For simplicity just push the
- // instruction to the top of the post_order and reconsider it for
- // further fusion in the next iteration of the outer loop.
- post_order.push_back(fusion_instruction);
- InsertOrDie(&post_order_index, fusion_instruction,
- post_order.size() - 1);
+ fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+ instruction);
changed = true;
if (operand->user_count() == 0) {
- // Operand is now dead. Remove from post order by setting its
- // location to nullptr.
- post_order[FindOrDie(post_order_index, operand)] = nullptr;
- post_order_index.erase(operand);
-
+ do_not_duplicate.erase(operand);
+ // Operand is now dead. Remove from queue.
+ fusion_queue->RemoveInstruction(operand);
// Remove from computation.
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
}
+
+ if (fusion_instruction != instruction) {
+ do_not_duplicate.erase(instruction);
+ }
break;
}
}