diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc | 44 |
1 files changed, 9 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 8037e90791..6af4bd0496 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -72,55 +72,29 @@ class StridedSliceOp : public XlaOpKernel { &dummy, &dummy, &dummy, &begin, &end, &strides)); gtl::InlinedVector<int64, 4> dimensions_to_reverse; - gtl::InlinedVector<int64, 4> slice_begin, slice_end; - bool simple_strides = true; + gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides; + for (int i = 0; i < begin.size(); ++i) { - simple_strides &= (std::abs(strides[i]) == 1); if (strides[i] > 0) { slice_begin.push_back(begin[i]); slice_end.push_back(end[i]); + slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(end[i] + 1); - slice_end.push_back(begin[i] + 1); + slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); + slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } } - xla::ComputationDataHandle slice = - ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end); + + xla::ComputationDataHandle slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } - // If at least one of the strides is > 1 (or < -1) then use Slice - // to pull out each of the strided slices, and Concat to put them - // together again. - if (!simple_strides) { - // Re-adjust the begin and end now that the periphery has been - // sliced away. - for (int d = 0; d < strides.size(); ++d) { - slice_end[d] -= slice_begin[d]; - slice_begin[d] = 0; - } - - for (int d = 0; d < strides.size(); ++d) { - int64 stride = std::abs(strides[d]); - if (stride > 1) { - std::vector<xla::ComputationDataHandle> to_concat; - int64 end = slice_end[d]; - for (int64 i = 0; i < end; i += stride) { - slice_begin[d] = i; - slice_end[d] = i + 1; - to_concat.push_back( - ctx->builder()->Slice(slice, slice_begin, slice_end)); - } - slice = ctx->builder()->ConcatInDim(to_concat, d); - slice_begin[d] = 0; - slice_end[d] = to_concat.size(); - } - } - } + slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); |