aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/model.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 13:20:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 13:23:19 -0700
commit83418120b7c2659fedddd7c85b65d3c3e6aa94e3 (patch)
tree3f8cbb0db3ee2c059a8eb99f6956b6a78f088497 /tensorflow/contrib/lite/toco/model.h
parent33d55d7caff2bd32fa2b1c5cacb7ac251c48e27d (diff)
Fixing a bug in strided slice. The op was not handling negative indices correctly.
PiperOrigin-RevId: 193245539
Diffstat (limited to 'tensorflow/contrib/lite/toco/model.h')
-rw-r--r--tensorflow/contrib/lite/toco/model.h55
1 files changed, 55 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 1c4c96ae70..705a9d69a6 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
@@ -845,6 +846,60 @@ struct StridedSliceOperator : Operator {
int end_mask;
int new_axis_mask;
int shrink_axis_mask;
+
+ StridedSliceOperator(const StridedSliceOperator& other)
+ : Operator(OperatorType::kStridedSlice) {
+ inputs = other.inputs;
+ outputs = other.outputs;
+
+ start_indices = other.start_indices;
+ stop_indices = other.stop_indices;
+ strides = other.strides;
+
+ begin_mask = other.begin_mask;
+ ellipsis_mask = other.ellipsis_mask;
+ end_mask = other.end_mask;
+ new_axis_mask = other.new_axis_mask;
+ shrink_axis_mask = other.shrink_axis_mask;
+ }
+
+ void PadIndices(int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ CHECK_GE(dim_count, start_indices.size());
+ CHECK_EQ(start_indices.size(), stop_indices.size());
+ CHECK_EQ(stop_indices.size(), strides.size());
+
+ for (int i = start_indices.size(); i < dim_count; i++) {
+ start_indices.push_back(0);
+ stop_indices.push_back(0);
+ strides.push_back(1);
+ begin_mask |= 1 << i;
+ end_mask |= 1 << i;
+ }
+ }
+
+ void ReverseIndices() {
+ CHECK_EQ(start_indices.size(), stop_indices.size());
+ CHECK_EQ(stop_indices.size(), strides.size());
+
+ std::reverse(start_indices.begin(), start_indices.end());
+ std::reverse(stop_indices.begin(), stop_indices.end());
+ std::reverse(strides.begin(), strides.end());
+
+ begin_mask = toco::port::ReverseBits32(static_cast<uint32>(begin_mask)) >>
+ (32 - start_indices.size());
+ ellipsis_mask =
+ toco::port::ReverseBits32(static_cast<uint32>(ellipsis_mask)) >>
+ (32 - start_indices.size());
+ end_mask = toco::port::ReverseBits32(static_cast<uint32>(end_mask)) >>
+ (32 - start_indices.size());
+ new_axis_mask =
+ toco::port::ReverseBits32(static_cast<uint32>(new_axis_mask)) >>
+ (32 - start_indices.size());
+ shrink_axis_mask =
+ toco::port::ReverseBits32(static_cast<uint32>(shrink_axis_mask)) >>
+ (32 - start_indices.size());
+ }
};
// Reshaping operator, reshaping its input array to a two-dimensional shape