diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-17 13:20:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-17 13:23:19 -0700 |
commit | 83418120b7c2659fedddd7c85b65d3c3e6aa94e3 (patch) | |
tree | 3f8cbb0db3ee2c059a8eb99f6956b6a78f088497 /tensorflow/contrib/lite/toco/model.h | |
parent | 33d55d7caff2bd32fa2b1c5cacb7ac251c48e27d (diff) |
Fixing a bug in strided slice. The op was not handling negative indices correctly.
PiperOrigin-RevId: 193245539
Diffstat (limited to 'tensorflow/contrib/lite/toco/model.h')
-rw-r--r-- | tensorflow/contrib/lite/toco/model.h | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 1c4c96ae70..705a9d69a6 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" @@ -845,6 +846,60 @@ struct StridedSliceOperator : Operator { int end_mask; int new_axis_mask; int shrink_axis_mask; + + StridedSliceOperator(const StridedSliceOperator& other) + : Operator(OperatorType::kStridedSlice) { + inputs = other.inputs; + outputs = other.outputs; + + start_indices = other.start_indices; + stop_indices = other.stop_indices; + strides = other.strides; + + begin_mask = other.begin_mask; + ellipsis_mask = other.ellipsis_mask; + end_mask = other.end_mask; + new_axis_mask = other.new_axis_mask; + shrink_axis_mask = other.shrink_axis_mask; + } + + void PadIndices(int dim_count) { + // Add indices and mask bits to fully include extra dimensions + CHECK_GE(dim_count, start_indices.size()); + CHECK_EQ(start_indices.size(), stop_indices.size()); + CHECK_EQ(stop_indices.size(), strides.size()); + + for (int i = start_indices.size(); i < dim_count; i++) { + start_indices.push_back(0); + stop_indices.push_back(0); + strides.push_back(1); + begin_mask |= 1 << i; + end_mask |= 1 << i; + } + } + + void ReverseIndices() { + CHECK_EQ(start_indices.size(), stop_indices.size()); + CHECK_EQ(stop_indices.size(), strides.size()); + + std::reverse(start_indices.begin(), start_indices.end()); + std::reverse(stop_indices.begin(), stop_indices.end()); + std::reverse(strides.begin(), strides.end()); + + begin_mask = toco::port::ReverseBits32(static_cast<uint32>(begin_mask)) >> + (32 - start_indices.size()); + ellipsis_mask = + toco::port::ReverseBits32(static_cast<uint32>(ellipsis_mask)) >> + (32 - start_indices.size()); + end_mask = toco::port::ReverseBits32(static_cast<uint32>(end_mask)) >> + (32 - start_indices.size()); + new_axis_mask = + toco::port::ReverseBits32(static_cast<uint32>(new_axis_mask)) >> + (32 - start_indices.size()); + shrink_axis_mask = + toco::port::ReverseBits32(static_cast<uint32>(shrink_axis_mask)) >> + (32 - start_indices.size()); + } }; // Reshaping operator, reshaping its input array to a two-dimensional shape |