aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 01:52:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 01:57:31 -0700
commit3e13ae966115b1aaf793601b0647b40efb25a2da (patch)
tree50aa649f843a698d23c99a4d6ef9c5adc6752895 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentf255b51c6e637ac7701996b4457157d3c313dca4 (diff)
Implementation of reduce_any.
PiperOrigin-RevId: 210507220
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc64
1 files changed, 3 insertions, 61 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index fa2be961f5..28effc2a67 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -539,6 +539,8 @@ bool KeepDims(const Operator& op) {
return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
+ case OperatorType::kAny:
+ return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
default:
LOG(FATAL) << "Not a reduction operator!";
return false;
@@ -1515,65 +1517,6 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
}
}
-void ProcessAnyOperator(Model* model, AnyOperator* op) {
- CHECK_EQ(op->inputs.size(), 2);
- CHECK_EQ(op->outputs.size(), 1);
-
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- // We have already run.
- return;
- }
-
- const auto& input_array = model->GetArray(op->inputs[0]);
- if (!input_array.has_shape()) {
- // Yield until input dims have been resolved.
- return;
- }
- const auto& input_shape = input_array.shape();
-
- auto& reduction_indices_array = model->GetArray(op->inputs[1]);
- if (!reduction_indices_array.has_shape()) {
- // Yield until reduction indices shape been resolved.
- return;
- }
- if (!reduction_indices_array.buffer) {
- // Yield until the reduction indices are constant.
- return;
- }
- CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32)
- << "Any reduction input must be int32";
-
- int input_rank = input_shape.dimensions_count();
- std::set<int32> true_indices;
- const auto& reduction_indices =
- reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
- for (int i = 0; i < reduction_indices.size(); ++i) {
- const int32 reduction_index = reduction_indices[i];
- if (reduction_index < -input_rank || reduction_index >= input_rank) {
- CHECK(false) << "Invalid reduction dimension " << reduction_index
- << " for input with " << input_rank << " dimensions";
- }
- int32 wrapped_index = reduction_index;
- if (wrapped_index < 0) {
- wrapped_index += input_rank;
- }
- true_indices.insert(wrapped_index);
- }
-
- auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
- mutable_dims->clear();
- for (int i = 0; i < input_rank; ++i) {
- if (true_indices.count(i) > 0) {
- if (op->keep_dims) {
- mutable_dims->emplace_back(1);
- }
- } else {
- mutable_dims->emplace_back(input_shape.dims(i));
- }
- }
-}
-
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
CHECK_EQ(op->inputs.size(), 4);
CHECK_EQ(op->outputs.size(), 1);
@@ -1769,6 +1712,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSum:
case OperatorType::kReduceProd:
case OperatorType::kMean:
+ case OperatorType::kAny:
ProcessTensorFlowReductionOperator(model, op);
break;
case OperatorType::kSelect:
@@ -1900,8 +1844,6 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
- case OperatorType::kAny:
- ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));