diff options
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/slice_op.cc | 116 |
1 files changed, 49 insertions, 67 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index d46701749b..28a379774b 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -190,41 +190,25 @@ class SliceOp : public OpKernel { } return; } -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - HandleCase<NDIM>(context, begin, size, result); \ - return; \ +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + functor::Slice<Device, T, NDIM>()( \ + context->eigen_device<Device>(), result, input, begin, size); \ + return; \ } - HANDLE_DIM(1); HANDLE_DIM(2); HANDLE_DIM(3); HANDLE_DIM(4); HANDLE_DIM(5); HANDLE_DIM(6); - HANDLE_DIM(7); #undef HANDLE_DIM - OP_REQUIRES(context, false, errors::Unimplemented( - "SliceOp : Unhandled input dimensions")); - } - } - - private: - template <int NDIM> - void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, Tensor* result) { - Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; - Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; - for (int i = 0; i < NDIM; ++i) { - indices[i] = begin[i]; - sizes[i] = size[i]; + // handle cases which dim >= 7 + functor::Slice<Device, T, 7>()( + context->eigen_device<Device>(), result, input, begin, size); } - - functor::Slice<Device, T, NDIM>()( - context->eigen_device<Device>(), result->tensor<T, NDIM>(), - context->input(0).tensor<T, NDIM>(), indices, sizes); } }; @@ -264,11 +248,16 @@ class MklSliceOp : public OpKernel { } return; } -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - HandleCase<NDIM>(context, begin, size, result); \ - return; \ - } + // Special case for handling 4-D tensor slice. + if (input_dims == 4) { + HandleCase4D(context, begin, size, result); + } else { +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + functor::Slice<Device, T, NDIM>()( \ + context->eigen_device<Device>(), result, input, begin, size); \ + return; \ + } HANDLE_DIM(1); HANDLE_DIM(2); @@ -276,12 +265,13 @@ class MklSliceOp : public OpKernel { HANDLE_DIM(4); HANDLE_DIM(5); HANDLE_DIM(6); - HANDLE_DIM(7); #undef HANDLE_DIM - OP_REQUIRES(context, false, errors::Unimplemented( - "SliceOp : Unhandled input dimensions")); + // handle cases which dim >= 7 + functor::Slice<Device, T, 7>()( + context->eigen_device<Device>(), result, input, begin, size); + } } } @@ -328,8 +318,7 @@ class MklSliceOp : public OpKernel { return false; } - template <int NDIM> - void HandleCase(OpKernelContext* context, + void HandleCase4D(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, const gtl::ArraySlice<int64>& size, Tensor* result) { int slice_dim = -1; @@ -338,8 +327,7 @@ class MklSliceOp : public OpKernel { // differs from the input tensor in only 1 out of 4 dimensions. // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW // format over channel dimension. - if (NDIM == 4 && - DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { + if (DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { size_t in_strides[4] = { (size_t) in_shape.dim_size(1) * in_shape.dim_size(2) * in_shape.dim_size(3), @@ -403,16 +391,8 @@ class MklSliceOp : public OpKernel { // slice_dim is not 1 or 3, then we fallback to Eigen implementation. } - Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; - Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; - for (int i = 0; i < NDIM; ++i) { - indices[i] = begin[i]; - sizes[i] = size[i]; - } - - functor::Slice<Device, T, NDIM>()( - context->eigen_device<Device>(), result->tensor<T, NDIM>(), - context->input(0).tensor<T, NDIM>(), indices, sizes); + functor::Slice<Device, T, 4>()( + context->eigen_device<Device>(), result, context->input(0), begin, size); } }; #endif @@ -420,13 +400,13 @@ class MklSliceOp : public OpKernel { // Forward declarations of the functor specializations for declared in the // sharded source files. namespace functor { -#define DECLARE_CPU_SPEC(T, NDIM) \ - template <> \ - void Slice<CPUDevice, T, NDIM>::operator()( \ - const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ - typename TTypes<T, NDIM>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ +#define DECLARE_CPU_SPEC(T, NDIM) \ + template <> \ + void Slice<CPUDevice, T, NDIM>::operator()( \ + const CPUDevice& d, Tensor* output, \ + const Tensor& input, \ + const gtl::ArraySlice<int64>& slice_indices, \ + const gtl::ArraySlice<int64>& slice_sizes); \ extern template struct Slice<CPUDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ @@ -476,13 +456,14 @@ REGISTER_SLICE(bfloat16); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, NDIM) \ - template <> \ - void Slice<GPUDevice, T, NDIM>::operator()( \ - const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ - typename TTypes<T, NDIM>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ +#define DECLARE_GPU_SPEC(T, NDIM) \ + template <> \ + void Slice<GPUDevice, T, NDIM>::operator()( \ + const GPUDevice& d, \ + Tensor* output, \ + const Tensor& input, \ + const gtl::ArraySlice<int64>& slice_indices, \ + const gtl::ArraySlice<int64>& slice_sizes); \ extern template struct Slice<GPUDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ @@ -536,13 +517,14 @@ REGISTER_KERNEL_BUILDER(Name("Slice") #ifdef TENSORFLOW_USE_SYCL // Forward declarations of the functor specializations for SYCL. namespace functor { -#define DECLARE_SYCL_SPEC(T, NDIM) \ - template <> \ - void Slice<SYCLDevice, T, NDIM>::operator()( \ - const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\ - typename TTypes<T, NDIM>::ConstTensor input, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ - const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ +#define DECLARE_SYCL_SPEC(T, NDIM) \ + template <> \ + void Slice<SYCLDevice, T, NDIM>::operator()( \ + const SYCLDevice& d, \ + Tensor* output, \ + const Tensor& input, \ + const gtl::ArraySlice<int64>& slice_indices, \ + const gtl::ArraySlice<int64>& slice_sizes); \ extern template struct Slice<SYCLDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ |