diff options
author | 2018-03-08 15:29:18 +0800 | |
---|---|---|
committer | 2018-03-07 23:29:18 -0800 | |
commit | ab1cab51265f8b0fb38d007a1d3d93a857ca864d (patch) | |
tree | 48adf87ef2ecbbf45314ba563408dac7b935c046 | |
parent | 584aa04bfc816a6cf9f0390d33c3595837355935 (diff) |
Fix a bug in tf.strided_slice() (#16989)
Current implementation modifies TfLiteNode::builtin_data every time when
a loaded graph is executed. The three masks in params will continually
flipping, and cause the op produce incorrect result every two executions.
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice.cc | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index fb1e11e0ca..3907a6620d 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -48,7 +48,7 @@ struct StridedSliceContext { output = GetOutput(context, node, kOutputTensor); dims = NumDimensions(input); } - TfLiteStridedSliceParams* params; + const TfLiteStridedSliceParams* params; TfLiteTensor* input; TfLiteTensor* begin; TfLiteTensor* end; @@ -199,18 +199,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { strides.emplace_back(1); } - op_context.params->begin_mask = + int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); - op_context.params->end_mask = + int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - op_context.params->shrink_axis_mask = + int shrink_axis_mask = ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); #define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ kernel_type::StridedSlice( \ GetTensorData<data_type>(op_context.input), \ - GetTensorDims(op_context.input), op_context.params->begin_mask, \ - op_context.params->end_mask, op_context.params->shrink_axis_mask, \ + GetTensorDims(op_context.input), \ + begin_mask, end_mask, shrink_axis_mask, \ starts, stops, strides, GetTensorData<data_type>(op_context.output), \ GetTensorDims(op_context.output)) |