aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r--tensorflow/core/kernels/slice_op.cc116
1 files changed, 67 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 28a379774b..d46701749b 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -190,25 +190,41 @@ class SliceOp : public OpKernel {
}
return;
}
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- functor::Slice<Device, T, NDIM>()( \
- context->eigen_device<Device>(), result, input, begin, size); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ HandleCase<NDIM>(context, begin, size, result); \
+ return; \
}
+
HANDLE_DIM(1);
HANDLE_DIM(2);
HANDLE_DIM(3);
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
+ HANDLE_DIM(7);
#undef HANDLE_DIM
- // handle cases which dim >= 7
- functor::Slice<Device, T, 7>()(
- context->eigen_device<Device>(), result, input, begin, size);
+ OP_REQUIRES(context, false, errors::Unimplemented(
+ "SliceOp : Unhandled input dimensions"));
+ }
+ }
+
+ private:
+ template <int NDIM>
+ void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size, Tensor* result) {
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
+ for (int i = 0; i < NDIM; ++i) {
+ indices[i] = begin[i];
+ sizes[i] = size[i];
}
+
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
@@ -248,16 +264,11 @@ class MklSliceOp : public OpKernel {
}
return;
}
- // Special case for handling 4-D tensor slice.
- if (input_dims == 4) {
- HandleCase4D(context, begin, size, result);
- } else {
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- functor::Slice<Device, T, NDIM>()( \
- context->eigen_device<Device>(), result, input, begin, size); \
- return; \
- }
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ HandleCase<NDIM>(context, begin, size, result); \
+ return; \
+ }
HANDLE_DIM(1);
HANDLE_DIM(2);
@@ -265,13 +276,12 @@ class MklSliceOp : public OpKernel {
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
+ HANDLE_DIM(7);
#undef HANDLE_DIM
- // handle cases which dim >= 7
- functor::Slice<Device, T, 7>()(
- context->eigen_device<Device>(), result, input, begin, size);
- }
+ OP_REQUIRES(context, false, errors::Unimplemented(
+ "SliceOp : Unhandled input dimensions"));
}
}
@@ -318,7 +328,8 @@ class MklSliceOp : public OpKernel {
return false;
}
- void HandleCase4D(OpKernelContext* context,
+ template <int NDIM>
+ void HandleCase(OpKernelContext* context,
const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
int slice_dim = -1;
@@ -327,7 +338,8 @@ class MklSliceOp : public OpKernel {
// differs from the input tensor in only 1 out of 4 dimensions.
// This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
// format over channel dimension.
- if (DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
+ if (NDIM == 4 &&
+ DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
in_shape.dim_size(2) *
in_shape.dim_size(3),
@@ -391,8 +403,16 @@ class MklSliceOp : public OpKernel {
// slice_dim is not 1 or 3, then we fallback to Eigen implementation.
}
- functor::Slice<Device, T, 4>()(
- context->eigen_device<Device>(), result, context->input(0), begin, size);
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
+ for (int i = 0; i < NDIM; ++i) {
+ indices[i] = begin[i];
+ sizes[i] = size[i];
+ }
+
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
#endif
@@ -400,13 +420,13 @@ class MklSliceOp : public OpKernel {
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
-#define DECLARE_CPU_SPEC(T, NDIM) \
- template <> \
- void Slice<CPUDevice, T, NDIM>::operator()( \
- const CPUDevice& d, Tensor* output, \
- const Tensor& input, \
- const gtl::ArraySlice<int64>& slice_indices, \
- const gtl::ArraySlice<int64>& slice_sizes); \
+#define DECLARE_CPU_SPEC(T, NDIM) \
+ template <> \
+ void Slice<CPUDevice, T, NDIM>::operator()( \
+ const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
+ typename TTypes<T, NDIM>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
extern template struct Slice<CPUDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \
@@ -456,14 +476,13 @@ REGISTER_SLICE(bfloat16);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T, NDIM) \
- template <> \
- void Slice<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, \
- Tensor* output, \
- const Tensor& input, \
- const gtl::ArraySlice<int64>& slice_indices, \
- const gtl::ArraySlice<int64>& slice_sizes); \
+#define DECLARE_GPU_SPEC(T, NDIM) \
+ template <> \
+ void Slice<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
+ typename TTypes<T, NDIM>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
extern template struct Slice<GPUDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \
@@ -517,14 +536,13 @@ REGISTER_KERNEL_BUILDER(Name("Slice")
#ifdef TENSORFLOW_USE_SYCL
// Forward declarations of the functor specializations for SYCL.
namespace functor {
-#define DECLARE_SYCL_SPEC(T, NDIM) \
- template <> \
- void Slice<SYCLDevice, T, NDIM>::operator()( \
- const SYCLDevice& d, \
- Tensor* output, \
- const Tensor& input, \
- const gtl::ArraySlice<int64>& slice_indices, \
- const gtl::ArraySlice<int64>& slice_sizes); \
+#define DECLARE_SYCL_SPEC(T, NDIM) \
+ template <> \
+ void Slice<SYCLDevice, T, NDIM>::operator()( \
+ const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\
+ typename TTypes<T, NDIM>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
extern template struct Slice<SYCLDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \