aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc44
1 files changed, 9 insertions, 35 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 8037e90791..6af4bd0496 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -72,55 +72,29 @@ class StridedSliceOp : public XlaOpKernel {
&dummy, &dummy, &dummy, &begin, &end, &strides));
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
- gtl::InlinedVector<int64, 4> slice_begin, slice_end;
- bool simple_strides = true;
+ gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
+
for (int i = 0; i < begin.size(); ++i) {
- simple_strides &= (std::abs(strides[i]) == 1);
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_end.push_back(end[i]);
+ slice_strides.push_back(strides[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
- slice_begin.push_back(end[i] + 1);
- slice_end.push_back(begin[i] + 1);
+ slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
+ slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
+ slice_strides.push_back(-strides[i]);
dimensions_to_reverse.push_back(i);
}
}
- xla::ComputationDataHandle slice =
- ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end);
+
+ xla::ComputationDataHandle slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
}
- // If at least one of the strides is > 1 (or < -1) then use Slice
- // to pull out each of the strided slices, and Concat to put them
- // together again.
- if (!simple_strides) {
- // Re-adjust the begin and end now that the periphery has been
- // sliced away.
- for (int d = 0; d < strides.size(); ++d) {
- slice_end[d] -= slice_begin[d];
- slice_begin[d] = 0;
- }
-
- for (int d = 0; d < strides.size(); ++d) {
- int64 stride = std::abs(strides[d]);
- if (stride > 1) {
- std::vector<xla::ComputationDataHandle> to_concat;
- int64 end = slice_end[d];
- for (int64 i = 0; i < end; i += stride) {
- slice_begin[d] = i;
- slice_end[d] = i + 1;
- to_concat.push_back(
- ctx->builder()->Slice(slice, slice_begin, slice_end));
- }
- slice = ctx->builder()->ConcatInDim(to_concat, d);
- slice_begin[d] = 0;
- slice_end[d] = to_concat.size();
- }
- }
- }
+ slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);