diff options
author | 2017-03-15 16:59:17 -0800 | |
---|---|---|
committer | 2017-03-15 18:10:30 -0700 | |
commit | b05a83916f21becf59eff4e9db1d375eeb0fe904 (patch) | |
tree | 7c1e242cd968cb9738cc4d5546e5613ed7a49126 | |
parent | 3d489f2ef23b540aa835d2182b12c1830833b4f0 (diff) |
[TF:XLA] Don't compile functions that are marked "noinline".
The underlying function mechanism uses LocalExecutor to call the function,
which interacts poorly with the LocalExecutor used by tf2xla to translate
the TF graph into XLA.
Change: 150268961
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/tests/function_test.py | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tests/jit_test.py | 12 |
4 files changed, 33 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 403e41a4fe..22dbf7ec99 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -116,6 +116,18 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, } const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); CHECK(fbody); + const FunctionDef& fdef = fbody->fdef; + bool noinline = false; + if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() && + noinline) { + // The underlying mechanism that calls non-inlined functions uses + // LocalExecutor, which interacts poorly with the LocalExecutor used by + // tf2xla to translate the TF graph into XLA. So we avoid this for now. + // + // TODO(b/36139787): Create a mechanism to set inlining hints. + VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + return false; + } for (Node* node : fbody->graph->nodes()) { if (node->IsSource() || node->IsSink()) continue; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 78ec713937..91e4a2b41c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -184,13 +184,20 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { } TEST(XlaCompilationTest, FunctionCalls) { - FunctionDefLibrary flib; - *flib.add_function() = FunctionDefHelper::Define( + FunctionDef compilable = FunctionDefHelper::Define( "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); - *flib.add_function() = + FunctionDef uncompilable = FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); + FunctionDef noinline = compilable; + noinline.mutable_signature()->set_name("NoInlineFn"); + AddAttr("_noinline", bool(true), noinline.mutable_attr()); + + FunctionDefLibrary flib; + *flib.add_function() = compilable; + *flib.add_function() = uncompilable; + *flib.add_function() = noinline; FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); std::unique_ptr<Graph> graph(new Graph(&flib_def)); @@ -202,6 +209,7 @@ TEST(XlaCompilationTest, FunctionCalls) { Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); + ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); TF_EXPECT_OK(builder.ToGraph(graph.get())); } @@ -213,6 +221,7 @@ TEST(XlaCompilationTest, FunctionCalls) { EXPECT_EQ(clusters["B"], clusters["C"]); EXPECT_TRUE(clusters.find("A") == clusters.cend()); EXPECT_TRUE(clusters.find("D") == clusters.cend()); + EXPECT_TRUE(clusters.find("E") == clusters.cend()); } // Metadata-only operators such as Shape/Rank/Size may not be the root of a diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 40cc7a5d60..cbe2888696 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -103,7 +103,8 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) - def testFunctionsNoInline(self): + # TODO(b/36139787): Re-enable this test when noinline works again. + def DISABLED_testFunctionsNoInline(self): @function.Defun(dtypes.float32, noinline=True) def TimesTwo(x): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 8a568d6d58..11914080ec 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -160,12 +160,14 @@ class JitLaunchTest(test.TestCase): # function (say, Bar) which is not inlined. When the compiler compiles # Foo, it needs to symbolic execute Bar correctly regardless whether # Bar is inlined or not. - # + + # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. - self._compare( - AddOnceReturnTwice, [np.array( - [[[0.5, -1.0]]], dtype=np.float32)], - noinline=True) + # self._compare( + # AddOnceReturnTwice, [np.array( + # [[[0.5, -1.0]]], dtype=np.float32)], + # noinline=True) + # Tests compiled=True and noinline=False. self._compare( AddOnceReturnTwice, [np.array( |