aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-09-27 15:38:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 15:48:55 -0700
commitbfec3d54fed955a4b145220e64c48b94fbb04ae7 (patch)
treea52d4bac97107f0fb4153c18bf218edc6ba82976 /tensorflow/compiler
parent8f85a9de475f0acf0abef4fabc12943e2e487bf7 (diff)
[XLA] Use a result cache to speed up InstructionFusion::CanFuseOnAllPaths()
PiperOrigin-RevId: 214848216
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc29
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h11
2 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 3fdc2cee9a..e884122fcb 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -188,13 +188,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_duplicate) {
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache) {
if (consumer == producer) {
return true;
}
if (!consumer->IsFusible()) {
return false;
}
+ auto cache_it = result_cache->find(std::make_pair(producer, consumer));
+ if (cache_it != result_cache->end()) {
+ return cache_it->second;
+ }
+ bool result = true;
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
@@ -202,20 +209,23 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
- if (do_not_duplicate.count(consumer_operand) > 0 ||
- !ShouldFuse(consumer, i)) {
- return false;
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
+ result = false;
+ break;
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
// producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
- if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
- return false;
+ if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
+ result_cache)) {
+ result = false;
+ break;
}
}
- return true;
+ result_cache->emplace(std::make_pair(producer, consumer), result);
+ return result;
}
InstructionFusion::HloInstructionSet
@@ -231,6 +241,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// fusing operations that require duplication later depending on
// is_expensive_().
HloInstructionSet do_not_duplicate;
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>
+ can_fuse_on_all_paths_result_cache;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
if (do_not_duplicate.count(producer) > 0) {
@@ -286,7 +298,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
if (producer->IsFusible() &&
- CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
+ CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
+ &can_fuse_on_all_paths_result_cache)) {
continue;
}
do_not_duplicate.insert(producer);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 7e1196fb7f..c1ec3b18a1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -151,8 +151,15 @@ class InstructionFusion : public HloModulePass {
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
- bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_fuse);
+ //
+ // A map from <producer, consumer> to a bool is required as the result cache
+ // to store and query the results of calls to this function, in order to avoid
+ // repeated computations.
+ bool CanFuseOnAllPaths(
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>,
+ bool>* result_cache);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.