diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-31 16:05:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 16:14:33 -0700 |
commit | 0fffc1bde7f85c3c5b985bf30500c53ace6c81eb (patch) | |
tree | 23631c6082ea8b32c7e302bd2a79b96da510288a | |
parent | 2681adc66ceb0530ea2f0f50c11938fba32caa3e (diff) |
Fixing Any operator.
PiperOrigin-RevId: 211159438
4 files changed, 37 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6fdf47dedc..b52a79282c 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1701,9 +1701,11 @@ void ConvertReduceOperator(const Model& model, const T& src_op, *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const tensorflow::DataType params_type = - GetTensorFlowDataType(model, src_op.inputs[0]); - (*new_op->mutable_attr())["T"].set_type(params_type); + if (src_op.type != OperatorType::kAny) { + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + } const tensorflow::DataType indices_type = GetTensorFlowDataType(model, src_op.inputs[1]); (*new_op->mutable_attr())["Tidx"].set_type(indices_type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 28effc2a67..c25be078ff 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -561,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { const bool keep_dims = KeepDims(*op); if (op->inputs.size() == 2) { // There is a reduction_indices input. - const auto& reduction_array = model->GetArray(op->inputs[1]); - if (!reduction_array.buffer) { + const auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.buffer) { return; } - CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); - const auto& reduction_array_vals = - reduction_array.GetBuffer<ArrayDataType::kInt32>().data; - auto& output_dims = *output_array.mutable_shape()->mutable_dims(); - output_dims.clear(); - for (int i = 0; i < input_shape.dimensions_count(); i++) { - bool is_reduction_dim = false; - for (int r : reduction_array_vals) { - if (i == r) { - is_reduction_dim = true; - } + CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32); + + int input_rank = input_shape.dimensions_count(); + std::set<int32> true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; } - if (!is_reduction_dim) { - output_dims.push_back(input_shape.dims(i)); - } else if (keep_dims) { - output_dims.push_back(1); + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); } } } else { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 7d456af2fb..73198ac7c0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -52,6 +52,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) { return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op)); case OperatorType::kReduceMax: return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op)); + case OperatorType::kAny: + return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op)); default: return false; } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index fa1c459f0e..2e100e37f6 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -1768,6 +1768,7 @@ struct PowOperator : Operator { // // Inputs: // Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. // // TensorFlow equivalent: tf.reduce_any. struct TensorFlowAnyOperator : Operator { |