diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-09 16:00:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-09 17:04:53 -0700 |
commit | f51e1964b57ecce5f3b38d4156cf61e80e5c4181 (patch) | |
tree | bbea34065e0681d1a2145dc7fff495e3e99631b0 | |
parent | af02c57adec03a0c037420d29bfac2f8e047db0b (diff) |
Move shared validation for StridedSliceOp to a separate header, in preparation
for calling it from the shape fn.
Change: 132732667
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op.cc | 313 | ||||
-rw-r--r-- | tensorflow/core/util/strided_slice_op.cc | 296 | ||||
-rw-r--r-- | tensorflow/core/util/strided_slice_op.h | 42 |
4 files changed, 361 insertions, 291 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a3a75cee32..33d01682ec 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -287,6 +287,7 @@ tf_cuda_library( "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", "util/stream_executor_util.h", + "util/strided_slice_op.h", "util/tensor_format.h", "util/tensor_slice_reader.h", "util/tensor_slice_reader_cache.h", diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index b0dae752be..2accd7f3e5 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -35,286 +35,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/util/strided_slice_op.h" namespace tensorflow { -namespace { - -/// Constants -constexpr int32 kShrinkAxis = -1, kNewAxis = -2; - -// Sparse slicing specification -// if one does foo[3:5, ..., -3], this will have 3 length tensors -struct StridedSliceSparseSpec { - int64 dims; - int32 num_add_axis_after_ellipsis; - const Tensor& begin_tensor; - const Tensor& end_tensor; - const Tensor& strides_tensor; - const int32 begin_mask, end_mask; - int32 ellipsis_mask; - const int32 new_axis_mask, shrink_axis_mask; -}; - -// Dense slicing specification -// all ellipses and newaxis' are expanded out. So if -// foo[3:5, ..., -3] where foo is 10 dimensional, -// each inlinedVector will have 10 entries whereas the -// sparse had 3 length tensors. -struct StridedSliceDenseSpec { - const int64 dims; - int32 begin_mask; - int32 end_mask; - gtl::InlinedVector<int64, 4>& begin; - gtl::InlinedVector<int64, 4>& end; - gtl::InlinedVector<int64, 4>& strides; - // This vector helps construct the final shape of the slice. - // The final tensor is reduced in rank whenever a single index e.g. foo[3] - // is called for. The final tensor increases in rank with tf.newaxis - // entries. If an index in this array is positive, the size of the dimension - // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis, - // it will be 1. A shrunk dimension is skipped. - gtl::InlinedVector<int32, 4> final_shape_gather_indices; - // The dense indexed shrink mask is which processing dimensions - // should be shrunk. For example, if foo.shape = (10,10,10,10) - // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and - // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10). - int32 shrink_axis_mask; -}; - -} // namespace - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; - -template <class T> -static void BuildDenseSpec(const StridedSliceSparseSpec& sparse, - StridedSliceDenseSpec* dense) { - // Build expanded begin, end, strides, begin_mask, end_mask - // to remove any ellipsis - dense->begin.resize(dense->dims); - dense->end.resize(dense->dims); - dense->strides.resize(dense->dims); - // What indices to get the final shape from. - dense->begin_mask = 0; - dense->end_mask = 0; - dense->shrink_axis_mask = 0; - { - int full_index = 0; - - const auto& begin_flat = sparse.begin_tensor.flat<T>(); - const auto& end_flat = sparse.end_tensor.flat<T>(); - const auto& strides_flat = sparse.strides_tensor.flat<T>(); - - for (int i = 0; i < sparse.dims; i++) { - if ((1 << i) & sparse.ellipsis_mask) { - // Expand the ellipsis into the appropriate indices - // NOTE: this only works because we guaranteed one ellipsis - int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 + - sparse.num_add_axis_after_ellipsis, - dense->dims); - for (; full_index < next_index; full_index++) { - // new_axis' aren't real axis so you have to skip - dense->begin[full_index] = dense->end[full_index] = 0; - dense->strides[full_index] = 1; - dense->begin_mask |= (1 << full_index); - dense->end_mask |= (1 << full_index); - dense->final_shape_gather_indices.push_back(full_index); - } - } else if ((1 << i) & sparse.new_axis_mask) { - dense->final_shape_gather_indices.push_back(kNewAxis); - } else { - // Gather slicing spec into appropriate index - dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i)); - dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i)); - dense->strides[full_index] = - internal::SubtleMustCopy<T>(strides_flat(i)); - if (sparse.begin_mask & (1 << i)) { - dense->begin_mask |= (1 << full_index); - } - if (sparse.end_mask & (1 << i)) { - dense->end_mask |= (1 << full_index); - } - // If shrink, record where to get the dimensionality from (i.e. - // new_axis creates a fake 1 size dimension. Also remember shrink - // axis (now in dense form) so we can ignore dense->end below. - if (sparse.shrink_axis_mask & (1 << i)) { - dense->final_shape_gather_indices.push_back(kShrinkAxis); - dense->shrink_axis_mask |= (1 << full_index); - } else { - dense->final_shape_gather_indices.push_back(full_index); - } - full_index++; - } - } - } -} - -// Shared code that is not dependent on the type of T. We do this to reduce -// code size by not duplicating all this for all T (float, double, int32, etc.) -static void SharedValidation( - OpKernelContext* context, const TensorShape& input_shape, - int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, - int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape, - TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, - bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, - gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) { - const Tensor& begin_tensor = context->input(1); - const Tensor& end_tensor = context->input(2); - const Tensor& strides_tensor = context->input(3); - OP_REQUIRES( - context, TensorShapeUtils::IsVector(begin_tensor.shape()) && - TensorShapeUtils::IsVector(end_tensor.shape()) && - TensorShapeUtils::IsVector(strides_tensor.shape()) && - strides_tensor.dims() == 1 && - strides_tensor.dims() == begin_tensor.dims() && - strides_tensor.dims() == end_tensor.dims() && - begin_tensor.dim_size(0) == end_tensor.dim_size(0) && - begin_tensor.dim_size(0) == strides_tensor.dim_size(0) && - begin_tensor.dim_size(0) < 32, // using 32 bit masks - errors::InvalidArgument( - "Expected begin, end, and strides to be 1D equal size tensors, ", - "but got shapes ", begin_tensor.shape().DebugString(), ", ", - end_tensor.shape().DebugString(), ", and ", - strides_tensor.shape().DebugString(), " instead.")); - // Use bit compares to ensure ellipsis_mask is 0 or a power of 2 - // i.e. there exists only no more than one ellipsis - OP_REQUIRES(context, - !ellipsis_mask || (ellipsis_mask & (ellipsis_mask - 1)) == 0, - errors::InvalidArgument("Multiple ellipsis' in slice " - "spec not allowed")); - - // Step 1: Account for ellipsis and new axis - // - // Check for ellipses and count how many non-newaxis' there are after - // TODO(aselle): Convert this to do a fast log2 followed by iteration - // counting ones in next guys - bool ellipsis_seen = false; - - StridedSliceSparseSpec sparse_spec = {begin_tensor.NumElements(), - 0, - begin_tensor, - end_tensor, - strides_tensor, - begin_mask_spec, - end_mask_spec, - ellipsis_mask, - new_axis_mask, - shrink_axis_mask}; - - for (int32 i = 0; i < sparse_spec.dims; i++) { - if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { - sparse_spec.num_add_axis_after_ellipsis++; - } - if ((1 << i) & ellipsis_mask) { - ellipsis_seen = true; - } - } - // If no ellipsis insert one at the end - if (!ellipsis_seen) { - sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); - sparse_spec.dims++; // this effects loop iteration below - } - - // Step 2: Make a sparse spec into a full index spec - // - // The sparse spec does not corresopnds to the number of dimensions - // Make a dense spec that corresponds to thte number of dimensions - // - // For example suppose foo[...,3:] on foo.shape=(2,2,3) then - // we need to produce the missing begin_mask for the the first two - // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 - // we achieve begin_mask=6, end_mask=7 - StridedSliceDenseSpec dense_spec = { - input_shape.dims(), 0, 0, *begin, *end, *strides}; - - if (begin_tensor.dtype() == DT_INT32) { - BuildDenseSpec<int32>(sparse_spec, &dense_spec); - } else if (begin_tensor.dtype() == DT_INT64) { - BuildDenseSpec<int64>(sparse_spec, &dense_spec); - } else { - LOG(FATAL) << "begin must be either int32 or int64"; - } - - // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit - // and bounds check! - *is_identity = true; - *slice_dim0 = true; - *is_simple_slice = true; - for (int i = 0; i < dense_spec.dims; ++i) { - int64& begin_i = (*begin)[i]; - int64& end_i = (*end)[i]; - int64& stride_i = (*strides)[i]; - int64 dim_i = input_shape.dim_size(i); - OP_REQUIRES(context, stride_i != 0, - errors::InvalidArgument("strides[", i, "] must be non-zero")); - - int64 masks[] = {dense_spec.begin_mask & (1 << i), - dense_spec.end_mask & (1 << i)}; - int64 valid_range[] = {stride_i > 0 ? 0 : -1, - stride_i > 0 ? dim_i : dim_i - 1}; - - auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) { - if (masks[c]) { - return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; - } else { - int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive - return x_fwd < valid_range[0] - ? valid_range[0] - : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; - } - }; - if (dense_spec.shrink_axis_mask & (1 << i)) { - // If we are shrinking, the end index is now possibly incorrect. In - // particular foo[-1] produces sparse_begin = -1, sparse_end = 0. - // and canonical puts these to n-1 and 0, which implies a degenerate - // interval. Fortunately, it is now safe to re-create end as begin+1. - int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; - begin_i = x_fwd; - end_i = begin_i + 1; - OP_REQUIRES(context, stride_i > 0, - errors::InvalidArgument("only stride 1 allowed on" - " non-range indexing.")); - OP_REQUIRES( - context, x_fwd >= 0 && x_fwd < dim_i, - errors::InvalidArgument("slice index ", begin_i, " of dimension ", i, - " out of bounds.")); - } else { - begin_i = canonical(begin_i, 0); - end_i = canonical(end_i, 1); - } - // Update optimization values - (*is_simple_slice) &= stride_i == 1; - bool take_all_in_dimension = - stride_i == 1 && begin_i == 0 && end_i == input_shape.dim_size(i); - (*is_identity) &= take_all_in_dimension; - (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension; - - // Compute the processing shape (the intermediate Eigen will produce) - int64 interval_length = end_i - begin_i; - int64 size_i; - // Hold zero if the interval is degenerate, otherwise account for remainder - if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) - size_i = 0; - else - size_i = interval_length / stride_i + - (interval_length % stride_i != 0 ? 1 : 0); - processing_shape->AddDim(size_i); - } - - // Step 4: Compute the final shape - // - // new_axis will increase dimension by 1 (with a one-size dimension) - // slices like foo[3,...] will reduce dimension by 1. - // This cannot be done earlier, because it depends on Step 3. - for (auto gather_index : dense_spec.final_shape_gather_indices) { - if (gather_index >= 0) - final_shape->AddDim(processing_shape->dim_size(gather_index)); - else if (gather_index == kNewAxis) - final_shape->AddDim(1); - } -} - template <typename Device, typename T> class StridedSliceOp : public OpKernel { public: @@ -336,11 +60,13 @@ class StridedSliceOp : public OpKernel { gtl::InlinedVector<int64, 4> end; gtl::InlinedVector<int64, 4> strides; - SharedValidation(context, context->input(0).shape(), begin_mask, end_mask, - ellipsis_mask, new_axis_mask, shrink_axis_mask, - &processing_shape, &final_shape, &is_identity, - &is_simple_slice, &slice_dim0, &begin, &end, &strides); - if (!context->status().ok()) return; + OP_REQUIRES_OK(context, + ValidateStridedSliceOp( + context->input(1), context->input(2), context->input(3), + context->input(0).shape(), begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask, + &processing_shape, &final_shape, &is_identity, + &is_simple_slice, &slice_dim0, &begin, &end, &strides)); const Tensor& input = context->input(0); @@ -460,10 +186,13 @@ class StridedSliceGradOp : public OpKernel { LOG(FATAL) << "shape must have type int32 or int64."; } - SharedValidation(context, input_shape, begin_mask, end_mask, ellipsis_mask, - new_axis_mask, shrink_axis_mask, &processing_shape, - &final_shape, &is_identity, &is_simple_slice, &slice_dim0, - &begin, &end, &strides); + OP_REQUIRES_OK( + context, + ValidateStridedSliceOp( + context->input(1), context->input(2), context->input(3), + input_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, + shrink_axis_mask, &processing_shape, &final_shape, &is_identity, + &is_simple_slice, &slice_dim0, &begin, &end, &strides)); // Check to make sure dy is consistent with the original slice TensorShape dy_shape = context->input(4).shape(); @@ -527,11 +256,13 @@ class StridedSliceAssignOp : public OpKernel { context->forward_ref_input_to_ref_output(0, 0); Tensor old_lhs = context->mutable_input(0, true); - SharedValidation(context, old_lhs.shape(), begin_mask, end_mask, - ellipsis_mask, new_axis_mask, shrink_axis_mask, - &processing_shape, &final_shape, &is_identity, - &is_simple_slice, &slice_dim0, &begin, &end, &strides); - if (!context->status().ok()) return; + OP_REQUIRES_OK( + context, + ValidateStridedSliceOp( + context->input(1), context->input(2), context->input(3), + old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, + shrink_axis_mask, &processing_shape, &final_shape, &is_identity, + &is_simple_slice, &slice_dim0, &begin, &end, &strides)); if (processing_shape.num_elements()) { const Tensor& input = context->input(4); diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc new file mode 100644 index 0000000000..27f53939d2 --- /dev/null +++ b/tensorflow/core/util/strided_slice_op.cc @@ -0,0 +1,296 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/strided_slice_op.h" + +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +namespace { + +/// Constants +constexpr int32 kShrinkAxis = -1, kNewAxis = -2; + +// Sparse slicing specification +// if one does foo[3:5, ..., -3], this will have 3 length tensors +struct StridedSliceSparseSpec { + int64 dims; + int32 num_add_axis_after_ellipsis; + const Tensor& begin_tensor; + const Tensor& end_tensor; + const Tensor& strides_tensor; + const int32 begin_mask, end_mask; + int32 ellipsis_mask; + const int32 new_axis_mask, shrink_axis_mask; +}; + +// Dense slicing specification +// all ellipses and newaxis' are expanded out. So if +// foo[3:5, ..., -3] where foo is 10 dimensional, +// each inlinedVector will have 10 entries whereas the +// sparse had 3 length tensors. +struct StridedSliceDenseSpec { + const int64 dims; + int32 begin_mask; + int32 end_mask; + gtl::InlinedVector<int64, 4>& begin; + gtl::InlinedVector<int64, 4>& end; + gtl::InlinedVector<int64, 4>& strides; + // This vector helps construct the final shape of the slice. + // The final tensor is reduced in rank whenever a single index e.g. foo[3] + // is called for. The final tensor increases in rank with tf.newaxis + // entries. If an index in this array is positive, the size of the dimension + // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis, + // it will be 1. A shrunk dimension is skipped. + gtl::InlinedVector<int32, 4> final_shape_gather_indices; + // The dense indexed shrink mask is which processing dimensions + // should be shrunk. For example, if foo.shape = (10,10,10,10) + // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and + // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10). + int32 shrink_axis_mask; +}; + +} // namespace + +template <class T> +static void BuildDenseSpec(const StridedSliceSparseSpec& sparse, + StridedSliceDenseSpec* dense) { + // Build expanded begin, end, strides, begin_mask, end_mask + // to remove any ellipsis + dense->begin.resize(dense->dims); + dense->end.resize(dense->dims); + dense->strides.resize(dense->dims); + // What indices to get the final shape from. + dense->begin_mask = 0; + dense->end_mask = 0; + dense->shrink_axis_mask = 0; + { + int full_index = 0; + + const auto& begin_flat = sparse.begin_tensor.flat<T>(); + const auto& end_flat = sparse.end_tensor.flat<T>(); + const auto& strides_flat = sparse.strides_tensor.flat<T>(); + + for (int i = 0; i < sparse.dims; i++) { + if ((1 << i) & sparse.ellipsis_mask) { + // Expand the ellipsis into the appropriate indices + // NOTE: this only works because we guaranteed one ellipsis + int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 + + sparse.num_add_axis_after_ellipsis, + dense->dims); + for (; full_index < next_index; full_index++) { + // new_axis' aren't real axis so you have to skip + dense->begin[full_index] = dense->end[full_index] = 0; + dense->strides[full_index] = 1; + dense->begin_mask |= (1 << full_index); + dense->end_mask |= (1 << full_index); + dense->final_shape_gather_indices.push_back(full_index); + } + } else if ((1 << i) & sparse.new_axis_mask) { + dense->final_shape_gather_indices.push_back(kNewAxis); + } else { + // Gather slicing spec into appropriate index + dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i)); + dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i)); + dense->strides[full_index] = + internal::SubtleMustCopy<T>(strides_flat(i)); + if (sparse.begin_mask & (1 << i)) { + dense->begin_mask |= (1 << full_index); + } + if (sparse.end_mask & (1 << i)) { + dense->end_mask |= (1 << full_index); + } + // If shrink, record where to get the dimensionality from (i.e. + // new_axis creates a fake 1 size dimension. Also remember shrink + // axis (now in dense form) so we can ignore dense->end below. + if (sparse.shrink_axis_mask & (1 << i)) { + dense->final_shape_gather_indices.push_back(kShrinkAxis); + dense->shrink_axis_mask |= (1 << full_index); + } else { + dense->final_shape_gather_indices.push_back(full_index); + } + full_index++; + } + } + } +} + +Status ValidateStridedSliceOp( + const Tensor& begin_tensor, const Tensor& end_tensor, + const Tensor& strides_tensor, const TensorShape& input_shape, + int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, + int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape, + TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, + bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) { + if (!(TensorShapeUtils::IsVector(begin_tensor.shape()) && + TensorShapeUtils::IsVector(end_tensor.shape()) && + TensorShapeUtils::IsVector(strides_tensor.shape()) && + strides_tensor.dims() == 1 && + strides_tensor.dims() == begin_tensor.dims() && + strides_tensor.dims() == end_tensor.dims() && + begin_tensor.dim_size(0) == end_tensor.dim_size(0) && + begin_tensor.dim_size(0) == strides_tensor.dim_size(0) && + begin_tensor.dim_size(0) < 32 /* using 32 bit masks */)) { + return errors::InvalidArgument( + "Expected begin, end, and strides to be 1D equal size tensors, ", + "but got shapes ", begin_tensor.shape().DebugString(), ", ", + end_tensor.shape().DebugString(), ", and ", + strides_tensor.shape().DebugString(), " instead."); + } + // Use bit compares to ensure ellipsis_mask is 0 or a power of 2 + // i.e. there exists only no more than one ellipsis + if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) { + return errors::InvalidArgument( + "Multiple ellipsis' in slice spec not allowed"); + } + + // Step 1: Account for ellipsis and new axis + // + // Check for ellipses and count how many non-newaxis' there are after + // TODO(aselle): Convert this to do a fast log2 followed by iteration + // counting ones in next guys + bool ellipsis_seen = false; + + StridedSliceSparseSpec sparse_spec = {begin_tensor.NumElements(), + 0, + begin_tensor, + end_tensor, + strides_tensor, + begin_mask_spec, + end_mask_spec, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask}; + + for (int32 i = 0; i < sparse_spec.dims; i++) { + if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) { + sparse_spec.num_add_axis_after_ellipsis++; + } + if ((1 << i) & ellipsis_mask) { + ellipsis_seen = true; + } + } + // If no ellipsis insert one at the end + if (!ellipsis_seen) { + sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims); + sparse_spec.dims++; // this effects loop iteration below + } + + // Step 2: Make a sparse spec into a full index spec + // + // The sparse spec does not corresopnds to the number of dimensions + // Make a dense spec that corresponds to thte number of dimensions + // + // For example suppose foo[...,3:] on foo.shape=(2,2,3) then + // we need to produce the missing begin_mask for the the first two + // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2 + // we achieve begin_mask=6, end_mask=7 + StridedSliceDenseSpec dense_spec = { + input_shape.dims(), 0, 0, *begin, *end, *strides}; + + if (begin_tensor.dtype() == DT_INT32) { + BuildDenseSpec<int32>(sparse_spec, &dense_spec); + } else if (begin_tensor.dtype() == DT_INT64) { + BuildDenseSpec<int64>(sparse_spec, &dense_spec); + } else { + LOG(FATAL) << "begin must be either int32 or int64"; + } + + // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit + // and bounds check! + *is_identity = true; + *slice_dim0 = true; + *is_simple_slice = true; + for (int i = 0; i < dense_spec.dims; ++i) { + int64& begin_i = (*begin)[i]; + int64& end_i = (*end)[i]; + int64& stride_i = (*strides)[i]; + int64 dim_i = input_shape.dim_size(i); + if (stride_i == 0) { + return errors::InvalidArgument("strides[", i, "] must be non-zero"); + } + + int64 masks[] = {dense_spec.begin_mask & (1 << i), + dense_spec.end_mask & (1 << i)}; + int64 valid_range[] = {stride_i > 0 ? 0 : -1, + stride_i > 0 ? dim_i : dim_i - 1}; + + auto canonical = [stride_i, i, dim_i, masks, valid_range](int64 x, int c) { + if (masks[c]) { + return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1]; + } else { + int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive + return x_fwd < valid_range[0] + ? valid_range[0] + : x_fwd > valid_range[1] ? valid_range[1] : x_fwd; + } + }; + if (dense_spec.shrink_axis_mask & (1 << i)) { + // If we are shrinking, the end index is now possibly incorrect. In + // particular foo[-1] produces sparse_begin = -1, sparse_end = 0. + // and canonical puts these to n-1 and 0, which implies a degenerate + // interval. Fortunately, it is now safe to re-create end as begin+1. + int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i; + begin_i = x_fwd; + end_i = begin_i + 1; + if (stride_i <= 0) { + return errors::InvalidArgument( + "only stride 1 allowed on non-range indexing."); + } + if (x_fwd < 0 || x_fwd >= dim_i) { + return errors::InvalidArgument("slice index ", begin_i, + " of dimension ", i, " out of bounds."); + } + } else { + begin_i = canonical(begin_i, 0); + end_i = canonical(end_i, 1); + } + // Update optimization values + (*is_simple_slice) &= stride_i == 1; + bool take_all_in_dimension = + stride_i == 1 && begin_i == 0 && end_i == input_shape.dim_size(i); + (*is_identity) &= take_all_in_dimension; + (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension; + + // Compute the processing shape (the intermediate Eigen will produce) + int64 interval_length = end_i - begin_i; + int64 size_i; + // Hold zero if the interval is degenerate, otherwise account for remainder + if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) + size_i = 0; + else + size_i = interval_length / stride_i + + (interval_length % stride_i != 0 ? 1 : 0); + processing_shape->AddDim(size_i); + } + + // Step 4: Compute the final shape + // + // new_axis will increase dimension by 1 (with a one-size dimension) + // slices like foo[3,...] will reduce dimension by 1. + // This cannot be done earlier, because it depends on Step 3. + for (auto gather_index : dense_spec.final_shape_gather_indices) { + if (gather_index >= 0) + final_shape->AddDim(processing_shape->dim_size(gather_index)); + else if (gather_index == kNewAxis) + final_shape->AddDim(1); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/strided_slice_op.h b/tensorflow/core/util/strided_slice_op.h new file mode 100644 index 0000000000..84d308a3c6 --- /dev/null +++ b/tensorflow/core/util/strided_slice_op.h @@ -0,0 +1,42 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +// Runs validation on the strided slice op parameters. +// +// Is a separate translation unit from the kernel so that: +// 1. The op's shape function can use it. +// 2. The code size is reduced vs templating this on the kernel's type. +Status ValidateStridedSliceOp( + const Tensor& begin_tensor, const Tensor& end_tensor, + const Tensor& strides_tensor, const TensorShape& input_shape, + int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask, + int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape, + TensorShape* final_shape, bool* is_identity, bool* is_simple_slice, + bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_ |