diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-26 13:35:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-26 13:37:55 -0700 |
commit | 7b0e865d79d8b9bacf855779b9c3ccf73d2571ac (patch) | |
tree | 764d82dfa3a4e1e20fb9b742fd73e46ba3b3c271 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | b6189a23a5f6afa59ced097d7844d58c7fd24901 (diff) |
Adding some slightly more exhaustive strided_slice test parameters.
PiperOrigin-RevId: 194446000
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 91 |
1 files changed, 7 insertions, 84 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index be6e0e07dd..19037bc503 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "absl/strings/str_join.h" +#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" @@ -1235,83 +1236,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) { output_array.copy_shape(*stacked_shape); } -// These StridedSlice utility functions are essentially a COPY of those in -// reference_ops.h. See comments there. - -// Use until std::clamp() is available from C++17. -int Clamp(const int v, const int lo, const int hi) { - if (hi < v) return hi; - if (v < lo) return lo; - return v; -} - -int StartForAxis(StridedSliceOperator const& op, Shape const& input_shape, - int axis) { - // Begin with the specified index - int start = op.start_indices[axis]; - - // begin_mask override - if (op.begin_mask & 1 << axis) { - if (op.strides[axis] > 0) { - // Forward iteration - use the first element. These values will get - // clamped below (Note: We could have set them to 0 and axis_size-1, but - // use lowest() and max() to maintain symmetry with StopForAxis()) - start = std::numeric_limits<int>::lowest(); - } else { - // Backward iteration - use the last element. - start = std::numeric_limits<int>::max(); - } - } - - // Handle negative indices - int axis_size = input_shape.dims(axis); - if (start < 0) { - start += axis_size; - } - - // Clamping - start = Clamp(start, 0, axis_size - 1); - - return start; -} - -int StopForAxis(StridedSliceOperator const& op, Shape const& input_shape, - int axis) { - // Begin with the specified index - int stop = op.stop_indices[axis]; - - // end_mask override - if (op.end_mask & (1 << axis)) { - if (op.strides[axis] > 0) { - // Forward iteration - use the last element. These values will get - // clamped below - stop = std::numeric_limits<int>::max(); - } else { - // Backward iteration - use the first element. - stop = std::numeric_limits<int>::lowest(); - } - } - - // Handle negative indices - int axis_size = input_shape.dims(axis); - if (stop < 0) { - stop += axis_size; - } - - // Clamping - // Because the end index points one past the last element, we need slightly - // different clamping ranges depending on the direction. - if (op.strides[axis] > 0) { - // Forward iteration - stop = Clamp(stop, 0, axis_size); - } else { - // Backward iteration - stop = Clamp(stop, -1, axis_size - 1); - } - - return stop; -} - void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); @@ -1364,18 +1288,17 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { << " has stride=" << op->strides[i] << "."; } - // The TensorFlow documentation is not explicit on how it handles fewer - // supplied indices than dimensions, but they are accepted. We emulate TF's - // behavior by fully iterating over each "forgotten" dimension. - op->PadIndices(num_input_axes); - // Create output shape std::vector<int>* dims = output_array.mutable_shape()->mutable_dims(); // Compute output shape for (int axis = 0; axis < num_input_axes; ++axis) { - int start_index = StartForAxis(*op, input_array.shape(), axis); - int stop_index = StopForAxis(*op, input_array.shape(), axis); + int start_index = tflite::strided_slice::StartForAxis( + op->begin_mask, op->start_indices, op->strides, + input_array.shape().dims().data(), axis); + int stop_index = tflite::strided_slice::StopForAxis( + op->end_mask, op->stop_indices, op->strides, + input_array.shape().dims().data(), axis); int dim_size = ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]); |