diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 153 |
1 files changed, 126 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 82b3ab96fe..a03b589bae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -437,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model, product_non_wildcard_dims *= shape_data[i]; } } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); if (has_wildcard) { CHECK_GE(input_flat_size, product_non_wildcard_dims) @@ -445,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; } + + if (shape_data.size() == 1 && shape_data[0] == 0) { + // We have reshaped a scalar, so preserve as a scalar. + shape_data.clear(); + } + auto& output_shape = *output_array.mutable_shape(); *output_shape.mutable_dims() = shape_data; CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) @@ -522,12 +529,14 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min return static_cast<const TensorFlowMinOperator&>(op).keep_dims; - case OperatorType::kMax: // Reduction Max + case OperatorType::kReduceMax: // Reduction Max return static_cast<const TensorFlowMaxOperator&>(op).keep_dims; case OperatorType::kSum: return static_cast<const TensorFlowSumOperator&>(op).keep_dims; + case OperatorType::kReduceProd: + return static_cast<const TensorFlowProdOperator&>(op).keep_dims; case OperatorType::kMean: return static_cast<const MeanOperator&>(op).keep_dims; default: @@ -1034,20 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { return; } + // Yield until the axis has been resolved. + if (!op->axis) { + return; + } + int axis = op->axis.value(); + const auto& input_shape = input_array.shape(); const auto& indices_shape = indices_array.shape(); QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); + QCHECK_LT(axis, op->input_rank); - // We only support 1-D indices. - QCHECK_EQ(indices_shape.dimensions_count(), 1); - - // Copy the input dimensions to the output except for dimension 0, + // Copy the input dimensions to the output except for the axis dimensions // where the dimension of indices_shape is used. - // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); - output_dims->push_back(indices_shape.dims(0)); - for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + for (int dim = 0; dim < axis; ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } + for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) { output_dims->push_back(input_shape.dims(dim)); } } @@ -1193,7 +1210,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { output_shape->ReplaceDims({input_array.shape().dimensions_count()}); } -void ProcessStackOperator(Model* model, StackOperator* op) { +void ProcessPackOperator(Model* model, PackOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); @@ -1202,7 +1219,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) { return; } - std::unique_ptr<Shape> stacked_shape; + std::unique_ptr<Shape> packed_shape; for (const auto& input : op->inputs) { const auto& input_array = model->GetArray(input); if (!input_array.has_shape()) { @@ -1211,23 +1228,23 @@ void ProcessStackOperator(Model* model, StackOperator* op) { } Shape shape = input_array.shape(); - if (!stacked_shape) { - stacked_shape.reset(new Shape(shape)); + if (!packed_shape) { + packed_shape.reset(new Shape(shape)); } else { - CHECK(*stacked_shape == shape) << "All input arrays to Stack operators " - "must have the same shape. Input \"" - << input << "\" is different."; + CHECK(*packed_shape == shape) << "All input arrays to Pack operators " + "must have the same shape. Input \"" + << input << "\" is different."; } } int axis = op->axis; if (axis < 0) { // Handle negative axis - axis += stacked_shape->dims().size() + 1; + axis += packed_shape->dims().size() + 1; } - stacked_shape->mutable_dims()->insert( - stacked_shape->mutable_dims()->begin() + axis, op->inputs.size()); - output_array.copy_shape(*stacked_shape); + packed_shape->mutable_dims()->insert( + packed_shape->mutable_dims()->begin() + axis, op->inputs.size()); + output_array.copy_shape(*packed_shape); } void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { @@ -1407,7 +1424,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { } } -void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { +template <typename Op> +void ProcessArgMinMaxOperator(Model* model, Op* op) { CHECK_EQ(op->inputs.size(), 2); const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1501,6 +1519,65 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } +void ProcessAnyOperator(Model* model, AnyOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.has_shape()) { + // Yield until reduction indices shape been resolved. + return; + } + if (!reduction_indices_array.buffer) { + // Yield until the reduction indices are constant. + return; + } + CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) + << "Any reduction input must be int32"; + + int input_rank = input_shape.dimensions_count(); + std::set<int32> true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (op->keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); + } + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1539,6 +1616,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kFloor: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1607,9 +1686,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op)); break; - case OperatorType::kMin: // Reduction Min - case OperatorType::kMax: // Reduction Max + case OperatorType::kReduceMin: // Reduction Min + case OperatorType::kReduceMax: // Reduction Max case OperatorType::kSum: + case OperatorType::kReduceProd: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1658,8 +1738,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kShape: ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op)); break; - case OperatorType::kStack: - ProcessStackOperator(model, static_cast<StackOperator*>(op)); + case OperatorType::kPack: + ProcessPackOperator(model, static_cast<PackOperator*>(op)); break; case OperatorType::kReorderAxes: ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op)); @@ -1699,10 +1779,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast<StridedSliceOperator*>(op)); break; case OperatorType::kArgMax: - ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op)); + ProcessArgMinMaxOperator<ArgMaxOperator>( + model, static_cast<ArgMaxOperator*>(op)); + break; + case OperatorType::kArgMin: + ProcessArgMinMaxOperator<ArgMinOperator>( + model, static_cast<ArgMinOperator*>(op)); break; - case OperatorType::kUnsupported: + case OperatorType::kUnsupported: { + const auto* unsupported_op = + static_cast<TensorFlowUnsupportedOperator*>(op); + // Attribute can be not specified, ignore it. + if (unsupported_op->output_shapes.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + const string& output = op->outputs[i]; + model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i)); + } break; + } case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op)); break; @@ -1726,6 +1822,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op)); break; + case OperatorType::kAny: + ProcessAnyOperator(model, static_cast<AnyOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); |