diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-17 13:20:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-17 13:23:19 -0700 |
commit | 83418120b7c2659fedddd7c85b65d3c3e6aa94e3 (patch) | |
tree | 3f8cbb0db3ee2c059a8eb99f6956b6a78f088497 /tensorflow/contrib/lite/kernels/strided_slice.cc | |
parent | 33d55d7caff2bd32fa2b1c5cacb7ac251c48e27d (diff) |
Fixing a bug in strided slice. The op was not handling negative indices correctly.
PiperOrigin-RevId: 193245539
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice.cc | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index e6d5c300dc..40ac436b7d 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -87,6 +87,8 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { std::min(std::max(index, -dim), dim - 1), dim)); } +// TODO(b/77971377) this logic should be removed, as it's a duplication of +// StartForAxis() & StopForAxis() in kernels/internal/reference/reference_ops.h inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { const int dim = op_context->input->dims->data[idx]; const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0; @@ -188,8 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { std::vector<int32_t> strides; for (int idx = op_context.dims - 1; idx >= 0; --idx) { - starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); - stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); + starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]); + stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]); strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]); } @@ -202,15 +204,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - int shrink_axis_mask = - ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice( \ - GetTensorData<data_type>(op_context.input), \ - GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ - starts, stops, strides, GetTensorData<data_type>(op_context.output), \ - GetTensorDims(op_context.output)) + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice(GetTensorData<data_type>(op_context.input), \ + GetTensorDims(op_context.input), begin_mask, \ + end_mask, starts, stops, strides, \ + GetTensorData<data_type>(op_context.output), \ + GetTensorDims(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: |