aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc22
-rw-r--r--tensorflow/compiler/tests/eager_test.py16
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc14
3 files changed, 43 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 731b8ebfdc..a2e6285339 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -66,8 +66,28 @@ class SinglePassSearch {
Status CompilationRequested(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
+ const FunctionDef* function_def =
+ flr.GetFunctionLibraryDefinition()->Find(node_def.name());
+ if (function_def == nullptr) {
+ // The node def is not calling a function. Individual ops can be
+ // run directly using on-demand mode, no need to create XlaLaunch
+ // kernel for them.
+ // TODO(b/110359382): Make custom kernel creation return a bool instead of
+ // status.
+ // We don't set error messages here to avoid unnecessary string copy.
+ // Similarly below.
+ return Status(error::INVALID_ARGUMENT, "");
+ }
+
+ // If kXlaCompileAttr is set on the node_def, use its value.
+ const auto& it = node_def.attr().find(kXlaCompileAttr);
+ if (it != node_def.attr().end()) {
+ return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, "");
+ }
+
+ // kXlaCompileAttr is not set on node_def, check if it is set on
+ // FunctionDef.
bool xla_compile = false;
- // Check if op is marked _XlaCompile=true.
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 3bb3049e87..e438832a23 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -290,7 +290,7 @@ class EagerFunctionTest(XLATestCase):
def testBasic(self):
with self.test_scope():
- matmul = function.defun(math_ops.matmul, compiled=True)
+ matmul = function.defun(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
@@ -312,7 +312,7 @@ class EagerFunctionTest(XLATestCase):
def model(x):
x = conv(x)
return pool(x)
- model = function.defun(model, compiled=True)
+ model = function.defun(model)
x = array_ops.ones([1, 4, 4, 1])
y = model(x)
@@ -322,7 +322,7 @@ class EagerFunctionTest(XLATestCase):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
- @function.defun(compiled=True)
+ @function.defun
def f():
return v.read_value()
@@ -337,7 +337,7 @@ class EagerFunctionTest(XLATestCase):
v.assign_add(1.0)
return v
- f = function.defun(f, compiled=True)
+ f = function.defun(f)
var = f(v)
self.assertEqual(2.0, var.numpy())
@@ -365,7 +365,7 @@ class EagerFunctionTest(XLATestCase):
d = r2 * v2
return a, b, c, d
- foo = function.defun(foo, compiled=True)
+ foo = function.defun(foo)
c1 = [0, 0]
c2 = array_ops.ones([2], dtype=dtypes.int32)
@@ -387,7 +387,7 @@ class EagerFunctionTest(XLATestCase):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
- @function.defun(compiled=True)
+ @function.defun
def f(x):
x = v0 * v0 * x
return x
@@ -450,7 +450,7 @@ class ExcessivePaddingTest(XLATestCase):
def testAsFunctionInput(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x):
return math_ops.reduce_sum(x, axis=2)
@@ -461,7 +461,7 @@ class ExcessivePaddingTest(XLATestCase):
def testAsFunctionOutput(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x):
return x * constant_op.constant(100 * [[[10.0, 2.0]]])
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index c619857b78..08abded4e4 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -39,6 +39,11 @@ namespace tensorflow {
namespace {
+// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
+// Copied here because we don't currently compile XLA on windows. So, can't
+// depend on it directly.
+const char* const kXlaCompileAttr = "_XlaCompile";
+
// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
@@ -472,6 +477,15 @@ Status EagerLocalExecute(EagerOperation* op,
device == nullptr ? "unspecified" : device->name());
KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
if (kernel == nullptr) {
+ // If we are running a function on explicitly requested TPU,
+ // compile it with XLA.
+ // Note that it is not ideal, but currently ok, to set this
+ // attribute after computing the kernel cache key above.
+ if (op->is_function() && device != nullptr &&
+ device->device_type() == "TPU") {
+ op->MutableAttrs()->Set(kXlaCompileAttr, true);
+ }
+
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (device == nullptr) {
status = SelectDevice(ndef, ctx, &device);