diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-18 14:12:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 14:15:24 -0700 |
commit | 4ca04537c0d1d75ea37944aa3bb2dc749428031a (patch) | |
tree | 663b7a6b36e560258dd95a603e0440a4c7b278ac /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 44af531d952a35c887770ecc4cfddfb0431c2478 (diff) |
Import/export support for Any, LogicalAnd, and LogicalNot ops.
PiperOrigin-RevId: 205134621
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 68 |
1 files changed, 66 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4275ee9a03..62ed5c46e9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -529,7 +529,7 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min return static_cast<const TensorFlowMinOperator&>(op).keep_dims; case OperatorType::kReduceMax: // Reduction Max return static_cast<const TensorFlowMaxOperator&>(op).keep_dims; @@ -1519,6 +1519,65 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } +void ProcessAnyOperator(Model* model, AnyOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.has_shape()) { + // Yield until reduction indices shape been resolved. + return; + } + if (!reduction_indices_array.buffer) { + // Yield until the reduction indices are constant. + return; + } + CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) + << "Any reduction input must be int32"; + + int input_rank = input_shape.dimensions_count(); + std::set<int32> true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (op->keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); + } + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1557,6 +1616,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kFloor: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1625,7 +1686,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op)); break; - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min case OperatorType::kReduceMax: // Reduction Max case OperatorType::kSum: case OperatorType::kReduceProd: @@ -1750,6 +1811,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op)); break; + case OperatorType::kAny: + ProcessAnyOperator(model, static_cast<AnyOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); |