aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 12:54:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 12:57:46 -0700
commitf3c89936e97c99dead1ca3310246691c1b221adf (patch)
tree3c99b66936ed59028b32609115a239f52798907d /tensorflow/core/kernels/slice_op.cc
parent0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff)
Merge changes from github.
END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit 1699d904a authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit 03da61134 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commit e89f04d80 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit 15a8df724 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commit b58d98353 authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commit a286b7db8 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r--tensorflow/core/kernels/slice_op.cc258
1 files changed, 242 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index ee6f9a28cd..d46701749b 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -118,6 +118,43 @@ static void SharedValidation(OpKernelContext* context,
}
}
+// Extracted out code in SliceOp::Compute so that MklSliceOp can reuse this
+// generic code
+template <typename T>
+static void SharedSliceCommonCases(OpKernelContext* context,
+ TensorShape* output_shape,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ Tensor** result,
+ bool* done) {
+ bool is_identity = true;
+ bool slice_dim0 = true;
+ *done = false;
+
+ SharedValidation(context, output_shape, &is_identity, &slice_dim0, begin,
+ size);
+ if (!context->status().ok()) return;
+ const Tensor& input = context->input(0);
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ *done = true;
+ return;
+ }
+
+ if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), (*begin)[0],
+ (*size)[0])) {
+ VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
+ CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
+ context->set_output(0, input.Slice((*begin)[0], (*begin)[0] + (*size)[0]));
+ *done = true;
+ return;
+ }
+
+ OP_REQUIRES_OK(context, context->allocate_output(0, *output_shape, result));
+}
+
+
template <typename Device, typename T>
class SliceOp : public OpKernel {
public:
@@ -125,29 +162,89 @@ class SliceOp : public OpKernel {
void Compute(OpKernelContext* context) override {
TensorShape output_shape;
- bool is_identity = true;
- bool slice_dim0 = true;
gtl::InlinedVector<int64, 4> begin;
gtl::InlinedVector<int64, 4> size;
- SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin,
- &size);
- if (!context->status().ok()) return;
+ Tensor* result = nullptr;
+ bool done = false;
+ SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
+ &done);
+ if (!context->status().ok() || done == true) return;
+
const Tensor& input = context->input(0);
- if (is_identity) {
- VLOG(1) << "Slice identity";
- context->set_output(0, input);
- return;
+ const int input_dims = input.dims();
+
+ if (output_shape.num_elements() > 0) {
+ if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
+ DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ auto input = context->input(0).tensor<T, 2>();
+ auto output = result->tensor<T, 2>();
+ // TODO(agarwal): Consider multi-threading this loop for cases where
+ // size[0] is very large.
+ for (int i = 0; i < size[0]; ++i) {
+ const int64 row = begin[0] + i;
+ if (i + 1 < size[0]) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
+ }
+ memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
+ }
+ return;
+ }
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ HandleCase<NDIM>(context, begin, size, result); \
+ return; \
+ }
+
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+ HANDLE_DIM(6);
+ HANDLE_DIM(7);
+
+#undef HANDLE_DIM
+
+ OP_REQUIRES(context, false, errors::Unimplemented(
+ "SliceOp : Unhandled input dimensions"));
}
+ }
- if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], size[0])) {
- VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
- CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
- context->set_output(0, input.Slice(begin[0], begin[0] + size[0]));
- return;
+ private:
+ template <int NDIM>
+ void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size, Tensor* result) {
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
+ for (int i = 0; i < NDIM; ++i) {
+ indices[i] = begin[i];
+ sizes[i] = size[i];
}
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes);
+ }
+};
+
+#ifdef INTEL_MKL
+template <typename Device, typename T>
+class MklSliceOp : public OpKernel {
+ public:
+ explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ TensorShape output_shape;
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
Tensor* result = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+ bool done = false;
+ SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
+ &done);
+ if (!context->status().ok() || done == true) return;
+
+ const Tensor& input = context->input(0);
const int input_dims = input.dims();
if (output_shape.num_elements() > 0) {
@@ -189,9 +286,123 @@ class SliceOp : public OpKernel {
}
private:
+ // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
+ // criteria matches for slice_dim: if indices for slice are 0 in all dims
+ // except slice_dim and if sizes of all the dimensions of the slice are same
+ // as the sizes of all the dimensions of the input except slice_dim, then
+ // returns True. Otherwise, returns False.
+ bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
+ const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size,
+ int slice_dim) {
+ for (int dim = 0; dim < 4; dim++) {
+ if (dim != slice_dim &&
+ (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Is 'input' tensor being sliced over a single dimension out of 4?
+ //
+ // This check is applicable in the context of Slice of a 4-D tensor in
+ // NHWC or NCHW format over channel dimension.
+ //
+ // If indices for slice are 0 in all dims except one dimension and if sizes of
+ // all dimensions of slice are same as sizes of all dimensions of inputs
+ // except that dimension, then we are slicing over a single dimension.
+ //
+ // Returns True if Slicing over a single dimension, and sets slice_dim
+ // to the number of the dimension that satisfies criteria.
+ bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
+ const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size,
+ int* slice_dim) {
+ for (int dim = 0; dim < 4; dim++) {
+ if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
+ *slice_dim = dim;
+ return true;
+ }
+ }
+ return false;
+ }
+
template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ void HandleCase(OpKernelContext* context,
+ const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
+ int slice_dim = -1;
+ TensorShape in_shape = context->input(0).shape();
+ // Special case for handling 4-D tensor slice when shape of the slice
+ // differs from the input tensor in only 1 out of 4 dimensions.
+ // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
+ // format over channel dimension.
+ if (NDIM == 4 &&
+ DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
+ size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
+ in_shape.dim_size(2) *
+ in_shape.dim_size(3),
+ (size_t) in_shape.dim_size(2) *
+ in_shape.dim_size(3),
+ (size_t) in_shape.dim_size(3),
+ (size_t) 1
+ };
+
+ size_t out_strides[4] = { (size_t) size[1] * size[2] * size[3],
+ (size_t) size[2] * size[3],
+ (size_t) size[3],
+ (size_t) 1 };
+
+ T *in_buf = const_cast<T*>(const_cast<const T*>(
+ context->input(0).flat<T>().data()));
+ T *op_buf = result->flat<T>().data();
+
+ if (slice_dim == 1) {
+ /* data format = NCHW */
+
+ #pragma omp parallel for
+ for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ T *ip = in_buf + (d0 * in_strides[0]);
+ T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
+ #pragma omp parallel for
+ for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ T *ip1 = ip + (d1 * in_strides[1]);
+ T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
+ // For NCHW, H and W will be contiguous. So we can copy
+ // both with one memcpy.
+ memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
+ sizeof(T) * in_strides[1]);
+ }
+ }
+ return;
+ } else if (slice_dim == 3) {
+ /* data_format = NHWC */
+
+ #pragma omp parallel for
+ for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ T *ip = in_buf + (d0 * in_strides[0]);
+ T *op = op_buf + ((d0 - begin[0]) * out_strides[0]);
+ #pragma omp parallel for
+ for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ T *ip1 = ip + (d1 * in_strides[1]);
+ T *op1 = op + ((d1 - begin[1]) * out_strides[1]);
+ #pragma omp parallel for
+ for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
+ T *ip2 = ip1 + (d2 * in_strides[2]);
+ T *ip3 = ip2 + begin[3];
+ T *op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
+ T *op3 = op2;
+ memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
+ sizeof(T) * size[3]);
+ }
+ }
+ }
+ return;
+ }
+ // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
+ }
+
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
@@ -204,6 +415,7 @@ class SliceOp : public OpKernel {
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};
+#endif
// Forward declarations of the functor specializations for declared in the
// sharded source files.
@@ -233,6 +445,7 @@ DECLARE_FOR_N(bfloat16);
#undef DECLARE_CPU_SPEC
} // namespace functor
+#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -244,8 +457,21 @@ DECLARE_FOR_N(bfloat16);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
REGISTER_SLICE(bfloat16);
+#undef REGISTER_SLICE
+#else
+#define REGISTER_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("Slice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size"), \
+ MklSliceOp<CPUDevice, type>)
+TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
+TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
+REGISTER_SLICE(bfloat16);
#undef REGISTER_SLICE
+#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.