diff options
author | Yuanzhong Xu <yuanzx@google.com> | 2018-09-27 15:38:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 15:48:55 -0700 |
commit | bfec3d54fed955a4b145220e64c48b94fbb04ae7 (patch) | |
tree | a52d4bac97107f0fb4153c18bf218edc6ba82976 /tensorflow/compiler | |
parent | 8f85a9de475f0acf0abef4fabc12943e2e487bf7 (diff) |
[XLA] Use a result cache to speed up InstructionFusion::CanFuseOnAllPaths()
PiperOrigin-RevId: 214848216
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.cc | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/instruction_fusion.h | 11 |
2 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 3fdc2cee9a..e884122fcb 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -188,13 +188,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_duplicate) { + const HloInstructionSet& do_not_fuse, + tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>* + result_cache) { if (consumer == producer) { return true; } if (!consumer->IsFusible()) { return false; } + auto cache_it = result_cache->find(std::make_pair(producer, consumer)); + if (cache_it != result_cache->end()) { + return cache_it->second; + } + bool result = true; for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter @@ -202,20 +209,23 @@ bool InstructionFusion::CanFuseOnAllPaths( if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_duplicate.count(consumer_operand) > 0 || - !ShouldFuse(consumer, i)) { - return false; + if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + result = false; + break; } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { - return false; + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse, + result_cache)) { + result = false; + break; } } - return true; + result_cache->emplace(std::make_pair(producer, consumer), result); + return result; } InstructionFusion::HloInstructionSet @@ -231,6 +241,8 @@ InstructionFusion::ComputeGloballyUnfusible( // fusing operations that require duplication later depending on // is_expensive_(). HloInstructionSet do_not_duplicate; + tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool> + can_fuse_on_all_paths_result_cache; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { if (do_not_duplicate.count(producer) > 0) { @@ -286,7 +298,8 @@ InstructionFusion::ComputeGloballyUnfusible( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusible() && - CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate, + &can_fuse_on_all_paths_result_cache)) { continue; } do_not_duplicate.insert(producer); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 7e1196fb7f..c1ec3b18a1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -151,8 +151,15 @@ class InstructionFusion : public HloModulePass { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_fuse); + // + // A map from <producer, consumer> to a bool is required as the result cache + // to store and query the results of calls to this function, in order to avoid + // repeated computations. + bool CanFuseOnAllPaths( + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse, + tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, + bool>* result_cache); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. |