diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/slice.cc | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc index 6a20e802a9..55e16506df 100644 --- a/tensorflow/contrib/lite/kernels/slice.cc +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -159,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { sizes.push_back(1); } -#define TF_LITE_SLICE(data_type) \ - optimized_ops::Slice<data_type>( \ - GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \ - GetTensorData<data_type>(output), GetTensorDims(output)) + // The original Slice op implementation only accepted 4-D sizes. That + // constraint is, for the present, maintained here. + // + // The dimensions in the kernel used to be in reverse-order, and TFLite + // arranged the begins and sizes vectors accordingly. This macro incorporates + // the needed reversing. +#define TF_LITE_SLICE(data_type) \ + { \ + TF_LITE_ENSURE_EQ(context, begins.size(), 4); \ + TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \ + tflite::SliceParams op_params; \ + op_params.begin_count = 4; \ + op_params.size_count = 4; \ + for (int i = 0; i < 4; ++i) { \ + op_params.begin[i] = begins[3 - i]; \ + op_params.size[i] = sizes[3 - i]; \ + } \ + \ + optimized_ops::Slice<data_type>( \ + op_params, GetTensorShape(input), GetTensorData<data_type>(input), \ + GetTensorShape(output), GetTensorData<data_type>(output)); \ + } switch (input->type) { case kTfLiteFloat32: |