aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_fusion_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_fusion_optimizer.cc')
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 74257b09a8..4b499b1613 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
@@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
+ std::unique_ptr<DeadnessAnalysis> deadness;
+ TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
+
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
@@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
continue;
}
+ // If inputs to `node` can have conflicting deadness (i.e. some are alive
+ // and some are dead) then don't compile it. XLA cannot represent the
+ // deadness semantics of these nodes correctly and auto-clustering these
+ // nodes can cause deadness to propagate to nodes that should be live.
+ if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ continue;
+ }
+
compilation_candidates.insert(node);
}