aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r--tensorflow/core/kernels/slice_op.cc258
1 files changed, 16 insertions, 242 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index d46701749b..ee6f9a28cd 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -118,43 +118,6 @@ static void SharedValidation(OpKernelContext* context,
}
}
-// Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
-// generic code
-template <typename T>
-static void SharedSliceCommonCases(OpKernelContext* context,
- TensorShape* output_shape,
- gtl::InlinedVector<int64, 4>* begin,
- gtl::InlinedVector<int64, 4>* size,
- Tensor** result,
- bool* done) {
- bool is_identity = true;
- bool slice_dim0 = true;
- *done = false;
-
- SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin,
- size);
- if (!context->status().ok()) return;
- const Tensor& input = context->input(0);
- if (is_identity) {
- VLOG(1) << "Slice identity";
- context->set_output(0, input);
- *done = true;
- return;
- }
-
- if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), (*begin)[0],
- (*size)[0])) {
- VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
- CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
- context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
- *done = true;
- return;
- }
-
- OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result));
-}
-
-
template <typename Device, typename T>
class SliceOp : public OpKernel {
public:
@@ -162,89 +125,29 @@ class SliceOp : public OpKernel {
void Compute(OpKernelContext* context) override {
TensorShape output_shape;
+ bool is_identity = true;
+ bool slice_dim0 = true;
gtl::InlinedVector<int64, 4> begin;
gtl::InlinedVector<int64, 4> size;
- Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
+ SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin,
+ &size);
+ if (!context->status().ok()) return;
const Tensor& input = context->input(0);
- const int input_dims = input.dims();
-
- if (output_shape.num_elements() > 0) {
- if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
- DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
- auto output = result->tensor<T, 2>();
- // TODO(agarwal): Consider multi-threading this loop for cases where
- // size[0] is very large.
- for (int i = 0; i < size[0]; ++i) {
- const int64 row = begin[0] + i;
- if (i + 1 < size[0]) {
- port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
- }
- memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
- }
- return;
- }
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
-
- HANDLE_DIM(1);
- HANDLE_DIM(2);
- HANDLE_DIM(3);
- HANDLE_DIM(4);
- HANDLE_DIM(5);
- HANDLE_DIM(6);
- HANDLE_DIM(7);
-
-#undef HANDLE_DIM
-
- OP_REQUIRES(context, false, errors::Unimplemented(
- "SliceOp : Unhandled input dimensions"));
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ return;
}
- }
- private:
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
+ if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], size[0])) {
+ VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
+ CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
+ context->set_output(0, input.Slice(begin[0], begin[0] + size[0]));
+ return;
}
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
- }
-};
-
-#ifdef INTEL_MKL
-template <typename Device, typename T>
-class MklSliceOp : public OpKernel {
- public:
- explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- TensorShape output_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> size;
Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
- const Tensor& input = context->input(0);
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
const int input_dims = input.dims();
if (output_shape.num_elements() > 0) {
@@ -286,123 +189,9 @@ class MklSliceOp : public OpKernel {
}
private:
- // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
- // criteria matches for slice_dim: if indices for slice are 0 in all dims
- // except slice_dim and if sizes of all the dimensions of the slice are same
- // as the sizes of all the dimensions of the input except slice_dim, then
- // returns True. Otherwise, returns False.
- bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (dim != slice_dim &&
- (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
- return false;
- }
- }
- return true;
- }
-
- // Is 'input' tensor being sliced over a single dimension out of 4?
- //
- // This check is applicable in the context of Slice of a 4-D tensor in
- // NHWC or NCHW format over channel dimension.
- //
- // If indices for slice are 0 in all dims except one dimension and if sizes of
- // all dimensions of slice are same as sizes of all dimensions of inputs
- // except that dimension, then we are slicing over a single dimension.
- //
- // Returns True if Slicing over a single dimension, and sets slice_dim
- // to the number of the dimension that satisfies criteria.
- bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int* slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
- *slice_dim = dim;
- return true;
- }
- }
- return false;
- }
-
template <int NDIM>
- void HandleCase(OpKernelContext* context,
- const gtl::ArraySlice<int64>& begin,
+ void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
- int slice_dim = -1;
- TensorShape in_shape = context->input(0).shape();
- // Special case for handling 4-D tensor slice when shape of the slice
- // differs from the input tensor in only 1 out of 4 dimensions.
- // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
- // format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
- size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
- in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t) in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t) in_shape.dim_size(3),
- (size_t) 1
- };
-
- size_t out_strides[4] = { (size_t) size[1] * size[2] * size[3],
- (size_t) size[2] * size[3],
- (size_t) size[3],
- (size_t) 1 };
-
- T *in_buf = const_cast<T*>(const_cast<const T*>(
- context->input(0).flat<T>().data()));
- T *op_buf = result->flat<T>().data();
-
- if (slice_dim == 1) {
- /* data format = NCHW */
-
- #pragma omp parallel for
- for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T *ip = in_buf + (d0 * in_strides[0]);
- T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
- #pragma omp parallel for
- for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T *ip1 = ip + (d1 * in_strides[1]);
- T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
- // For NCHW, H and W will be contiguous. So we can copy
- // both with one memcpy.
- memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
- sizeof(T) * in_strides[1]);
- }
- }
- return;
- } else if (slice_dim == 3) {
- /* data_format = NHWC */
-
- #pragma omp parallel for
- for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T *ip = in_buf + (d0 * in_strides[0]);
- T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
- #pragma omp parallel for
- for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T *ip1 = ip + (d1 * in_strides[1]);
- T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
- #pragma omp parallel for
- for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
- T *ip2 = ip1 + (d2 * in_strides[2]);
- T *ip3 = ip2 + begin[3];
- T *op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
- T *op3 = op2;
- memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
- sizeof(T) * size[3]);
- }
- }
- }
- return;
- }
- // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
- }
-
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
@@ -415,7 +204,6 @@ class MklSliceOp : public OpKernel {
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
-#endif
// Forward declarations of the functor specializations for declared in the
// sharded source files.
@@ -445,7 +233,6 @@ DECLARE_FOR_N(bfloat16);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -457,21 +244,8 @@ DECLARE_FOR_N(bfloat16);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
REGISTER_SLICE(bfloat16);
-#undef REGISTER_SLICE
-#else
-#define REGISTER_SLICE(type) \
- REGISTER_KERNEL_BUILDER(Name("Slice") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("begin") \
- .HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-REGISTER_SLICE(bfloat16);
#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.