diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 01:52:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 01:57:31 -0700 |
commit | 3e13ae966115b1aaf793601b0647b40efb25a2da (patch) | |
tree | 50aa649f843a698d23c99a4d6ef9c5adc6752895 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | f255b51c6e637ac7701996b4457157d3c313dca4 (diff) |
Implementation of reduce_any.
PiperOrigin-RevId: 210507220
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 64 |
1 files changed, 3 insertions, 61 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index fa2be961f5..28effc2a67 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -539,6 +539,8 @@ bool KeepDims(const Operator& op) { return static_cast<const TensorFlowProdOperator&>(op).keep_dims; case OperatorType::kMean: return static_cast<const MeanOperator&>(op).keep_dims; + case OperatorType::kAny: + return static_cast<const TensorFlowAnyOperator&>(op).keep_dims; default: LOG(FATAL) << "Not a reduction operator!"; return false; @@ -1515,65 +1517,6 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } -void ProcessAnyOperator(Model* model, AnyOperator* op) { - CHECK_EQ(op->inputs.size(), 2); - CHECK_EQ(op->outputs.size(), 1); - - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - // We have already run. - return; - } - - const auto& input_array = model->GetArray(op->inputs[0]); - if (!input_array.has_shape()) { - // Yield until input dims have been resolved. - return; - } - const auto& input_shape = input_array.shape(); - - auto& reduction_indices_array = model->GetArray(op->inputs[1]); - if (!reduction_indices_array.has_shape()) { - // Yield until reduction indices shape been resolved. - return; - } - if (!reduction_indices_array.buffer) { - // Yield until the reduction indices are constant. - return; - } - CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) - << "Any reduction input must be int32"; - - int input_rank = input_shape.dimensions_count(); - std::set<int32> true_indices; - const auto& reduction_indices = - reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; - for (int i = 0; i < reduction_indices.size(); ++i) { - const int32 reduction_index = reduction_indices[i]; - if (reduction_index < -input_rank || reduction_index >= input_rank) { - CHECK(false) << "Invalid reduction dimension " << reduction_index - << " for input with " << input_rank << " dimensions"; - } - int32 wrapped_index = reduction_index; - if (wrapped_index < 0) { - wrapped_index += input_rank; - } - true_indices.insert(wrapped_index); - } - - auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); - mutable_dims->clear(); - for (int i = 0; i < input_rank; ++i) { - if (true_indices.count(i) > 0) { - if (op->keep_dims) { - mutable_dims->emplace_back(1); - } - } else { - mutable_dims->emplace_back(input_shape.dims(i)); - } - } -} - void ProcessOneHotOperator(Model* model, OneHotOperator* op) { CHECK_EQ(op->inputs.size(), 4); CHECK_EQ(op->outputs.size(), 1); @@ -1769,6 +1712,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kSum: case OperatorType::kReduceProd: case OperatorType::kMean: + case OperatorType::kAny: ProcessTensorFlowReductionOperator(model, op); break; case OperatorType::kSelect: @@ -1900,8 +1844,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op)); break; - case OperatorType::kAny: - ProcessAnyOperator(model, static_cast<AnyOperator*>(op)); break; case OperatorType::kOneHot: ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op)); |