diff options
Diffstat (limited to 'tensorflow/core/kernels/split_op.cc')
-rw-r--r-- | tensorflow/core/kernels/split_op.cc | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 4b12e1f995..cca2fc41c2 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -36,6 +36,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template <typename Device, typename T> class SplitOpBase : public OpKernel { @@ -243,6 +246,75 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> { }; #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + +template <typename T> +class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> { + public: + typedef SplitOpBase<SYCLDevice, T> Base; + explicit SplitOpSYCL(OpKernelConstruction* c) : Base(c) {} + + void Compute(OpKernelContext* context) override { + bool done = false; + Base::ComputeEasyCases(context, &done); + if (!context->status().ok() || done) { + return; + } + const int32 split_dim = context->input(0).flat<int32>()(0); + const int32 num_split = Base::num_outputs(); + const Tensor& input = context->input(1); + const TensorShape& input_shape = input.shape(); + + // Android also uses int32 indexing, so check here also. + OP_REQUIRES( + context, FastBoundsCheck(input.NumElements(), + std::numeric_limits<Eigen::DenseIndex>::max()), + errors::InvalidArgument("Split requires input size < ", + std::numeric_limits<Eigen::DenseIndex>::max())); + + Eigen::DenseIndex prefix_dim_size; + Eigen::DenseIndex split_dim_size; + Eigen::DenseIndex suffix_dim_size; + + std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = + Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); + auto input_reshaped = + input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size}); + + const int64 split_dim_output_size = split_dim_size / num_split; + TensorShape output_shape(input_shape); + output_shape.set_dim(split_dim, split_dim_output_size); + + Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0}; + Eigen::DSizes<Eigen::DenseIndex, 3> sizes{ + prefix_dim_size, split_dim_output_size, suffix_dim_size}; + + for (int i = 0; i < num_split; ++i) { + Tensor* result = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(i, output_shape, &result)); + if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) { + Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices; + Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes; + for (int j = 0; j < 3; ++j) { + slice_indices[j] = indices[j]; + slice_sizes[j] = sizes[j]; + } + + auto result_shaped = result->shaped<T, 3>( + {prefix_dim_size, split_dim_output_size, suffix_dim_size}); + + functor::Split<SYCLDevice, T>()(context->eigen_device<SYCLDevice>(), + result_shaped, input_reshaped, + slice_indices, slice_sizes); + } + indices[1] += split_dim_output_size; + } + } +}; + +#endif // TENSORFLOW_USE_SYCL + #define REGISTER_SPLIT(type) \ REGISTER_KERNEL_BUILDER(Name("Split") \ .Device(DEVICE_CPU) \ @@ -269,4 +341,17 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("Split") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("split_dim"), \ + SplitOpSYCL<type>) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL); +#undef REGISTER_SYCL + +#endif // TENSORFLOW_USE_SYCL + } // end namespace tensorflow |