diff options
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/slice_op.cc | 258 |
1 files changed, 16 insertions, 242 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index d46701749b..ee6f9a28cd 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -118,43 +118,6 @@ static void SharedValidation(OpKernelContext* context, } } -// Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this -// generic code -template <typename T> -static void SharedSliceCommonCases(OpKernelContext* context, - TensorShape* output_shape, - gtl::InlinedVector<int64, 4>* begin, - gtl::InlinedVector<int64, 4>* size, - Tensor** result, - bool* done) { - bool is_identity = true; - bool slice_dim0 = true; - *done = false; - - SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin, - size); - if (!context->status().ok()) return; - const Tensor& input = context->input(0); - if (is_identity) { - VLOG(1) << "Slice identity"; - context->set_output(0, input); - *done = true; - return; - } - - if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), (*begin)[0], - (*size)[0])) { - VLOG(1) << "Slice dim 0: " << input.shape().DebugString(); - CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. - context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0])); - *done = true; - return; - } - - OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result)); -} - - template <typename Device, typename T> class SliceOp : public OpKernel { public: @@ -162,89 +125,29 @@ class SliceOp : public OpKernel { void Compute(OpKernelContext* context) override { TensorShape output_shape; + bool is_identity = true; + bool slice_dim0 = true; gtl::InlinedVector<int64, 4> begin; gtl::InlinedVector<int64, 4> size; - Tensor* result = nullptr; - bool done = false; - SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result, - &done); - if (!context->status().ok() || done == true) return; - + SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin, + &size); + if (!context->status().ok()) return; const Tensor& input = context->input(0); - const int input_dims = input.dims(); - - if (output_shape.num_elements() > 0) { - if (std::is_same<Device, CPUDevice>::value && input_dims == 2 && - DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - auto input = context->input(0).tensor<T, 2>(); - auto output = result->tensor<T, 2>(); - // TODO(agarwal): Consider multi-threading this loop for cases where - // size[0] is very large. - for (int i = 0; i < size[0]; ++i) { - const int64 row = begin[0] + i; - if (i + 1 < size[0]) { - port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1])); - } - memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T)); - } - return; - } -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - HandleCase<NDIM>(context, begin, size, result); \ - return; \ - } - - HANDLE_DIM(1); - HANDLE_DIM(2); - HANDLE_DIM(3); - HANDLE_DIM(4); - HANDLE_DIM(5); - HANDLE_DIM(6); - HANDLE_DIM(7); - -#undef HANDLE_DIM - - OP_REQUIRES(context, false, errors::Unimplemented( - "SliceOp : Unhandled input dimensions")); + if (is_identity) { + VLOG(1) << "Slice identity"; + context->set_output(0, input); + return; } - } - private: - template <int NDIM> - void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, Tensor* result) { - Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; - Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; - for (int i = 0; i < NDIM; ++i) { - indices[i] = begin[i]; - sizes[i] = size[i]; + if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], size[0])) { + VLOG(1) << "Slice dim 0: " << input.shape().DebugString(); + CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. + context->set_output(0, input.Slice(begin[0], begin[0] + size[0])); + return; } - functor::Slice<Device, T, NDIM>()( - context->eigen_device<Device>(), result->tensor<T, NDIM>(), - context->input(0).tensor<T, NDIM>(), indices, sizes); - } -}; - -#ifdef INTEL_MKL -template <typename Device, typename T> -class MklSliceOp : public OpKernel { - public: - explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - TensorShape output_shape; - gtl::InlinedVector<int64, 4> begin; - gtl::InlinedVector<int64, 4> size; Tensor* result = nullptr; - bool done = false; - SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result, - &done); - if (!context->status().ok() || done == true) return; - - const Tensor& input = context->input(0); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result)); const int input_dims = input.dims(); if (output_shape.num_elements() > 0) { @@ -286,123 +189,9 @@ class MklSliceOp : public OpKernel { } private: - // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following - // criteria matches for slice_dim: if indices for slice are 0 in all dims - // except slice_dim and if sizes of all the dimensions of the slice are same - // as the sizes of all the dimensions of the input except slice_dim, then - // returns True. Otherwise, returns False. - bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape, - const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, - int slice_dim) { - for (int dim = 0; dim < 4; dim++) { - if (dim != slice_dim && - (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) { - return false; - } - } - return true; - } - - // Is 'input' tensor being sliced over a single dimension out of 4? - // - // This check is applicable in the context of Slice of a 4-D tensor in - // NHWC or NCHW format over channel dimension. - // - // If indices for slice are 0 in all dims except one dimension and if sizes of - // all dimensions of slice are same as sizes of all dimensions of inputs - // except that dimension, then we are slicing over a single dimension. - // - // Returns True if Slicing over a single dimension, and sets slice_dim - // to the number of the dimension that satisfies criteria. - bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape, - const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, - int* slice_dim) { - for (int dim = 0; dim < 4; dim++) { - if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) { - *slice_dim = dim; - return true; - } - } - return false; - } - template <int NDIM> - void HandleCase(OpKernelContext* context, - const gtl::ArraySlice<int64>& begin, + void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, const gtl::ArraySlice<int64>& size, Tensor* result) { - int slice_dim = -1; - TensorShape in_shape = context->input(0).shape(); - // Special case for handling 4-D tensor slice when shape of the slice - // differs from the input tensor in only 1 out of 4 dimensions. - // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW - // format over channel dimension. - if (NDIM == 4 && - DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { - size_t in_strides[4] = { (size_t) in_shape.dim_size(1) * - in_shape.dim_size(2) * - in_shape.dim_size(3), - (size_t) in_shape.dim_size(2) * - in_shape.dim_size(3), - (size_t) in_shape.dim_size(3), - (size_t) 1 - }; - - size_t out_strides[4] = { (size_t) size[1] * size[2] * size[3], - (size_t) size[2] * size[3], - (size_t) size[3], - (size_t) 1 }; - - T *in_buf = const_cast<T*>(const_cast<const T*>( - context->input(0).flat<T>().data())); - T *op_buf = result->flat<T>().data(); - - if (slice_dim == 1) { - /* data format = NCHW */ - - #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T *ip = in_buf + (d0 * in_strides[0]); - T *op = op_buf + ((d0 - begin[0]) * out_strides[0]); - #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T *ip1 = ip + (d1 * in_strides[1]); - T *op1 = op + ((d1 - begin[1]) * out_strides[1]); - // For NCHW, H and W will be contiguous. So we can copy - // both with one memcpy. - memcpy(static_cast<void*>(op1), static_cast<void*>(ip1), - sizeof(T) * in_strides[1]); - } - } - return; - } else if (slice_dim == 3) { - /* data_format = NHWC */ - - #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T *ip = in_buf + (d0 * in_strides[0]); - T *op = op_buf + ((d0 - begin[0]) * out_strides[0]); - #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T *ip1 = ip + (d1 * in_strides[1]); - T *op1 = op + ((d1 - begin[1]) * out_strides[1]); - #pragma omp parallel for - for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { - T *ip2 = ip1 + (d2 * in_strides[2]); - T *ip3 = ip2 + begin[3]; - T *op2 = op1 + ((d2 - begin[2]) * out_strides[2]); - T *op3 = op2; - memcpy(static_cast<void*>(op3), static_cast<void*>(ip3), - sizeof(T) * size[3]); - } - } - } - return; - } - // slice_dim is not 1 or 3, then we fallback to Eigen implementation. - } - Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; for (int i = 0; i < NDIM; ++i) { @@ -415,7 +204,6 @@ class MklSliceOp : public OpKernel { context->input(0).tensor<T, NDIM>(), indices, sizes); } }; -#endif // Forward declarations of the functor specializations for declared in the // sharded source files. @@ -445,7 +233,6 @@ DECLARE_FOR_N(bfloat16); #undef DECLARE_CPU_SPEC } // namespace functor -#ifndef INTEL_MKL #define REGISTER_SLICE(type) \ REGISTER_KERNEL_BUILDER(Name("Slice") \ .Device(DEVICE_CPU) \ @@ -457,21 +244,8 @@ DECLARE_FOR_N(bfloat16); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); REGISTER_SLICE(bfloat16); -#undef REGISTER_SLICE -#else -#define REGISTER_SLICE(type) \ - REGISTER_KERNEL_BUILDER(Name("Slice") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .HostMemory("begin") \ - .HostMemory("size"), \ - MklSliceOp<CPUDevice, type>) -TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); -TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -REGISTER_SLICE(bfloat16); #undef REGISTER_SLICE -#endif // INTEL_MKL #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. |