diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-31 16:05:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 16:14:33 -0700 |
commit | 0fffc1bde7f85c3c5b985bf30500c53ace6c81eb (patch) | |
tree | 23631c6082ea8b32c7e302bd2a79b96da510288a /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 2681adc66ceb0530ea2f0f50c11938fba32caa3e (diff) |
Fixing Any operator.
PiperOrigin-RevId: 211159438
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 46 |
1 files changed, 29 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 28effc2a67..c25be078ff 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -561,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { const bool keep_dims = KeepDims(*op); if (op->inputs.size() == 2) { // There is a reduction_indices input. - const auto& reduction_array = model->GetArray(op->inputs[1]); - if (!reduction_array.buffer) { + const auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.buffer) { return; } - CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); - const auto& reduction_array_vals = - reduction_array.GetBuffer<ArrayDataType::kInt32>().data; - auto& output_dims = *output_array.mutable_shape()->mutable_dims(); - output_dims.clear(); - for (int i = 0; i < input_shape.dimensions_count(); i++) { - bool is_reduction_dim = false; - for (int r : reduction_array_vals) { - if (i == r) { - is_reduction_dim = true; - } + CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32); + + int input_rank = input_shape.dimensions_count(); + std::set<int32> true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; } - if (!is_reduction_dim) { - output_dims.push_back(input_shape.dims(i)); - } else if (keep_dims) { - output_dims.push_back(1); + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); } } } else { |