aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/split_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/split_op.cc')
-rw-r--r--tensorflow/core/kernels/split_op.cc85
1 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 4b12e1f995..cca2fc41c2 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -36,6 +36,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class SplitOpBase : public OpKernel {
@@ -243,6 +246,75 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
};
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+
+template <typename T>
+class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
+ public:
+ typedef SplitOpBase<SYCLDevice, T> Base;
+ explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool done = false;
+ Base::ComputeEasyCases(context, &done);
+ if (!context->status().ok() || done) {
+ return;
+ }
+ const int32 split_dim = context->input(0).flat<int32>()(0);
+ const int32 num_split = Base::num_outputs();
+ const Tensor& input = context->input(1);
+ const TensorShape& input_shape = input.shape();
+
+ // Android also uses int32 indexing, so check here also.
+ OP_REQUIRES(
+ context, FastBoundsCheck(input.NumElements(),
+ std::numeric_limits<Eigen::DenseIndex>::max()),
+ errors::InvalidArgument("Split requires input size < ",
+ std::numeric_limits<Eigen::DenseIndex>::max()));
+
+ Eigen::DenseIndex prefix_dim_size;
+ Eigen::DenseIndex split_dim_size;
+ Eigen::DenseIndex suffix_dim_size;
+
+ std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) =
+ Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim);
+ auto input_reshaped =
+ input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
+
+ const int64 split_dim_output_size = split_dim_size / num_split;
+ TensorShape output_shape(input_shape);
+ output_shape.set_dim(split_dim, split_dim_output_size);
+
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ prefix_dim_size, split_dim_output_size, suffix_dim_size};
+
+ for (int i = 0; i < num_split; ++i) {
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &result));
+ if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
+ Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
+ for (int j = 0; j < 3; ++j) {
+ slice_indices[j] = indices[j];
+ slice_sizes[j] = sizes[j];
+ }
+
+ auto result_shaped = result->shaped<T, 3>(
+ {prefix_dim_size, split_dim_output_size, suffix_dim_size});
+
+ functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(),
+ result_shaped, input_reshaped,
+ slice_indices, slice_sizes);
+ }
+ indices[1] += split_dim_output_size;
+ }
+ }
+};
+
+#endif // TENSORFLOW_USE_SYCL
+
#define REGISTER_SPLIT(type) \
REGISTER_KERNEL_BUILDER(Name("Split") \
.Device(DEVICE_CPU) \
@@ -269,4 +341,17 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Split") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("split_dim"), \
+ SplitOpSYCL<type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
+#undef REGISTER_SYCL
+
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace tensorflow