aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-26 13:35:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 13:37:55 -0700
commit7b0e865d79d8b9bacf855779b9c3ccf73d2571ac (patch)
tree764d82dfa3a4e1e20fb9b742fd73e46ba3b3c271 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentb6189a23a5f6afa59ced097d7844d58c7fd24901 (diff)
Adding some slightly more exhaustive strided_slice test parameters.
PiperOrigin-RevId: 194446000
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc91
1 files changed, 7 insertions, 84 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index be6e0e07dd..19037bc503 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
@@ -1235,83 +1236,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
output_array.copy_shape(*stacked_shape);
}
-// These StridedSlice utility functions are essentially a COPY of those in
-// reference_ops.h. See comments there.
-
-// Use until std::clamp() is available from C++17.
-int Clamp(const int v, const int lo, const int hi) {
- if (hi < v) return hi;
- if (v < lo) return lo;
- return v;
-}
-
-int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int start = op.start_indices[axis];
-
- // begin_mask override
- if (op.begin_mask & 1 << axis) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = std::numeric_limits<int>::lowest();
- } else {
- // Backward iteration - use the last element.
- start = std::numeric_limits<int>::max();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (start < 0) {
- start += axis_size;
- }
-
- // Clamping
- start = Clamp(start, 0, axis_size - 1);
-
- return start;
-}
-
-int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape,
- int axis) {
- // Begin with the specified index
- int stop = op.stop_indices[axis];
-
- // end_mask override
- if (op.end_mask & (1 << axis)) {
- if (op.strides[axis] > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = std::numeric_limits<int>::max();
- } else {
- // Backward iteration - use the first element.
- stop = std::numeric_limits<int>::lowest();
- }
- }
-
- // Handle negative indices
- int axis_size = input_shape.dims(axis);
- if (stop < 0) {
- stop += axis_size;
- }
-
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (op.strides[axis] > 0) {
- // Forward iteration
- stop = Clamp(stop, 0, axis_size);
- } else {
- // Backward iteration
- stop = Clamp(stop, -1, axis_size - 1);
- }
-
- return stop;
-}
-
void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
@@ -1364,18 +1288,17 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
<< " has stride=" << op->strides[i] << ".";
}
- // The TensorFlow documentation is not explicit on how it handles fewer
- // supplied indices than dimensions, but they are accepted. We emulate TF's
- // behavior by fully iterating over each "forgotten" dimension.
- op->PadIndices(num_input_axes);
-
// Create output shape
std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
- int start_index = StartForAxis(*op, input_array.shape(), axis);
- int stop_index = StopForAxis(*op, input_array.shape(), axis);
+ int start_index = tflite::strided_slice::StartForAxis(
+ op->begin_mask, op->start_indices, op->strides,
+ input_array.shape().dims().data(), axis);
+ int stop_index = tflite::strided_slice::StopForAxis(
+ op->end_mask, op->stop_indices, op->strides,
+ input_array.shape().dims().data(), axis);
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);