aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-15 16:59:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 18:10:30 -0700
commitb05a83916f21becf59eff4e9db1d375eeb0fe904 (patch)
tree7c1e242cd968cb9738cc4d5546e5613ed7a49126
parent3d489f2ef23b540aa835d2182b12c1830833b4f0 (diff)
[TF:XLA] Don't compile functions that are marked "noinline".
The underlying function mechanism uses LocalExecutor to call the function, which interacts poorly with the LocalExecutor used by tf2xla to translate the TF graph into XLA. Change: 150268961
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc12
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc15
-rw-r--r--tensorflow/compiler/tests/function_test.py3
-rw-r--r--tensorflow/compiler/tests/jit_test.py12
4 files changed, 33 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 403e41a4fe..22dbf7ec99 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -116,6 +116,18 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
+ const FunctionDef& fdef = fbody->fdef;
+ bool noinline = false;
+ if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
+ noinline) {
+ // The underlying mechanism that calls non-inlined functions uses
+ // LocalExecutor, which interacts poorly with the LocalExecutor used by
+ // tf2xla to translate the TF graph into XLA. So we avoid this for now.
+ //
+ // TODO(b/36139787): Create a mechanism to set inlining hints.
+ VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
+ return false;
+ }
for (Node* node : fbody->graph->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 78ec713937..91e4a2b41c 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -184,13 +184,20 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
}
TEST(XlaCompilationTest, FunctionCalls) {
- FunctionDefLibrary flib;
- *flib.add_function() = FunctionDefHelper::Define(
+ FunctionDef compilable = FunctionDefHelper::Define(
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
{{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
- *flib.add_function() =
+ FunctionDef uncompilable =
FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
+ FunctionDef noinline = compilable;
+ noinline.mutable_signature()->set_name("NoInlineFn");
+ AddAttr("_noinline", bool(true), noinline.mutable_attr());
+
+ FunctionDefLibrary flib;
+ *flib.add_function() = compilable;
+ *flib.add_function() = uncompilable;
+ *flib.add_function() = noinline;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph(new Graph(&flib_def));
@@ -202,6 +209,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
+ ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}
@@ -213,6 +221,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
+ EXPECT_TRUE(clusters.find("E") == clusters.cend());
}
// Metadata-only operators such as Shape/Rank/Size may not be the root of a
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 40cc7a5d60..cbe2888696 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -103,7 +103,8 @@ class FunctionTest(XLATestCase):
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)
- def testFunctionsNoInline(self):
+ # TODO(b/36139787): Re-enable this test when noinline works again.
+ def DISABLED_testFunctionsNoInline(self):
@function.Defun(dtypes.float32, noinline=True)
def TimesTwo(x):
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 8a568d6d58..11914080ec 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -160,12 +160,14 @@ class JitLaunchTest(test.TestCase):
# function (say, Bar) which is not inlined. When the compiler compiles
# Foo, it needs to symbolic execute Bar correctly regardless whether
# Bar is inlined or not.
- #
+
+ # TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
- self._compare(
- AddOnceReturnTwice, [np.array(
- [[[0.5, -1.0]]], dtype=np.float32)],
- noinline=True)
+ # self._compare(
+ # AddOnceReturnTwice, [np.array(
+ # [[[0.5, -1.0]]], dtype=np.float32)],
+ # noinline=True)
+
# Tests compiled=True and noinline=False.
self._compare(
AddOnceReturnTwice, [np.array(