aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/split_lib_cpu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/split_lib_cpu.cc')
-rw-r--r--tensorflow/core/kernels/split_lib_cpu.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/split_lib_cpu.cc b/tensorflow/core/kernels/split_lib_cpu.cc
index 41b2d6f0f5..e377e4d97a 100644
--- a/tensorflow/core/kernels/split_lib_cpu.cc
+++ b/tensorflow/core/kernels/split_lib_cpu.cc
@@ -43,5 +43,24 @@ TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
DEFINE_CPU_KERNELS(quint8)
DEFINE_CPU_KERNELS(bfloat16)
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T>
+void Split<Eigen::SyclDevice, T>::operator()(
+ const Eigen::SyclDevice& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+ const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
+ if (output.size() < 131072) {
+ output = input.slice(slice_indices, slice_sizes);
+ } else {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+ }
+}
+
+#define DEFINE_SYCL_KERNELS(T) template struct Split<Eigen::SyclDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_SYCL_KERNELS)
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
} // namespace tensorflow