diff options
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/slice_op.cc | 116 |
1 files changed, 67 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 28a379774b..d46701749b 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -190,25 +190,41 @@ class SliceOp : public OpKernel { } return; } -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - functor::Slice<Device, T, NDIM>()( \ - context->eigen_device<Device>(), result, input, begin, size); \ - return; \ +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + HandleCase<NDIM>(context, begin, size, result); \ + return; \ } + HANDLE_DIM(1); HANDLE_DIM(2); HANDLE_DIM(3); HANDLE_DIM(4); HANDLE_DIM(5); HANDLE_DIM(6); + HANDLE_DIM(7); #undef HANDLE_DIM - // handle cases which dim >= 7 - functor::Slice<Device, T, 7>()( - context->eigen_device<Device>(), result, input, begin, size); + OP_REQUIRES(context, false, errors::Unimplemented( + "SliceOp : Unhandled input dimensions")); + } + } + + private: + template <int NDIM> + void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, + const gtl::ArraySlice<int64>& size, Tensor* result) { + Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; + Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; + for (int i = 0; i < NDIM; ++i) { + indices[i] = begin[i]; + sizes[i] = size[i]; } + + functor::Slice<Device, T, NDIM>()( + context->eigen_device<Device>(), result->tensor<T, NDIM>(), + context->input(0).tensor<T, NDIM>(), indices, sizes); } }; @@ -248,16 +264,11 @@ class MklSliceOp : public OpKernel { } return; } - // Special case for handling 4-D tensor slice. - if (input_dims == 4) { - HandleCase4D(context, begin, size, result); - } else { -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - functor::Slice<Device, T, NDIM>()( \ - context->eigen_device<Device>(), result, input, begin, size); \ - return; \ - } +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + HandleCase<NDIM>(context, begin, size, result); \ + return; \ + } HANDLE_DIM(1); HANDLE_DIM(2); @@ -265,13 +276,12 @@ class MklSliceOp : public OpKernel { HANDLE_DIM(4); HANDLE_DIM(5); HANDLE_DIM(6); + HANDLE_DIM(7); #undef HANDLE_DIM - // handle cases which dim >= 7 - functor::Slice<Device, T, 7>()( - context->eigen_device<Device>(), result, input, begin, size); - } + OP_REQUIRES(context, false, errors::Unimplemented( + "SliceOp : Unhandled input dimensions")); } } @@ -318,7 +328,8 @@ class MklSliceOp : public OpKernel { return false; } - void HandleCase4D(OpKernelContext* context, + template <int NDIM> + void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, const gtl::ArraySlice<int64>& size, Tensor* result) { int slice_dim = -1; @@ -327,7 +338,8 @@ class MklSliceOp : public OpKernel { // differs from the input tensor in only 1 out of 4 dimensions. // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW // format over channel dimension. - if (DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { + if (NDIM == 4 && + DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { size_t in_strides[4] = { (size_t) in_shape.dim_size(1) * in_shape.dim_size(2) * in_shape.dim_size(3), @@ -391,8 +403,16 @@ class MklSliceOp : public OpKernel { // slice_dim is not 1 or 3, then we fallback to Eigen implementation. } - functor::Slice<Device, T, 4>()( - context->eigen_device<Device>(), result, context->input(0), begin, size); + Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; + Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; + for (int i = 0; i < NDIM; ++i) { + indices[i] = begin[i]; + sizes[i] = size[i]; + } + + functor::Slice<Device, T, NDIM>()( + context->eigen_device<Device>(), result->tensor<T, NDIM>(), + context->input(0).tensor<T, NDIM>(), indices, sizes); } }; #endif @@ -400,13 +420,13 @@ class MklSliceOp : public OpKernel { // Forward declarations of the functor specializations for declared in the // sharded source files. namespace functor { -#define DECLARE_CPU_SPEC(T, NDIM) \ - template <> \ - void Slice<CPUDevice, T, NDIM>::operator()( \ - const CPUDevice& d, Tensor* output, \ - const Tensor& input, \ - const gtl::ArraySlice<int64>& slice_indices, \ - const gtl::ArraySlice<int64>& slice_sizes); \ +#define DECLARE_CPU_SPEC(T, NDIM) \ + template <> \ + void Slice<CPUDevice, T, NDIM>::operator()( \ + const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ + typename TTypes<T, NDIM>::ConstTensor input, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ extern template struct Slice<CPUDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ @@ -456,14 +476,13 @@ REGISTER_SLICE(bfloat16); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T, NDIM) \ - template <> \ - void Slice<GPUDevice, T, NDIM>::operator()( \ - const GPUDevice& d, \ - Tensor* output, \ - const Tensor& input, \ - const gtl::ArraySlice<int64>& slice_indices, \ - const gtl::ArraySlice<int64>& slice_sizes); \ +#define DECLARE_GPU_SPEC(T, NDIM) \ + template <> \ + void Slice<GPUDevice, T, NDIM>::operator()( \ + const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \ + typename TTypes<T, NDIM>::ConstTensor input, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ extern template struct Slice<GPUDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ @@ -517,14 +536,13 @@ REGISTER_KERNEL_BUILDER(Name("Slice") #ifdef TENSORFLOW_USE_SYCL // Forward declarations of the functor specializations for SYCL. namespace functor { -#define DECLARE_SYCL_SPEC(T, NDIM) \ - template <> \ - void Slice<SYCLDevice, T, NDIM>::operator()( \ - const SYCLDevice& d, \ - Tensor* output, \ - const Tensor& input, \ - const gtl::ArraySlice<int64>& slice_indices, \ - const gtl::ArraySlice<int64>& slice_sizes); \ +#define DECLARE_SYCL_SPEC(T, NDIM) \ + template <> \ + void Slice<SYCLDevice, T, NDIM>::operator()( \ + const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\ + typename TTypes<T, NDIM>::ConstTensor input, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \ + const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \ extern template struct Slice<SYCLDevice, T, NDIM>; #define DECLARE_FOR_N(T) \ |