aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/strided_slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 06:12:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 06:16:42 -0700
commitabf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch)
tree5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/strided_slice.cc
parent19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc48
1 files changed, 19 insertions, 29 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 87ffcc4110..06b36dd196 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -57,17 +57,6 @@ struct StridedSliceContext {
int dims;
};
-// Reverse order of bits in the mask to match the expected order in kernel
-inline int ReverseMaskBits(int mask, int num_dimensions) {
- int out = 0;
- for (int dim = 0; dim < num_dimensions; dim++) {
- out <<= 1;
- out += (mask & 1);
- mask >>= 1;
- }
- return out;
-}
-
// This Op only supports 1-4D cases and since we use the reference 4D
// implementation, the 1-3D tensors are mapped to 4D.
const int kMaxDim = 4;
@@ -198,30 +187,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<int32_t> stops;
std::vector<int32_t> strides;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
- stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
- strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
- }
-
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- int begin_mask =
- ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
- int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
- int shrink_axis_mask =
- ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice( \
- GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
- starts, stops, strides, GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ for (int idx = 0; idx < op_context.dims; ++idx) {
+ starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
+ stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
+ int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
+ int end_mask = op_context.params->end_mask << (4 - op_context.dims);
+ int shrink_axis_mask = op_context.params->shrink_axis_mask
+ << (4 - op_context.dims);
+ TF_LITE_ENSURE_EQ(context, starts.size(), 4);
+ auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<data_type>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: