From 9c270922715306efefce848b87dee3690cdddd27 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 12 Sep 2018 11:30:07 -0700 Subject: [XLA] A queue interface to allow fusion in different orders. PiperOrigin-RevId: 212674212 --- .../compiler/xla/service/instruction_fusion.cc | 260 ++++++++++++--------- 1 file changed, 156 insertions(+), 104 deletions(-) (limited to 'tensorflow/compiler/xla/service/instruction_fusion.cc') diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8c907eae0c..3fdc2cee9a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector post_order_; + tensorflow::gtl::FlatMap post_order_index_; +}; + +} // namespace + +std::unique_ptr InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer) { + return absl::make_unique(computation); +} + StatusOr InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -306,111 +439,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); @@ -425,32 +478,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } -- cgit v1.2.3