diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 129 |
1 files changed, 90 insertions, 39 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index a027a47726..0abacf85e1 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, const int32* coords_data, const RuntimeShape& output_shape, T* output_data) { - // TODO(b/80418076): Enable these checks when moving legacy ops to - // legacy_reference_ops. + // Enable these checks when moving legacy ops to legacy_reference_ops. // // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1); const int input_rank = op_params.input_rank; @@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params, } template <typename T> -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, int shrink_axis_mask, - const std::vector<int>& start_indices, - const std::vector<int>& stop_indices, - const std::vector<int>& strides, T* output_data, - const Dims<4>& output_dims) { - // Note that the axis orders are reversed for runtime ops, so the indices, - // strides and masks must be as well too. - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); +inline void StridedSlice(const tflite::StridedSliceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + // Note that the output_shape is not used herein. + tflite::StridedSliceParams params_copy = op_params; + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + // Reverse and pad to 4 dimensions because that is what the runtime code + // requires (ie. all shapes must be 4D and are given backwards). + strided_slice::StridedSlicePadIndices(¶ms_copy, 4); + + const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0); const int stop_b = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 3, start_b); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); + strided_slice::StopForAxis(params_copy, input_shape, 0, start_b); + const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1); const int stop_h = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 2, start_h); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); + strided_slice::StopForAxis(params_copy, input_shape, 1, start_h); + const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2); const int stop_w = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 1, start_w); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); + strided_slice::StopForAxis(params_copy, input_shape, 2, start_w); + const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3); const int stop_d = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 0, start_d); + strided_slice::StopForAxis(params_copy, input_shape, 3, start_d); T* out_ptr = output_data; for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { + !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]); + in_b += params_copy.strides[0]) { for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { + !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]); + in_h += params_copy.strides[1]) { for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]); + in_w += params_copy.strides[2]) { + for (int in_d = start_d; !strided_slice::LoopCondition( + in_d, stop_d, params_copy.strides[3]); + in_d += params_copy.strides[3]) { + *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)]; } } } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline uint32 LegacyReverseBits32(uint32 n) { + n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1); + n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2); + n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4); + return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) | + ((n & 0xFF000000) >> 24)); +} + +inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) { + TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); + TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); + + std::reverse(p->start_indices, p->start_indices + p->start_indices_count); + std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count); + std::reverse(p->strides, p->strides + p->strides_count); + + p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >> + (32 - p->start_indices_count); + p->ellipsis_mask = + LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >> + (32 - p->start_indices_count); + p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >> + (32 - p->start_indices_count); + p->new_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >> + (32 - p->start_indices_count); + p->shrink_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >> + (32 - p->start_indices_count); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T> +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, int shrink_axis_mask, + const std::vector<int>& start_indices, + const std::vector<int>& stop_indices, + const std::vector<int>& strides, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_EQ(start_indices.size(), 4); + auto op_params = strided_slice::BuildStridedSliceParams( + begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices, + strides); + StridedSliceReverseIndices(&op_params); + + StridedSlice(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <typename T> inline void Slice(const tflite::SliceParams& op_params, const RuntimeShape& input_shape, const T* input_data, |