aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc23
1 files changed, 17 insertions, 6 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 8c3882116d..38eb6d830f 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/memory_types.h"
@@ -462,18 +464,19 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
const FunctionLibraryDefinition* fld = options.flib_def;
- auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
- const Node* node, const DeviceType& device_type) {
+ std::unique_ptr<DeadnessAnalysis> deadness;
+ {
+ XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
+ TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
+ }
+
+ auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
return false;
}
- // Don't compile control trigger nodes. We won't preserve their deadness
- // semantics correctly, so it's safest not to compile them.
- if (node->IsControlTrigger()) return false;
-
// If this device requires a JIT, we must say yes.
if (registration->requires_compilation) return true;
@@ -485,6 +488,14 @@ Status MarkForCompilationPass::Run(
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
if (status.ok()) return compile;
+ // If inputs to `node` can have conflicting deadness (i.e. some are alive
+ // and some are dead) then don't compile it. XLA cannot represent the
+ // deadness semantics of these nodes correctly and auto-clustering these
+ // nodes can cause deadness to propagate to nodes that should be live.
+ if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ return false;
+ }
+
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
return false;