aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-31 16:05:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 16:14:33 -0700
commit0fffc1bde7f85c3c5b985bf30500c53ace6c81eb (patch)
tree23631c6082ea8b32c7e302bd2a79b96da510288a /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent2681adc66ceb0530ea2f0f50c11938fba32caa3e (diff)
Fixing Any operator.
PiperOrigin-RevId: 211159438
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc46
1 files changed, 29 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 28effc2a67..c25be078ff 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -561,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = model->GetArray(op->inputs[1]);
- if (!reduction_array.buffer) {
+ const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.buffer) {
return;
}
- CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
- const auto& reduction_array_vals =
- reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
- auto& output_dims = *output_array.mutable_shape()->mutable_dims();
- output_dims.clear();
- for (int i = 0; i < input_shape.dimensions_count(); i++) {
- bool is_reduction_dim = false;
- for (int r : reduction_array_vals) {
- if (i == r) {
- is_reduction_dim = true;
- }
+ CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
+ }
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
}
- if (!is_reduction_dim) {
- output_dims.push_back(input_shape.dims(i));
- } else if (keep_dims) {
- output_dims.push_back(1);
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
}
}
} else {