aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h129
1 files changed, 90 insertions, 39 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index a027a47726..0abacf85e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& coords_shape, const int32* coords_data,
const RuntimeShape& output_shape, T* output_data) {
- // TODO(b/80418076): Enable these checks when moving legacy ops to
- // legacy_reference_ops.
+ // Enable these checks when moving legacy ops to legacy_reference_ops.
//
// TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
const int input_rank = op_params.input_rank;
@@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params,
}
template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 3, start_b);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 2, start_h);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 1, start_w);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 0, start_d);
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,