aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc56
1 files changed, 41 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 98df0844ea..d6988a562c 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
+namespace functor {
+template <typename Device, typename T>
+struct SelectScalarHandler;
+} // namespace functor
+
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
then->shape().DebugString(), " vs. ",
else_->shape().DebugString()));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
- {"t", "e"}, "output", then->shape(), &output));
-
- if (output->NumElements() > 0) {
- functor::SelectScalarFunctor<Device, T> func;
- TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
- func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
- then->flat<T>(), else_->flat<T>());
- }
+ functor::SelectScalarHandler<Device, T> handler;
+ handler(ctx, cond, then, else_);
}
private:
@@ -209,6 +206,40 @@ struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
+struct SelectScalarHandler {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
+
+ if (output->NumElements() > 0) {
+ functor::SelectScalarFunctor<Device, T> func;
+ TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
+ then->flat<T>(), else_->flat<T>());
+ }
+ }
+};
+
+// Specilization for CPU device. Forward input to output depending on the `cond`
+// value.
+// TODO(sjhwang): Consider specializing for GPUDevice as well by using
+// GPUDevice::memcpyDeviceToHost() to fetch bool value.
+template <typename T>
+struct SelectScalarHandler<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ if (cond->scalar<bool>()()) {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
+ } else {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
+ }
+ }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename Device, typename T>
struct SelectScalarFunctorBase {
void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
@@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
}
};
-// CPU Specializations of Select functors with scalar
-template <typename T>
-struct SelectScalarFunctor<CPUDevice, T>
- : SelectScalarFunctorBase<CPUDevice, T> {};
-#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct SelectScalarFunctor<SYCLDevice, T>
: SelectScalarFunctorBase<SYCLDevice, T> {};