diff options
author | 2018-05-31 17:03:07 -0700 | |
---|---|---|
committer | 2018-05-31 17:06:19 -0700 | |
commit | ba6d01807feaeaeb10272c9e55a7002306b63db5 (patch) | |
tree | 1e12c08f7a8eae14962e9f0556efb235158613fe /tensorflow/core | |
parent | 6a6cfbfe4bd79fb0eb21b3d0753d3ddf6ee86ce8 (diff) |
[TF:XLA] Preliminary support for tpu.replicate() inside of TF control flow (such as tf.while_loop()).
Register the remaining control-flow operators on XLA devices.
PiperOrigin-RevId: 198803131
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.h | 16 |
2 files changed, 22 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 7d5d54e5be..ebf844d75f 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -587,24 +587,14 @@ REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -// A LoopCond op has one input and one output. The input is a boolean -// scalar representing the taken branches of the "pivot" Switch that -// determines loop termination. As a contract, any high-level front-end -// should always use port '0' of the "pivot" switches for loop exit. -class LoopCondOp : public OpKernel { - public: - explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); - } +LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} +LoopCondOp::~LoopCondOp() = default; - bool IsExpensive() override { return false; } - - ~LoopCondOp() override {} +void LoopCondOp::Compute(OpKernelContext* context) { + context->set_output(0, context->input(0)); +} - TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); -}; +bool LoopCondOp::IsExpensive() { return false; } REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); REGISTER_KERNEL_BUILDER(Name("LoopCond") diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 4838f2e2bf..8edbcc9077 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); }; +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context); + ~LoopCondOp() override; + + void Compute(OpKernelContext* context) override; + + bool IsExpensive() override; + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ |