aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/strided_slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 13:20:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 13:23:19 -0700
commit83418120b7c2659fedddd7c85b65d3c3e6aa94e3 (patch)
tree3f8cbb0db3ee2c059a8eb99f6956b6a78f088497 /tensorflow/contrib/lite/kernels/strided_slice.cc
parent33d55d7caff2bd32fa2b1c5cacb7ac251c48e27d (diff)
Fixing a bug in strided slice. The op was not handling negative indices correctly.
PiperOrigin-RevId: 193245539
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index e6d5c300dc..40ac436b7d 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -87,6 +87,8 @@ inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
std::min(std::max(index, -dim), dim - 1), dim));
}
+// TODO(b/77971377) this logic should be removed, as it's a duplication of
+// StartForAxis() & StopForAxis() in kernels/internal/reference/reference_ops.h
inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
const int dim = op_context->input->dims->data[idx];
const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
@@ -188,8 +190,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<int32_t> strides;
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- starts.emplace_back(GetBeginValueAtIndex(&op_context, idx));
- stops.emplace_back(GetEndValueAtIndex(&op_context, idx));
+ starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
+ stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
}
@@ -202,15 +204,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
- int shrink_axis_mask =
- ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice( \
- GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
- starts, stops, strides, GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice(GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), begin_mask, \
+ end_mask, starts, stops, strides, \
+ GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: