diff options
-rw-r--r-- | tensorflow/compiler/jit/create_xla_launch_op.cc | 22 | ||||
-rw-r--r-- | tensorflow/compiler/tests/eager_test.py | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 14 |
3 files changed, 43 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 731b8ebfdc..a2e6285339 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -66,8 +66,28 @@ class SinglePassSearch { Status CompilationRequested(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + // TODO(b/110359382): Make custom kernel creation return a bool instead of + // status. + // We don't set error messages here to avoid unnecessary string copy. + // Similarly below. + return Status(error::INVALID_ARGUMENT, ""); + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. bool xla_compile = false; - // Check if op is marked _XlaCompile=true. Status status = flr.GetFunctionLibraryDefinition()->GetAttr( node_def, kXlaCompileAttr, &xla_compile); if (!status.ok() || !xla_compile) { diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 3bb3049e87..e438832a23 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -290,7 +290,7 @@ class EagerFunctionTest(XLATestCase): def testBasic(self): with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) + matmul = function.defun(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq = matmul(t, t, transpose_a=True) self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) @@ -312,7 +312,7 @@ class EagerFunctionTest(XLATestCase): def model(x): x = conv(x) return pool(x) - model = function.defun(model, compiled=True) + model = function.defun(model) x = array_ops.ones([1, 4, 4, 1]) y = model(x) @@ -322,7 +322,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) - @function.defun(compiled=True) + @function.defun def f(): return v.read_value() @@ -337,7 +337,7 @@ class EagerFunctionTest(XLATestCase): v.assign_add(1.0) return v - f = function.defun(f, compiled=True) + f = function.defun(f) var = f(v) self.assertEqual(2.0, var.numpy()) @@ -365,7 +365,7 @@ class EagerFunctionTest(XLATestCase): d = r2 * v2 return a, b, c, d - foo = function.defun(foo, compiled=True) + foo = function.defun(foo) c1 = [0, 0] c2 = array_ops.ones([2], dtype=dtypes.int32) @@ -387,7 +387,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) - @function.defun(compiled=True) + @function.defun def f(x): x = v0 * v0 * x return x @@ -450,7 +450,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionInput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return math_ops.reduce_sum(x, axis=2) @@ -461,7 +461,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionOutput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return x * constant_op.constant(100 * [[[10.0, 2.0]]]) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c619857b78..08abded4e4 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -39,6 +39,11 @@ namespace tensorflow { namespace { +// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h +// Copied here because we don't currently compile XLA on windows. So, can't +// depend on it directly. +const char* const kXlaCompileAttr = "_XlaCompile"; + // Initializes the step stats if needed. void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) { // Lazily initialize the RunMetadata with information about all devices if @@ -472,6 +477,15 @@ Status EagerLocalExecute(EagerOperation* op, device == nullptr ? "unspecified" : device->name()); KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key); if (kernel == nullptr) { + // If we are running a function on explicitly requested TPU, + // compile it with XLA. + // Note that it is not ideal, but currently ok, to set this + // attribute after computing the kernel cache key above. + if (op->is_function() && device != nullptr && + device->device_type() == "TPU") { + op->MutableAttrs()->Set(kXlaCompileAttr, true); + } + const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); if (device == nullptr) { status = SelectDevice(ndef, ctx, &device); |