aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Scott Tseng <scott.tseng@kikatech.com>2018-03-08 15:29:18 +0800
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-03-07 23:29:18 -0800
commitab1cab51265f8b0fb38d007a1d3d93a857ca864d (patch)
tree48adf87ef2ecbbf45314ba563408dac7b935c046
parent584aa04bfc816a6cf9f0390d33c3595837355935 (diff)
Fix a bug in tf.strided_slice() (#16989)
Current implementation modifies TfLiteNode::builtin_data every time when a loaded graph is executed. The three masks in params will continually flipping, and cause the op produce incorrect result every two executions.
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index fb1e11e0ca..3907a6620d 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -48,7 +48,7 @@ struct StridedSliceContext {
output = GetOutput(context, node, kOutputTensor);
dims = NumDimensions(input);
}
- TfLiteStridedSliceParams* params;
+ const TfLiteStridedSliceParams* params;
TfLiteTensor* input;
TfLiteTensor* begin;
TfLiteTensor* end;
@@ -199,18 +199,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
strides.emplace_back(1);
}
- op_context.params->begin_mask =
+ int begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
- op_context.params->end_mask =
+ int end_mask =
ReverseMaskBits(op_context.params->end_mask, op_context.dims);
- op_context.params->shrink_axis_mask =
+ int shrink_axis_mask =
ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
kernel_type::StridedSlice( \
GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), op_context.params->begin_mask, \
- op_context.params->end_mask, op_context.params->shrink_axis_mask, \
+ GetTensorDims(op_context.input), \
+ begin_mask, end_mask, shrink_axis_mask, \
starts, stops, strides, GetTensorData<data_type>(op_context.output), \
GetTensorDims(op_context.output))