diff options
-rw-r--r-- | tensorflow/compiler/jit/xla_device_ops.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tests/xla_device_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/no_op.cc | 5 |
3 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 9c00a0682c..59822a18c0 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -73,7 +73,9 @@ class XlaDeviceDummyOp : public OpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp<Var>); + ResourceHandleOp<Var>); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index b707bd0963..f0b010fa67 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test @@ -46,6 +47,12 @@ class XlaDeviceTest(XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe78..65ab9da8d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow |