aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r--tensorflow/core/kernels/slice_op.cc116
1 files changed, 49 insertions, 67 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index d46701749b..28a379774b 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -190,41 +190,25 @@ class SliceOp : public OpKernel {
}
return;
}
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ functor::Slice<Device, T, NDIM>()( \
+ context->eigen_device<Device>(), result, input, begin, size); \
+ return; \
}
-
HANDLE_DIM(1);
HANDLE_DIM(2);
HANDLE_DIM(3);
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
- HANDLE_DIM(7);
#undef HANDLE_DIM
- OP_REQUIRES(context, false, errors::Unimplemented(
- "SliceOp : Unhandled input dimensions"));
- }
- }
-
- private:
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
+ // handle cases which dim >= 7
+ functor::Slice<Device, T, 7>()(
+ context->eigen_device<Device>(), result, input, begin, size);
}
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
@@ -264,11 +248,16 @@ class MklSliceOp : public OpKernel {
}
return;
}
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
+ // Special case for handling 4-D tensor slice.
+ if (input_dims == 4) {
+ HandleCase4D(context, begin, size, result);
+ } else {
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ functor::Slice<Device, T, NDIM>()( \
+ context->eigen_device<Device>(), result, input, begin, size); \
+ return; \
+ }
HANDLE_DIM(1);
HANDLE_DIM(2);
@@ -276,12 +265,13 @@ class MklSliceOp : public OpKernel {
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
- HANDLE_DIM(7);
#undef HANDLE_DIM
- OP_REQUIRES(context, false, errors::Unimplemented(
- "SliceOp : Unhandled input dimensions"));
+ // handle cases which dim >= 7
+ functor::Slice<Device, T, 7>()(
+ context->eigen_device<Device>(), result, input, begin, size);
+ }
}
}
@@ -328,8 +318,7 @@ class MklSliceOp : public OpKernel {
return false;
}
- template <int NDIM>
- void HandleCase(OpKernelContext* context,
+ void HandleCase4D(OpKernelContext* context,
const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
int slice_dim = -1;
@@ -338,8 +327,7 @@ class MklSliceOp : public OpKernel {
// differs from the input tensor in only 1 out of 4 dimensions.
// This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
// format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
+ if (DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
in_shape.dim_size(2) *
in_shape.dim_size(3),
@@ -403,16 +391,8 @@ class MklSliceOp : public OpKernel {
// slice_dim is not 1 or 3, then we fallback to Eigen implementation.
}
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
- }
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
+ functor::Slice<Device, T, 4>()(
+ context->eigen_device<Device>(), result, context->input(0), begin, size);
}
};
#endif
@@ -420,13 +400,13 @@ class MklSliceOp : public OpKernel {
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
-#define DECLARE_CPU_SPEC(T, NDIM) \
- template <> \
- void Slice<CPUDevice, T, NDIM>::operator()( \
- const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
- typename TTypes<T, NDIM>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
+#define DECLARE_CPU_SPEC(T, NDIM) \
+ template <> \
+ void Slice<CPUDevice, T, NDIM>::operator()( \
+ const CPUDevice& d, Tensor* output, \
+ const Tensor& input, \
+ const gtl::ArraySlice<int64>& slice_indices, \
+ const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<CPUDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \
@@ -476,13 +456,14 @@ REGISTER_SLICE(bfloat16);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
-#define DECLARE_GPU_SPEC(T, NDIM) \
- template <> \
- void Slice<GPUDevice, T, NDIM>::operator()( \
- const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
- typename TTypes<T, NDIM>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
+#define DECLARE_GPU_SPEC(T, NDIM) \
+ template <> \
+ void Slice<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, \
+ Tensor* output, \
+ const Tensor& input, \
+ const gtl::ArraySlice<int64>& slice_indices, \
+ const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<GPUDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \
@@ -536,13 +517,14 @@ REGISTER_KERNEL_BUILDER(Name("Slice")
#ifdef TENSORFLOW_USE_SYCL
// Forward declarations of the functor specializations for SYCL.
namespace functor {
-#define DECLARE_SYCL_SPEC(T, NDIM) \
- template <> \
- void Slice<SYCLDevice, T, NDIM>::operator()( \
- const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\
- typename TTypes<T, NDIM>::ConstTensor input, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
- const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
+#define DECLARE_SYCL_SPEC(T, NDIM) \
+ template <> \
+ void Slice<SYCLDevice, T, NDIM>::operator()( \
+ const SYCLDevice& d, \
+ Tensor* output, \
+ const Tensor& input, \
+ const gtl::ArraySlice<int64>& slice_indices, \
+ const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<SYCLDevice, T, NDIM>;
#define DECLARE_FOR_N(T) \