aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_op_select.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_op_select.cc')
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc59
1 files changed, 51 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index add26b5ac8..709628da13 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -28,6 +28,10 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
+
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@@ -169,12 +173,24 @@ REGISTER_SELECT_GPU(complex128);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+// Registration of the SYCL implementations.
+#define REGISTER_SELECT_SYCL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Select").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SelectOp<SYCLDevice, type>);
+
+REGISTER_SELECT_SYCL(float);
+REGISTER_SELECT_SYCL(int32);
+#undef REGISTER_SELECT_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
namespace functor {
// CPU Specializations of Select functors.
-template <typename T>
-struct SelectFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+template <typename Device, typename T>
+struct SelectFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
typename TTypes<bool>::ConstFlat cond_flat,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
@@ -182,10 +198,18 @@ struct SelectFunctor<CPUDevice, T> {
}
};
-// CPU Specializations of Select functors with scalar
template <typename T>
-struct SelectScalarFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+struct SelectFunctor<CPUDevice, T>
+ : SelectFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SelectFunctor<SYCLDevice, T>
+ : SelectFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename Device, typename T>
+struct SelectScalarFunctorBase {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
typename TTypes<T>::ConstFlat then_flat,
typename TTypes<T>::ConstFlat else_flat) {
@@ -193,9 +217,19 @@ struct SelectScalarFunctor<CPUDevice, T> {
}
};
+// CPU Specializations of Select functors with scalar
template <typename T>
-struct BatchSelectFunctor<CPUDevice, T> {
- void operator()(const CPUDevice& d,
+struct SelectScalarFunctor<CPUDevice, T>
+ : SelectScalarFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct SelectScalarFunctor<SYCLDevice, T>
+ : SelectScalarFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
+template <typename Device, typename T>
+struct BatchSelectFunctorBase {
+ void operator()(const Device& d,
typename TTypes<T>::Matrix output_flat_outer_dims,
TTypes<bool>::ConstVec cond_vec,
typename TTypes<T>::ConstMatrix then_flat_outer_dims,
@@ -220,6 +254,15 @@ struct BatchSelectFunctor<CPUDevice, T> {
}
};
+template <typename T>
+struct BatchSelectFunctor<CPUDevice, T>
+ : BatchSelectFunctorBase<CPUDevice, T> {};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+struct BatchSelectFunctor<SYCLDevice, T>
+ : BatchSelectFunctorBase<SYCLDevice, T> {};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow