diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 06:12:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 06:16:42 -0700 |
commit | abf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch) | |
tree | 5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/strided_slice.cc | |
parent | 19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/strided_slice.cc | 48 |
1 files changed, 19 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 87ffcc4110..06b36dd196 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -57,17 +57,6 @@ struct StridedSliceContext { int dims; }; -// Reverse order of bits in the mask to match the expected order in kernel -inline int ReverseMaskBits(int mask, int num_dimensions) { - int out = 0; - for (int dim = 0; dim < num_dimensions; dim++) { - out <<= 1; - out += (mask & 1); - mask >>= 1; - } - return out; -} - // This Op only supports 1-4D cases and since we use the reference 4D // implementation, the 1-3D tensors are mapped to 4D. const int kMaxDim = 4; @@ -198,30 +187,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { std::vector<int32_t> stops; std::vector<int32_t> strides; - for (int idx = op_context.dims - 1; idx >= 0; --idx) { - starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]); - stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]); - strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]); - } - for (int i = op_context.dims; i < kMaxDim; i++) { starts.emplace_back(0); stops.emplace_back(1); strides.emplace_back(1); } - int begin_mask = - ReverseMaskBits(op_context.params->begin_mask, op_context.dims); - int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - int shrink_axis_mask = - ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice( \ - GetTensorData<data_type>(op_context.input), \ - GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ - starts, stops, strides, GetTensorData<data_type>(op_context.output), \ - GetTensorDims(op_context.output)) + for (int idx = 0; idx < op_context.dims; ++idx) { + starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]); + stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]); + strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]); + } + + int begin_mask = op_context.params->begin_mask << (4 - op_context.dims); + int end_mask = op_context.params->end_mask << (4 - op_context.dims); + int shrink_axis_mask = op_context.params->shrink_axis_mask + << (4 - op_context.dims); + TF_LITE_ENSURE_EQ(context, starts.size(), 4); + auto op_params = ::tflite::strided_slice::BuildStridedSliceParams( + begin_mask, end_mask, shrink_axis_mask, starts, stops, strides); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \ + GetTensorData<data_type>(op_context.input), \ + GetTensorShape(op_context.output), \ + GetTensorData<data_type>(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: |