aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc102
1 files changed, 60 insertions, 42 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 45d422943c..90d5d56998 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
+ VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
@@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def,
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Loop marking: " << while_node.type_string();
-
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'cond' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
@@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop condition: " << cond_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'body' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
@@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop body: " << body_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop body: " << body_func;
return false;
}
- VLOG(2) << "Loop is compilable.";
return true;
}
@@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node,
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Function marking: " << call_def.op();
-
if (depth > kMaxRecursionDepth) {
- VLOG(2) << "Function depth limit exceeded";
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": function depth limit exceeded.";
return false;
}
@@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def,
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
if (!status.ok()) {
- VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": could not instantiate: " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
@@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def,
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
- VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": can't compile noinline function.";
return false;
}
@@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
lib_runtime)) {
- VLOG(2) << "Function marking failed: unsupported op " << node->name()
- << ": " << node->def().ShortDebugString();
+ VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
+ << node->name() << ": " << node->def().ShortDebugString();
return false;
}
}
- VLOG(2) << "Function is compilable: " << call_def.op();
return true;
}
-// Tests whether `node` has a DT_RESOURCE typed input or output.
-bool HasResourceInputOrOutput(const Node& node) {
- return std::find(node.input_types().begin(), node.input_types().end(),
- DT_RESOURCE) != node.input_types().end() ||
- std::find(node.output_types().begin(), node.output_types().end(),
- DT_RESOURCE) != node.output_types().end();
-}
-
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
@@ -357,24 +351,27 @@ Status FindCompilationCandidates(
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
+ if (fuel >= std::numeric_limits<int64>::max() / 2) {
+ // The assumption is that if fuel started out as INT64_MAX, it will forever
+ // stay greater than INT64_MAX / 2.
+ VLOG(2) << "Starting fuel: infinity";
+ } else {
+ VLOG(2) << "Starting fuel: " << fuel;
+ }
+
for (Node* node : sorted_nodes) {
- VLOG(2) << "Fuel: " << fuel;
if (fuel <= 0) {
- VLOG(2)
+ VLOG(1)
<< "Hit fuel limit; not marking any remaining ops as clusterable.";
break;
}
- VLOG(2) << "FindCompilationCandidates(): Processing "
- << node->DebugString();
-
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
- VLOG(2) << "Compilation rejected node: not compilable " << node->name()
- << ": " << node->type_string();
+ // is_compilable_fn has already logged the reason if it returned false.
continue;
}
@@ -384,14 +381,14 @@ Status FindCompilationCandidates(
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
- VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
- << ": " << node->type_string();
+ VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
+ << node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Compilation rejected node: resource input/output "
- << node->name() << ": " << node->type_string();
+ VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ << node->type_string();
continue;
}
if (node->type_string() == "While" &&
@@ -401,15 +398,11 @@ Status FindCompilationCandidates(
// _Arg nodes in a top-level function represent feeds.
// Do not compile them.
if (node->type_string() == "_Arg") {
- VLOG(2) << "Skipping jit compilation for '_Arg'-typed node "
- << node->DebugString();
continue;
}
// _Retval nodes in a top-level function represent fetches.
// Do not compile them.
if (node->type_string() == "_Retval") {
- VLOG(2) << "Compilation rejected node: return value " << node->name()
- << ": " << node->type_string();
continue;
}
candidates->insert(node);
@@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run(
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
+ VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
return false;
}
@@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run(
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") is false.";
+ }
+ return compile;
+ }
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") on callee is false.";
+ }
+ return compile;
+ }
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness to propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
return false;
}
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
+ VLOG(2) << "Rejecting " << node->name()
+ << ": not fusable op but fusion_only enabled.";
return false;
}
@@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run(
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
- return (ignore_registration || registration->enable_jit_by_default) &&
- global_jit_level > 0;
+ bool should_compile =
+ (ignore_registration || registration->enable_jit_by_default) &&
+ global_jit_level > 0;
+ if (!should_compile) {
+ if (global_jit_level <= 0) {
+ VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
+ } else {
+ VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
+ }
+ }
+ return should_compile;
};
return RunImpl(options, is_compilable);
}