From 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Tue, 9 Oct 2018 11:38:15 -0700 Subject: Return ::tensorflow::Status in Toco Graph Transformations. PiperOrigin-RevId: 216392908 --- .../convert_expanddims_to_reshape.cc | 16 ++++++---- .../convert_pure_conv_to_depthwise.cc | 24 +++++++++------ .../graph_transformations/convert_reorder_axes.cc | 15 ++++++--- .../convert_squeeze_to_reshape.cc | 18 ++++++----- .../convert_trivial_addn_to_add.cc | 12 +++++--- .../convert_trivial_pack_to_reshape.cc | 16 ++++++---- .../convert_trivial_tile_to_concat.cc | 16 ++++++---- .../convert_trivial_transpose_to_reshape.cc | 16 ++++++---- .../graph_transformations/create_im2col_arrays.cc | 12 +++++--- .../lite/toco/graph_transformations/dequantize.cc | 14 ++++++--- .../toco/graph_transformations/drop_fake_quant.cc | 13 +++++--- .../graph_transformations/drop_im2col_arrays.cc | 11 ++++--- .../graph_transformations/ensure_bias_vectors.cc | 9 ++++-- ...ure_uint8_weights_safe_for_fast_int8_kernels.cc | 14 +++++---- .../fuse_activation_functions.cc | 22 +++++++------ .../fuse_binary_into_following_affine.cc | 32 ++++++++++--------- .../fuse_binary_into_preceding_affine.cc | 36 ++++++++++++---------- .../fuse_broadcast_into_following_binary.cc | 16 ++++++---- .../graph_transformations/graph_transformations.cc | 2 +- .../graph_transformations/graph_transformations.h | 29 ++++++++++------- .../toco/graph_transformations/hardcode_min_max.cc | 7 +++-- .../graph_transformations/identify_dilated_conv.cc | 16 ++++++---- .../identify_l2_normalization.cc | 22 +++++++------ .../toco/graph_transformations/identify_l2_pool.cc | 15 +++++---- .../toco/graph_transformations/identify_lstm.cc | 33 +++++++++++--------- .../identify_lstm_merge_inputs.cc | 16 ++++++---- .../identify_lstm_split_inputs.cc | 16 ++++++---- .../toco/graph_transformations/identify_prelu.cc | 19 +++++++----- .../toco/graph_transformations/identify_relu1.cc | 17 +++++----- .../make_initial_dequantize_operator.cc | 8 +++-- .../merge_reshape_into_preceding_transpose.cc | 26 ++++++++-------- .../move_binary_operator_before_reshape.cc | 30 ++++++++++-------- ...propagate_activation_function_into_constants.cc | 20 ++++++------ .../propagate_array_data_types.cc | 18 ++++++----- .../propagate_default_min_max.cc | 8 +++-- .../propagate_fake_quant_num_bits.cc | 12 +++++--- .../graph_transformations/propagate_fixed_sizes.cc | 12 +++++--- .../lite/toco/graph_transformations/quantize.cc | 13 +++++--- ...rray_minmax_and_narrow_range_from_fake_quant.cc | 12 +++++--- .../remove_final_dequantize_op.cc | 12 +++++--- .../remove_tensorflow_assert.cc | 10 ++++-- .../remove_tensorflow_identity.cc | 10 ++++-- .../graph_transformations/remove_trivial_binary.cc | 22 +++++++------ .../remove_trivial_concatenation.cc | 12 +++++--- .../remove_trivial_concatenation_input.cc | 12 +++++--- .../remove_trivial_fake_quant.cc | 12 +++++--- .../remove_trivial_quantized_activation_func.cc | 15 +++++---- .../remove_trivial_quantized_min_max.cc | 12 +++++--- .../remove_trivial_reshape.cc | 12 +++++--- .../graph_transformations/remove_trivial_slice.cc | 11 ++++--- .../toco/graph_transformations/remove_unused_op.cc | 15 +++++---- .../reorder_elementwise_unary.cc | 18 ++++++----- .../reorder_reshape_transpose.cc | 24 +++++++++------ .../resolve_batch_normalization.cc | 12 +++++--- .../resolve_batch_to_space_nd_attributes.cc | 21 ++++++++----- .../resolve_constant_binary.cc | 16 ++++++---- .../resolve_constant_concatenation.cc | 24 ++++++++++----- .../resolve_constant_fake_quant.cc | 16 ++++++---- .../graph_transformations/resolve_constant_fill.cc | 26 +++++++++------- .../resolve_constant_gather.cc | 20 +++++++----- .../graph_transformations/resolve_constant_pack.cc | 16 ++++++---- .../resolve_constant_random_uniform.cc | 18 ++++++----- .../resolve_constant_range.cc | 20 +++++++----- .../resolve_constant_reshape.cc | 20 +++++++----- .../resolve_constant_select.cc | 21 ++++++++----- .../resolve_constant_shape_or_rank.cc | 16 ++++++---- .../resolve_constant_slice.cc | 28 +++++++++-------- .../resolve_constant_strided_slice.cc | 20 +++++++----- .../graph_transformations/resolve_constant_tile.cc | 16 ++++++---- .../resolve_constant_transpose.cc | 18 ++++++----- .../resolve_constant_unary.cc | 28 +++++++++-------- .../resolve_fake_quant_args_from_vars.cc | 14 ++++++--- .../resolve_gather_attributes.cc | 20 +++++++----- .../resolve_multiply_by_zero.cc | 30 ++++++++++-------- .../resolve_pad_attributes.cc | 17 ++++++---- .../resolve_padv2_attributes.cc | 17 ++++++---- .../resolve_reduce_attributes.cc | 30 +++++++++++++----- .../graph_transformations/resolve_reorder_axes.cc | 13 +++++--- .../resolve_reshape_attributes.cc | 14 ++++++--- .../resolve_slice_attributes.cc | 22 ++++++++----- .../resolve_space_to_batch_nd_attributes.cc | 21 ++++++++----- .../resolve_squeeze_attributes.cc | 12 +++++--- .../resolve_strided_slice_attributes.cc | 32 +++++++++++-------- .../resolve_tensorflow_concat.cc | 12 +++++--- .../resolve_tensorflow_matmul.cc | 12 +++++--- .../resolve_tensorflow_merge.cc | 12 +++++--- .../resolve_tensorflow_switch.cc | 12 +++++--- .../resolve_transpose_attributes.cc | 18 +++++++---- .../graph_transformations/shuffle_fc_weights.cc | 27 ++++++++-------- .../tests/resolve_constant_concatenation_test.cc | 15 +++++++-- .../tests/resolve_constant_unary_test.cc | 3 +- .../unfuse_activation_functions.cc | 12 +++++--- .../unpartition_embedding_lookup.cc | 24 +++++++++------ .../graph_transformations/unroll_batch_matmul.cc | 15 ++++++--- 94 files changed, 1003 insertions(+), 617 deletions(-) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 310a88484c..8a945ac435 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -25,10 +25,13 @@ limitations under the License. namespace toco { -bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto expand_it = model->operators.begin() + op_index; if (expand_it->get()->type != OperatorType::kExpandDims) { - return false; + return ::tensorflow::Status::OK(); } ExpandDimsOperator* expand_op = static_cast(expand_it->get()); @@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { const auto& input_array = model->GetArray(expand_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { // Yield until input axis array shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1); if (!axis_array.buffer) { // Yield until the input axis array is constant - return false; + return ::tensorflow::Status::OK(); } int axis = axis_array.GetBuffer().data[0]; std::vector reshape_dims(input_array.shape().dims()); @@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(expand_it->get(), expand_op); model->operators.erase(expand_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index e88839be5d..a151012891 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -24,29 +24,32 @@ limitations under the License. namespace toco { -bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto conv_it = model->operators.begin() + op_index; if (conv_it->get()->type != OperatorType::kConv) { - return false; + return ::tensorflow::Status::OK(); } const auto* conv_op = static_cast(conv_it->get()); if (conv_op->stride_width != conv_op->stride_height) { - return false; + return ::tensorflow::Status::OK(); } if ((conv_op->dilation_width_factor != 1) || (conv_op->dilation_height_factor != 1)) { // Depthwise conv does not support dilation - return false; + return ::tensorflow::Status::OK(); } auto& input_array = model->GetArray(conv_op->inputs[0]); if (!input_array.has_shape()) { // Shapes not propagated yet - return false; + return ::tensorflow::Status::OK(); } if (input_array.shape().dims(3) != 1) { // Not a pure convolution: Conv does accumulation across the depth // dimension. - return false; + return ::tensorflow::Status::OK(); } const auto& weights_name = conv_op->inputs[1]; @@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { "Not changing %s to DepthwiseConv because the weights is consumed by " "another op.", LogName(*conv_op)); - return false; + return ::tensorflow::Status::OK(); } auto& weights_array = model->GetArray(weights_name); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. - return false; + return ::tensorflow::Status::OK(); } if (weights_array.data_type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( @@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { } *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth}; weights_buffer.data = depthwise_conv_weights_data; - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc index 0d274fc687..4a264e1cf1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc @@ -86,9 +86,12 @@ TransposeOperator* CreateTransposeFromReorderAxes( // Converts ReorderAxes into Transpose and Reshape which are compatible with the // TFLite interpreter. -bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto reorder_it = model->operators.begin() + op_index; - if (reorder_it->get()->type != OperatorType::kReorderAxes) return false; + if (reorder_it->get()->type != OperatorType::kReorderAxes) + return ::tensorflow::Status::OK(); auto* reorder_op = static_cast(reorder_it->get()); CHECK_EQ(reorder_op->inputs.size(), 1); @@ -113,8 +116,9 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) { // Yield if input array contains constants or if output array size has not // been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will // be merged into a constant array when possible. - if (IsConstantParameterArray(*model, constant_input_array_name)) return false; - if (!output_array.has_shape()) return false; + if (IsConstantParameterArray(*model, constant_input_array_name)) + return ::tensorflow::Status::OK(); + if (!output_array.has_shape()) return ::tensorflow::Status::OK(); const auto input_axes_order = reorder_op->input_axes_order; const auto output_axes_order = reorder_op->output_axes_order; @@ -143,7 +147,8 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) { CHECK_EQ(reorder_it->get(), reorder_op); model->operators.erase(reorder_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc index 81cedb5dad..a0bd1ed4a4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -30,10 +30,13 @@ namespace toco { // means that the data layout will never change with this op, just the shape. // By converting these to reshapes once we have run shape propagation we allow // standard reshape optimization transforms to do their magic. -bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto squeeze_it = model->operators.begin() + op_index; if (squeeze_it->get()->type != OperatorType::kSqueeze) { - return false; + return ::tensorflow::Status::OK(); } auto squeeze_op = static_cast(squeeze_it->get()); CHECK_EQ(squeeze_op->inputs.size(), 1); @@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { const auto& input_array = model->GetArray(squeeze_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } if (input_array.shape().dimensions_count() == 0) { // Input array cannot be 0-D. - return false; + return ::tensorflow::Status::OK(); } if (!model->HasArray(squeeze_op->outputs[0]) || !model->GetArray(squeeze_op->outputs[0]).has_shape()) { // Yield until shape propagation has set the output shape for us. - return false; + return ::tensorflow::Status::OK(); } // We use the output shape that has been calculated by shape propagation. @@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { // Empty shapes will not work as empty data arrays. if (output_shape.dimensions_count() == 0) { - return false; + return ::tensorflow::Status::OK(); } auto* reshape_op = new TensorFlowReshapeOperator; @@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(squeeze_it->get(), squeeze_op); model->operators.erase(squeeze_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc index dcaaddbf3b..d7cacf77f4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc @@ -20,10 +20,13 @@ namespace toco { // This pass will convert an AddN operator with only 2 inputs into a regular Add // operator, to which more optimizations may apply. -bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialAddNToAdd::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto addn_it = model->operators.begin() + op_index; if (addn_it->get()->type != OperatorType::kAddN) { - return false; + return ::tensorflow::Status::OK(); } AddNOperator* addn_op = static_cast(addn_it->get()); CHECK_GE(addn_op->inputs.size(), 2); @@ -31,7 +34,7 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) { // We only reduce AddN with N=2 to a regular Add. if (addn_op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } // Copy inputs & outputs to regular Add. @@ -45,7 +48,8 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) { addn_it = add_it + 1; CHECK_EQ(addn_it->get(), addn_op); model->operators.erase(addn_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index 75113a2a8c..78779243a9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -25,27 +25,30 @@ limitations under the License. namespace toco { -bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto pack_it = model->operators.begin() + op_index; if (pack_it->get()->type != OperatorType::kPack) { - return false; + return ::tensorflow::Status::OK(); } auto* pack_op = static_cast(pack_it->get()); if (pack_op->inputs.size() > 1) { // Not trivial. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(pack_op->outputs.size(), 1); const auto& input_array = model->GetArray(pack_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } if (input_array.shape().dimensions_count() == 0) { // Input array cannot be 0-D. // (Unsure if this is TF behavior, but was required to get a test to pass.) - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op)); @@ -75,7 +78,8 @@ bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(pack_it->get(), pack_op); model->operators.erase(pack_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index b689be0792..b6d712ca44 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -21,10 +21,13 @@ limitations under the License. namespace toco { -bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto tile_it = model->operators.begin() + op_index; if (tile_it->get()->type != OperatorType::kTile) { - return false; + return ::tensorflow::Status::OK(); } auto* tile_op = static_cast(tile_it->get()); @@ -34,13 +37,13 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { if (!input_array.has_shape() || !multiples_array.has_shape() || !output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. - return false; + return ::tensorflow::Status::OK(); } // Note: We can assume we have error checked inputs in PropagateFixedSizes. if (!multiples_array.buffer) { // Yield until the multiples is constant. - return false; + return ::tensorflow::Status::OK(); } std::vector const& multiples = multiples_array.GetBuffer().data; @@ -59,7 +62,7 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { // The tile is non-trivial. Good luck. AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", LogName(*tile_op)); - return false; + return ::tensorflow::Status::OK(); } // The tile is like a concat. @@ -88,7 +91,8 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { CHECK_EQ(tile_it->get(), tile_op); model->operators.erase(tile_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index 5a36a90b38..e5a96d4335 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector perm, } // namespace -bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto transpose_it = model->operators.begin() + op_index; if (transpose_it->get()->type != OperatorType::kTranspose) { - return false; + return ::tensorflow::Status::OK(); } TransposeOperator* transpose_op = static_cast(transpose_it->get()); @@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { const auto& output_array = model->GetArray(transpose_op->outputs[0]); if (!input_array.has_shape() || !output_array.has_shape()) { // Yield until PropagateFixedSizes has been run on this op. - return false; + return ::tensorflow::Status::OK(); } // Note: We can assume we have error checked inputs in PropagateFixedSizes. // Check that the permutation has propogated. std::vector const& perm = transpose_op->perm; if (perm.empty()) { - return false; + return ::tensorflow::Status::OK(); } // This transpose is trivial if non-unitary dimensions remain in the same @@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { std::vector const& output_dims = output_array.shape().dims(); if (TransposeAffectsMemoryOrder(perm, input_dims)) { - return false; + return ::tensorflow::Status::OK(); } // This transpose is trivial. Replace it with a Reshape op. @@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) { CHECK_EQ(transpose_it->get(), transpose_op); model->operators.erase(transpose_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 1e68cd678b..ebc0e9afca 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -73,18 +73,22 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { return true; } -bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { +::tensorflow::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); switch (op->type) { case OperatorType::kConv: - return ProcessConvOperator(model, static_cast(op)); + *modified = ProcessConvOperator(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kTransposeConv: - return ProcessTransposeConvOperator( + *modified = ProcessTransposeConvOperator( model, static_cast(op)); + return ::tensorflow::Status::OK(); default: - return false; + return ::tensorflow::Status::OK(); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 1688586733..2119174950 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -186,24 +186,27 @@ bool DequantizeArray(const string& array_name, } // namespace -bool Dequantize::Run(Model* model, std::size_t op_index) { +::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto op_it = model->operators.begin() + op_index; auto* op = op_it->get(); if (op->type == OperatorType::kDequantize) { auto& input_array = model->GetArray(op->inputs[0]); if (input_array.data_type == ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } if (input_array.final_data_type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } input_array.data_type = ArrayDataType::kFloat; input_array.quantization_params = nullptr; auto& output_array = model->GetArray(op->outputs[0]); output_array.data_type = ArrayDataType::kFloat; output_array.quantization_params = nullptr; - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } std::vector arrays; @@ -220,7 +223,8 @@ bool Dequantize::Run(Model* model, std::size_t op_index) { } } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc index 95558ef5ec..1555cf60a1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -25,21 +25,23 @@ limitations under the License. namespace toco { -bool DropFakeQuant::Run(Model* model, std::size_t op_index) { +::tensorflow::Status DropFakeQuant::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto fakequant_it = model->operators.begin() + op_index; auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } auto* fakequant_op = static_cast(fakequant_base_op); if (!fakequant_op->minmax) { - return false; + return ::tensorflow::Status::OK(); } const auto& output_array = model->GetArray(fakequant_op->outputs[0]); if (!output_array.minmax) { - return false; + return ::tensorflow::Status::OK(); } // Drop min/max inputs @@ -50,7 +52,8 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) { } fakequant_op->inputs.resize(1); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc index f7fd878b7e..7d66ea5dd2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -19,15 +19,17 @@ limitations under the License. namespace toco { -bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { +::tensorflow::Status DropIm2colArrays::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto conv_it = model->operators.begin() + op_index; if (conv_it->get()->type != OperatorType::kConv) { - return false; + return ::tensorflow::Status::OK(); } auto* conv_op = static_cast(conv_it->get()); if (conv_op->outputs.size() < 2) { // Conv op does not have im2col. - return false; + return ::tensorflow::Status::OK(); } // Drop the im2col array. @@ -36,7 +38,8 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { conv_op->outputs.resize(1); AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc index e80ed036b3..72b1dda3be 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -62,17 +62,20 @@ bool ProcessLinearOperator(Model* model, Operator* op) { } } // namespace -bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) { +::tensorflow::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto* op = model->operators[op_index].get(); if (op->type == OperatorType::kConv || op->type == OperatorType::kDepthwiseConv || op->type == OperatorType::kFullyConnected) { if (ProcessLinearOperator(model, op)) { AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index c13fc0de75..60dcd52684 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -108,8 +108,9 @@ namespace toco { // we can foresee these 'fast int8 kernels' to remain important to have into // the 2020s. // -bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; const auto& op = *model->operators[op_index]; int weights_index = 0; switch (op.type) { @@ -148,16 +149,16 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, // That's why at the moment we only handle operators that use a GEMM // (Conv, fully-connected --- note that LSTM merely wraps a // fully-connected operator). - return false; + return ::tensorflow::Status::OK(); } const string& name = op.inputs[weights_index]; auto& array = model->GetArray(name); if (!array.buffer) { - return false; + return ::tensorflow::Status::OK(); } if (array.data_type != ArrayDataType::kUint8) { - return false; + return ::tensorflow::Status::OK(); } auto& buffer_data = array.GetMutableBuffer().data; @@ -212,7 +213,8 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, AddMessageF("Tweaked weights values for %s", LogName(op)); } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index c5ce3fcd95..88511a7d3c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -25,27 +25,30 @@ limitations under the License. namespace toco { -bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseActivationFunctions::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); if (ac_op->type != OperatorType::kRelu6 && ac_op->type != OperatorType::kRelu1 && ac_op->type != OperatorType::kRelu) { - return false; + return ::tensorflow::Status::OK(); } // Find the op producing the array passed to this activation function Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]); - if (!op) return false; + if (!op) return ::tensorflow::Status::OK(); if (CountTrueOutputs(*model, *op) > 1) { AddMessageF( "Not fusing activation function %s into %s because it has more than " "one consumed output", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs[0], ac_op->inputs[0]); @@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function into %s because it is consumed by more " "than 1 other operator", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsDiscardableArray(*model, op->outputs[0])) { @@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s into %s because output %s it is not " "discardable", LogName(*ac_op), LogName(*op), op->outputs[0]); - return false; + return ::tensorflow::Status::OK(); } if (op->fused_activation_function != FusedActivationFunctionType::kNone) { @@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s into %s because it already has a " "fused activation function", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } if (!OperatorSupportsFusedActivation(op->type)) { @@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s because the %s op doesn't support " "it", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Fusing activation function %s into the preceding %s", @@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { model->EraseArray(ac_op->inputs[0]); op->outputs[0] = ac_op->outputs[0]; model->operators.erase(ac_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index dcbbead517..0de22b8ff4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, } // namespace -bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can fuse into a constant. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { if (index_of_constant_input != 1) { AddMessageF("Not fusing %s because the denominator is not constant", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s into the following affine op, because we only know " "how to do so when the constant operand is a scalar", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { FusedActivationFunctionType::kNone) { AddMessageF("Not fusing %s because it has a fused activation function", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); @@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because it is not consumed by exactly one other op", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (following_op->type != OperatorType::kConv && @@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the following %s is not of one of the supported " "types", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } if (following_op->inputs.size() < 3) { AddMessageF( "Not fusing %s because the following %s does not have a bias vector", LogName(*following_op), LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& weights = model->GetArray(following_op->inputs[1]); @@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the following %s has non-constant weights or " "bias arrays", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } // Try to fuse the binary params into the following op's params @@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because the following %s does not use VALID padding", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } } if (following_op->type == OperatorType::kDepthwiseConv) { @@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because the following %s does not use VALID padding", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } } FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, @@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index b324631579..b8da756d85 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, } } // namespace -bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; const auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can fuse into a constant. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { if (index_of_constant_input != 1) { AddMessageF("Not fusing %s because the denominator is not constant", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { if (!preceding_op) { AddMessageF("Not fusing %s because it is not the output of another op", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } for (const string& output_array : model->flags.output_arrays()) { if (preceding_op->outputs[0] == output_array) { - return false; + return ::tensorflow::Status::OK(); } } @@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s is not of one of the supported " "types", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (preceding_op->fused_activation_function != @@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has a fused activation " "function", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (preceding_op->inputs.size() < 3) { AddMessageF( "Not fusing %s because the preceding %s does not have a bias vector", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& weights_name = preceding_op->inputs[1]; @@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has a non-constant bias " "array", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (count_ops_consuming_bias > 1) { AddMessageF( "Not fusing %s because the bias of the preceding %s is consumed by " "another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } } else { if (!weights.buffer || !bias.buffer) { @@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the preceding %s has non-constant weights or " "bias arrays", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) { AddMessageF( "Not fusing %s because the weights or bias of the preceding %s is " "consumed by another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the output of the preceding %s is consumed by " "another op", LogName(*binary_op), LogName(*preceding_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), @@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc index 874d8def57..4848867b9a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) { // Finds an operation that looks like a broadcast (concat of the same sources // along the last dimension) and drops it by relying on the ability of certain // binary ops to perform an implicit broadcast. -bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); // Test for binary ops of types that we know how to resolve if (binary_op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } // NOTE: either of these ops may be nullptr if the input array is constant. @@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { if (!is_op_0_broadcast && !is_op_1_broadcast) { // Neither input is a broadcast-looking thing. AddMessageF("Neither input looks broadcasty"); - return false; + return ::tensorflow::Status::OK(); } else if (is_op_0_broadcast && is_op_1_broadcast) { AddMessageF( "Unable to fuse broadcast into %s as both inputs (%s, %s) are " "broadcasts", LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", op[1] ? LogName(*op[1]) : "(?)"); - return false; + return ::tensorflow::Status::OK(); } int broadcast_index = is_op_0_broadcast ? 0 : 1; @@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; // We leave the broadcast op in; it'll get cleaned up if it's not used later. - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc index 6961e23690..8b0bc2d865 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -142,7 +142,7 @@ bool GraphTransformationsPass(int increment, Model* model, for (const auto& transformation : transformations) { CHECK(!changed_now); CHECK(transformation->Messages().empty()); - changed_now = transformation->Run(model, op_index); + CHECK(transformation->Run(model, op_index, &changed_now).ok()); const char* made_a_change_msg = changed_now ? "made a change" : "did NOT make a change"; const int log_level = diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 4d213b3f9c..a89db320ea 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -27,7 +27,8 @@ namespace toco { class GraphTransformation { public: - virtual bool Run(Model* model, std::size_t op_index) = 0; + virtual ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) = 0; virtual const char* Name() const = 0; virtual ~GraphTransformation() {} // Returns the list of messages that this graph transformation @@ -104,11 +105,12 @@ class GraphTransformationsSet { void RunGraphTransformations(Model* model, const string& message, const GraphTransformationsSet& transformations); -#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ - class GTName : public GraphTransformation { \ - public: \ - bool Run(Model* model, std::size_t op_index) override; \ - const char* Name() const override { return #GTName; } \ +#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ + class GTName : public GraphTransformation { \ + public: \ + ::tensorflow::Status Run(Model* model, std::size_t op_index, \ + bool* modified) override; \ + const char* Name() const override { return #GTName; } \ }; // List of all graph transformations @@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes) class PropagateDefaultMinMax : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "PropagateDefaultMinMax"; } bool has_any_ranges_defined() const { return !type_ranges_.empty(); } @@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation { class RemoveTrivialReshape : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "RemoveTrivialReshape"; } bool treat_expand_dims_as_trivial() const { return treat_expand_dims_as_trivial_; @@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation { class ResolveConstantFakeQuant : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "ResolveConstantFakeQuant"; } // True if the num_bits should adjust the final data type. @@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation { class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "EnsureUint8WeightsSafeForFastInt8Kernels"; } @@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { class IdentifyDilatedConv : public GraphTransformation { public: - bool Run(Model* model, std::size_t op_index) override; + ::tensorflow::Status Run(Model* model, std::size_t op_index, + bool* modified) override; const char* Name() const override { return "IdentifyDilatedConv"; } bool identify_depthwise_conv() const { return identify_depthwise_conv_; } void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 3114fa93e8..72df53548b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -372,7 +372,9 @@ bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) { } } // namespace -bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { +::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); bool changed = false; @@ -467,7 +469,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { if (changed) { AddMessageF("Hardcoded min-max through %s", LogName(*op)); } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index aac77eb39e..9e4a3005a1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op, return true; } -bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyDilatedConv::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* stb_op = it->get(); @@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { // *************************************************************************** // SpaceToBatch Op. if (stb_op->type != OperatorType::kSpaceToBatchND) { - return false; + return ::tensorflow::Status::OK(); } if (stb_op->inputs.size() != 3) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(stb_op->outputs.size(), 1); // Extract the dilation factor from Input[1] of SpaceToBatch // TODO(mjmatthews): Support 2D dilation factors. const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); if (!block_shape_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); int dilation_factor = @@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { // Expand Op auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); if (!post_stb_op) { - return false; + return ::tensorflow::Status::OK(); } bool has_expand_op = false; if (post_stb_op->type == OperatorType::kExpandDims) { @@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index b78efd7fc3..78f60f52fb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -39,7 +39,10 @@ std::vector>::iterator FindOperator( } } // namespace -bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyL2Normalization::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto div_it = model->operators.begin() + op_index; const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; @@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } else if (div_or_mul_op->type == OperatorType::kMul) { expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(div_or_mul_op->inputs.size(), 2); Operator* op_producing_div_or_mul_input[2] = { @@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { if (!op_producing_div_or_mul_input[1] || op_producing_div_or_mul_input[1]->type != expected_op_type_producing_div_or_mul_input) { - return false; + return ::tensorflow::Status::OK(); } Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); Operator* op_producing_sqrt_or_rsqrt_input = GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); if (!op_producing_sqrt_or_rsqrt_input) { - return false; + return ::tensorflow::Status::OK(); } // There may be an Add or a Maximum here, adding or clamping to a "small" @@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { " because the operator producing the input to the square root, %s," ", does not match the expected pattern", LogName(*op_producing_sqrt_or_rsqrt_input)); - return false; + return ::tensorflow::Status::OK(); } } @@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", LogName(*sum_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); @@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", LogName(*square_op)); - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(square_op->inputs.size(), 1); @@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: %s does not " "take the same input as the Mul/Div node", LogName(*square_op)); - return false; + return ::tensorflow::Status::OK(); } // Create and emplace the new L2Normalization @@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); model->EraseArray(div_or_mul_op->inputs[1]); model->operators.erase(FindOperator(model, div_or_mul_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index 705e73779b..13664bb344 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -38,11 +38,13 @@ std::vector>::iterator FindOperator( } } // namespace -bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); if (sqrt_op->type != OperatorType::kSqrt) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(sqrt_op->inputs.size(), 1); @@ -56,7 +58,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " "expected AveragePool op, but Sqrt op has no preceding op"); - return false; + return ::tensorflow::Status::OK(); } if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { @@ -64,7 +66,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Pool subgraph: " "expected AveragePool op, got %s", LogName(*prev_to_sqrt_op)); - return false; + return ::tensorflow::Status::OK(); } avpool_op = static_cast(prev_to_sqrt_op); @@ -77,7 +79,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Pool subgraph: " "expected Square op, got %s", LogName(*square_op)); - return false; + return ::tensorflow::Status::OK(); } // Create and emplace L2Pool node. @@ -107,7 +109,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, avpool_op)); model->operators.erase(FindOperator(model, sqrt_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index c0b014b45e..7fd8f906e2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // namespace -bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; // This LSTM cell identification method is not invariant to commutation of // commutative operator inputs. For example, if input[0] and input[1] of the // final output multiplication were swapped, this method would not identify it @@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { auto op_it = model->operators.begin() + op_index; Operator* final_output_mul = op_it->get(); if (final_output_mul->type != OperatorType::kMul) { - return false; + return ::tensorflow::Status::OK(); } Operator *state_output_tanh, *fc_output_sig; if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, &state_output_tanh, OperatorType::kLogistic, &fc_output_sig)) { - return false; + return ::tensorflow::Status::OK(); } // State output TanH @@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { Operator* state_combine_add; if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, &state_combine_add)) { - return false; + return ::tensorflow::Status::OK(); } // State forget & remember addition @@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, &state_forget_mul, OperatorType::kMul, &state_remember_mul)) { - return false; + return ::tensorflow::Status::OK(); } const string prev_state = state_forget_mul->inputs[0]; @@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, nullptr, OperatorType::kLogistic, &state_forget_sig)) { - return false; + return ::tensorflow::Status::OK(); } // State remember gate @@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, &state_remember_sig, OperatorType::kTanh, &state_info_tanh)) { - return false; + return ::tensorflow::Status::OK(); } // State remember "information" activation function Operator* fc_output_split; if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, &fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // State remember gate activation function Operator* tmp; if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // State forget gate activation function if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected output activation function if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, &tmp) || (tmp != fc_output_split)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected output split Operator* fully_connected; if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, nullptr, OperatorType::kFullyConnected, &fully_connected)) { - return false; + return ::tensorflow::Status::OK(); } // Fully connected op @@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { OperatorType::kConcatenation, &concat_inputs, OperatorType::kNone, nullptr, OperatorType::kNone, nullptr)) { - return false; + return ::tensorflow::Status::OK(); } if (static_cast(fully_connected)->weights_format != FullyConnectedWeightsFormat::kDefault) { // Not yet implemented: experimental shuffled weights in fused LSTM cell. - return false; + return ::tensorflow::Status::OK(); } // Emplace a new LSTM cell operator @@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, *fully_connected)); DeleteArrayIfUnused(concat_inputs->outputs[0], model); model->operators.erase(FindOperator(model, *concat_inputs)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 5b6a984ee1..6ccce923f3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -25,19 +25,22 @@ limitations under the License. namespace toco { -bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { +::tensorflow::Status MergeLstmCellInputs::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // Find lstm cell. auto op_it = model->operators.begin() + op_index; auto src_op = op_it->get(); if (src_op->type != OperatorType::kLstmCell) { - return false; + return ::tensorflow::Status::OK(); } // Already a compact LstmCell. Do not need to merge cell inputs. const auto* src_lstm_op = static_cast(src_op); if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL || src_lstm_op->inputs.size() != kExtendedLstmInputCount) { - return false; + return ::tensorflow::Status::OK(); } // Identify prev_activ_input, prev_state_input as required Op inputs, @@ -45,12 +48,12 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { string prev_activ_input; if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor], &prev_activ_input)) { - return false; + return ::tensorflow::Status::OK(); } string prev_state_input; if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor], &prev_state_input)) { - return false; + return ::tensorflow::Status::OK(); } // Get LstmCell's cell, input, output size. @@ -184,7 +187,8 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model); model->operators.erase(FindOp(*model, src_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 46d1fce50e..ad5120e2aa 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -25,19 +25,22 @@ limitations under the License. namespace toco { -bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { +::tensorflow::Status SplitLstmCellInputs::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // Find lstm cell. auto op_it = model->operators.begin() + op_index; auto curr_op = op_it->get(); if (curr_op->type != OperatorType::kLstmCell) { - return false; + return ::tensorflow::Status::OK(); } const auto* curr_lstm_op = static_cast(curr_op); // Already an extended LstmCell. Do not need to split cell inputs. if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC || curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) { - return false; + return ::tensorflow::Status::OK(); } // Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays, @@ -46,13 +49,13 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { *model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) || !IsConstantParameterArray( *model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) { - return false; + return ::tensorflow::Status::OK(); } // Make sure propagate_fixed_sizes has defined the size of the output. if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT]) .has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). @@ -168,7 +171,8 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); model->operators.erase(FindOp(*model, curr_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc index b90a156a0d..c11fee4dc9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -43,13 +43,15 @@ limitations under the License. namespace toco { -bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto add_op_it = model->operators.begin() + op_index; const auto* add_op = add_op_it->get(); if (add_op == nullptr || add_op->type != OperatorType::kAdd || add_op->inputs.size() != 2 || add_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]); @@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { relu_input_op->inputs.size() != 1 || relu_input_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } // TODO(ycling): Both Add and Mul are commutative. Support the case where @@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { if (mul_op == nullptr || mul_op->type != OperatorType::kMul || mul_op->inputs.size() != 2 || mul_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } const auto neg_alpha_tensor_name = mul_op->inputs[0]; @@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { if (relu_neg_input_op == nullptr || relu_neg_input_op->inputs.size() != 1) { - return false; + return ::tensorflow::Status::OK(); } const Operator* final_input_op; @@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { relu_neg_input_op->type != OperatorType::kRelu || relu_neg_input_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } final_input_op = neg_input_op; } if (relu_input_op->inputs[0] != final_input_op->inputs[0]) { - return false; + return ::tensorflow::Status::OK(); } const auto input_tensor_name = relu_input_op->inputs[0]; @@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { // intermediate tensors aren't used by other ops, those will be removed by // other graph transformation rules. model->operators.erase(FindOp(*model, add_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index 94820a0166..51d0629362 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, } } // namespace -bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); if (op_0->type != OperatorType::kMinimum && op_0->type != OperatorType::kMaximum) { - return false; + return ::tensorflow::Status::OK(); } // Get the paired op and ensure it's the counter to the first. @@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { (op_1->type != OperatorType::kMinimum && op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { - return false; + return ::tensorflow::Status::OK(); } const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) { - return false; + return ::tensorflow::Status::OK(); } // Get the original input to the min+max pair. @@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { int max_scalar_input_index = GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f); if (min_scalar_input_index == -1 || max_scalar_input_index == -1) { - return false; + return ::tensorflow::Status::OK(); } int op_0_scalar_input_index = op_0 == min_op ? min_scalar_input_index : max_scalar_input_index; @@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, op_0)); model->operators.erase(FindOperator(model, op_1)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index f684de08ab..5bf17d5b4c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -97,7 +97,10 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, return true; } -bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status MakeInitialDequantizeOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // This is effectively a transformation applied to edges. We iterate over the // specified node (op) and proceed for input edges. const auto it = model->operators.begin() + op_index; @@ -114,7 +117,8 @@ bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) { } } } - return change_made; + *modified = change_made; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 95bc7f7d4b..06de9b1cd8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -102,18 +102,19 @@ std::vector ReshapeToTranspose(const Model& model, // to be merged if the reshape does not affect memory ordering and does not // affects the number of dimensions. This only occurs when only unary dimensions // are shifting position. -bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { - return false; + return ::tensorflow::Status::OK(); } const string intermediate_name = reshape_op->inputs[0]; @@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // Guarantee the input is only consume by the reshape. if (CountOpsWithInput(*model, intermediate_name) != 1) { - return false; + return ::tensorflow::Status::OK(); } // Check for the parent operator. const auto& transpose_it = FindOpWithOutput(*model, intermediate_name); if (transpose_it == model->operators.end()) { - return false; + return ::tensorflow::Status::OK(); } // Find the parent operator and guarantee it is a transpose. @@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, transpose_it->get(), OperatorType::kTranspose); if (transpose_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { - return false; + return ::tensorflow::Status::OK(); } if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, false /*allow_extra_unary_dimensions*/)) { - return false; + return ::tensorflow::Status::OK(); } // Check that the intermediate is not an output array. @@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, "Cannot fuse %s and %s as it would invalidate the transpose " "output array.", LogName(*transpose_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Merging operations %s and %s", LogName(*transpose_op), @@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // Remove the reshape as passthrough operation. if (!RemoveTrivialPassthroughOp(this, model, op_index)) { - return false; + return ::tensorflow::Status::OK(); } // Update transpose_op's constant buffer to contain the new permutation. @@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, // transpose_ops's shape will likely has changed. model->GetArray(transpose_op->outputs[0]).clear_shape(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc index 7f44c65285..f0d8d924ad 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) { // // Note we are testing for one particular case of a broader set of possible // binary-reshape op transformations. This transformation could be generalized. -bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; Operator* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && @@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kLessEqual && binary_op->type != OperatorType::kGreater && binary_op->type != OperatorType::kGreaterEqual) { - return false; + return ::tensorflow::Status::OK(); } // BINARY OP INPUT CHECKS @@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { if (!input_is_const[0] && !input_is_const[1]) { // To limit our scope, we require one constant input. Though there's no // reason this transformation wouldn't work with all variable inputs. - return false; + return ::tensorflow::Status::OK(); } if (input_is_const[0] && input_is_const[1]) { // Both inputs are constants. Leave this for constants propagation. - return false; + return ::tensorflow::Status::OK(); } const int constant_input_idx = input_is_const[0] ? 0 : 1; const int variable_input_idx = input_is_const[0] ? 1 : 0; @@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { AddMessageF( "Not moving %s because it's non-constant input shape is not resolved.", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsTailOfShape( model->GetArray(binary_op->inputs[constant_input_idx]).shape(), model->GetArray(binary_op->inputs[variable_input_idx]).shape())) { // Constant array shape must be the latter part of the variable shape. - return false; + return ::tensorflow::Status::OK(); } // RESHAPE OP CHECKS @@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { if (reshape_it == model->operators.end()) { AddMessageF("Not moving %s because it's variable input is not connected.", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { AddMessageF("Not moving %s because the preceding %s is not a reshape op", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]); if (!reshape_input_array.has_shape()) { @@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { "Not moving %s because it's non-constant input shape is not resolved " "yet", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsTailOfShape( model->GetArray(binary_op->inputs[constant_input_idx]).shape(), model->GetArray(reshape_op->outputs[0]).shape())) { // Constant array shape must be the latter part of the binary op output // shape. - return false; + return ::tensorflow::Status::OK(); } // EXTRA CHECKS ON CONNECTING ARRAY @@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { AddMessageF( "Not moving %s because the output of reshape op %s is an output op.", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } } int count_ops_consuming_output = @@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { "Not moving %s because the output of reshape op %s is consumed by " "another op", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } // SWAP ORDER OF BINARY AND RESHAPE OPS @@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { // Clear binary output shape so it will be re-propagated model->GetArray(binary_op->outputs[0]).clear_shape(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc index cf17c49b10..9c1ed2b732 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -26,20 +26,21 @@ limitations under the License. namespace toco { -bool PropagateActivationFunctionIntoConstants::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status PropagateActivationFunctionIntoConstants::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); if (ac_op->type != OperatorType::kRelu6 && ac_op->type != OperatorType::kRelu1 && ac_op->type != OperatorType::kRelu) { - return false; + return ::tensorflow::Status::OK(); } // Find the op producing the array passed to this activation function. auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]); if (!src_op) { - return false; + return ::tensorflow::Status::OK(); } // Ensure the src_op is not used without the activation function applied. @@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, src_op_input = src_op->inputs[0]; break; default: - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]); @@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, "Not propagating activation function %s into %s:%s because it is not " "constant", LogName(*ac_op), LogName(*src_op), src_op_input); - return false; + return ::tensorflow::Status::OK(); } // Get the array we'll be working with and ensure it's a compatible type. @@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, "Not propagating activation function %s into %s:%s because it is " "non-float data", LogName(*ac_op), LogName(*src_op), src_op_input); - return false; + return ::tensorflow::Status::OK(); } auto& const_array_data = const_array.GetMutableBuffer().data; @@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, } default: LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op); - return false; + return ::tensorflow::Status::OK(); } const_array_data[i] = new_value; } AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op), LogName(*src_op), src_op_input); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 323eefcd3a..40cd6dea82 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, } } // namespace -bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status PropagateArrayDataTypes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); @@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { if (!model->IsOptionalArray(input) && model->GetArray(input).data_type == ArrayDataType::kNone) { - return false; + return ::tensorflow::Status::OK(); } } // Record data types of output before processing, so we can see at the @@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { auto* rand_op = static_cast(op); // The output type of RandomUniform is specified with an attribute if (rand_op->dtype == ArrayDataType::kNone) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs.size(), 1); SetDataTypeForAllOutputs(model, op, rand_op->dtype); @@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // This can make unsupported_op->output_data_types have more elements than // op->outputs. if (unsupported_op->output_data_types.size() < op->outputs.size()) { - return false; + return ::tensorflow::Status::OK(); } for (int i = 0; i < op->outputs.size(); ++i) { const string& output = op->outputs[i]; @@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { } case OperatorType::kExpandDims: { // Yield on ExpandDim until it is converted to Reshape - return false; + return ::tensorflow::Status::OK(); } case OperatorType::kSelect: { // Select produces outputs with the same type as their 2nd input @@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // Return true if any output data type changed, false if none changed. for (const auto& output : op->outputs) { if (old_output_data_types[output] != model->GetArray(output).data_type) { - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc index cd078ef189..3cf191436d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -39,7 +39,10 @@ bool SupportsMinMax(const Array& array) { // When provided a set of min/max values for uint8 arrays this will rescale // the values for other data types as required and preserving the floating point // range within the new type. -bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { +::tensorflow::Status PropagateDefaultMinMax::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); @@ -61,7 +64,8 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { } } - return did_change; + *modified = did_change; + return ::tensorflow::Status::OK(); } // Sets the min/max on the given array, adjusting the reference_minmax for the diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 3ad6b0ec6f..d0113237ce 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -277,11 +277,14 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, // nice logging and integration with the graphviz video dumping mode. // In general you should not copy this style of transformation and stick to // local-only changes as seen in the other transformations. -bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) { +::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } auto* fakequant_op = static_cast(op); @@ -290,7 +293,7 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) { &quantized_data_type)) { AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring", LogName(*op), fakequant_op->num_bits); - return false; + return ::tensorflow::Status::OK(); } const auto& final_minmax = *fakequant_op->minmax; @@ -311,7 +314,8 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) { did_change |= RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type); - return did_change; + *modified = did_change; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index d056a8add7..5496e2093e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1622,7 +1622,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) { } // namespace -bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status PropagateFixedSizes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); std::unordered_map> old_output_dims; @@ -1836,7 +1839,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast(op); // Attribute can be not specified, ignore it. if (unsupported_op->output_shapes.size() < op->outputs.size()) { - return false; + return ::tensorflow::Status::OK(); } for (int i = 0; i < op->outputs.size(); ++i) { const string& output = op->outputs[i]; @@ -1886,10 +1889,11 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { (old_output_dims[output] != model->GetArray(output).shape().dims())) { AddMessageF("Set shape of %s to [%s]", output, absl::StrJoin(model->GetArray(output).shape().dims(), ",")); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index fb299c31b7..29ea17dc61 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -439,7 +439,9 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation, } // namespace -bool Quantize::Run(Model* model, std::size_t op_index) { +::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; // Our general "quantization" graph transformation consists in replacing // QuantizedInputArrays[] -> // DequantizeOperators[] -> @@ -460,7 +462,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { auto& op = *model->operators[op_index]; if (op.type == OperatorType::kDequantize || op.type == OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } // Our assumption here is that the input arrays are already quantized - @@ -497,7 +499,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { if (!array.minmax && !array.buffer) { LOG(ERROR) << "Can't quantize input array " << input << " because it lacks min/max info"; - return false; + return ::tensorflow::Status::OK(); } const auto* other_op = GetOpWithOutput(*model, input); if (other_op && other_op->type != OperatorType::kDequantize) { @@ -507,7 +509,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { "which means that we should yield and let other ops " "get quantized first", LogName(op), input); - return false; + return ::tensorflow::Status::OK(); } } } @@ -672,7 +674,8 @@ bool Quantize::Run(Model* model, std::size_t op_index) { } } - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc index eaa9d3bcda..0c32218ff2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -51,18 +51,19 @@ bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model, } // end namespace -bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; const auto fakequant_it = model->operators.begin() + op_index; auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } auto* fq_op = static_cast(fakequant_base_op); if (!fq_op->minmax) { // Need to be resolved first by ResolveFakeQuantArgsFromVars. - return false; + return ::tensorflow::Status::OK(); } // At this point, this FakeQuantOperator should have a MinMax @@ -74,7 +75,8 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model, bool changed = false; changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]); changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]); - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc index c3b2709a33..fe8023ab8f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -25,11 +25,14 @@ limitations under the License. namespace toco { -bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveFinalDequantizeOp::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto dequantize_it = model->operators.begin() + op_index; const auto* dequantize_op = dequantize_it->get(); if (dequantize_op->type != OperatorType::kDequantize) { - return false; + return ::tensorflow::Status::OK(); } const auto& output = dequantize_op->outputs[0]; // We can remove any dequantize op whose output is not consumed by @@ -38,7 +41,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { // in the middle of the graph might be designated as an output // array. if (CountOpsWithInput(*model, output)) { - return false; + return ::tensorflow::Status::OK(); } // If one of the model's output arrays was actually the Dequantize op's @@ -53,7 +56,8 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { AddMessageF("Removed final %s", LogName(*dequantize_op)); model->EraseArray(output); model->operators.erase(dequantize_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 73ad326299..be8c0acc7b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -23,11 +23,14 @@ limitations under the License. namespace toco { -bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTensorFlowAssert::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); if (assert_op->type != OperatorType::kAssert) { - return false; + return ::tensorflow::Status::OK(); } bool changed = false; @@ -54,7 +57,8 @@ bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { // That's it. We can stop here, no need to duplicate the work that // RemoveUnusedOp will do removing this now-unused node. - return changed; + *modified = changed; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 7ec7752f25..37fe5fa3d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -25,14 +25,18 @@ limitations under the License. namespace toco { -bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTensorFlowIdentity::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); if (passthru_op->type != OperatorType::kIdentity) { - return false; + return ::tensorflow::Status::OK(); } - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc index 0dfdc40e4c..68c6fb65c5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector& buffer_data, // For example, an Add operator is trivial if // one of its operands is constant 0, a Mul operator is trivial // if one of its operands is constant 1, etc. -bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can resolve here. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { const auto& input_array_1 = model->GetArray(binary_op->inputs[1]); if (!input_array_0.has_shape() || !input_array_1.has_shape()) { // Both input shapes must be known. - return false; + return ::tensorflow::Status::OK(); } if (input_array_0.shape().dimensions_count() == input_array_1.shape().dimensions_count() && @@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { "(lhs %s, rhs %s)", LogName(*binary_op), ShapeToString(input_array_0.shape()), ShapeToString(input_array_1.shape())); - return false; + return ::tensorflow::Status::OK(); } // Now check if the constant operand makes this binary @@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { model->GetArray(binary_op->inputs[index_of_constant_input]); // For now, we only handle floats here. if (constant_input_array.data_type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } const auto& constant_input_float_data = constant_input_array.GetBuffer().data; @@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { } if (!is_trivial) { - return false; + return ::tensorflow::Status::OK(); } // Now we know that this node is trivial, so we can remove it. AddMessageF("Removing trivial %s", LogName(*binary_op)); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc index 3ceb93d8ee..faaa2a828e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -25,16 +25,20 @@ limitations under the License. namespace toco { -bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialConcatenation::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto concat_it = model->operators.begin() + op_index; auto* concat_op = concat_it->get(); if (concat_op->type != OperatorType::kConcatenation) { - return false; + return ::tensorflow::Status::OK(); } if (concat_op->inputs.size() != 1) { - return false; + return ::tensorflow::Status::OK(); } - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc index 936854a04f..ccfc181fe0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -25,7 +25,10 @@ limitations under the License. namespace toco { -bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialConcatenationInput::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // TensorFlow allows Concatenation nodes to have 0-D inputs, // and they are then treated as empty i.e. omitted from concatenation, // in violation of the notion that 0-D is equivalent to 1x1x1x1. @@ -36,7 +39,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { const auto concat_it = model->operators.begin() + op_index; auto* concat_op = concat_it->get(); if (concat_op->type != OperatorType::kConcatenation) { - return false; + return ::tensorflow::Status::OK(); } std::vector trivial_inputs; std::vector nontrivial_inputs; @@ -52,7 +55,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { } if (trivial_inputs.empty()) { - return false; + return ::tensorflow::Status::OK(); } // Drop trivial inputs. @@ -63,7 +66,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { } } concat_op->inputs = nontrivial_inputs; - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc index 2c8d04440f..5448a816bc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc @@ -64,23 +64,27 @@ bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model, } // namespace // Removes FakeQuant ops that are trivial (have no effect, are redundant, etc). -bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialFakeQuant::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto op_it = model->operators.begin() + op_index; auto* op = op_it->get(); if (op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } auto* fakequant_op = static_cast(op); if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) { AddMessageF("%s is not trivial", LogName(*fakequant_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Removing trivial %s", LogName(*fakequant_op)); CHECK_EQ(fakequant_op->inputs.size(), 1); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc index 752560e075..4133815285 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -94,12 +94,13 @@ bool IsTrivialFusedActivationFunc( // Attempts to remove both fused and unfused activation functions if the // quantization params indicate that the representable values fall inside the // activation range. -bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->inputs.empty()) { - return false; + return ::tensorflow::Status::OK(); } if (IsTrivialUnfusedActivationFunc(this, *model, op->type, op->inputs[0])) { @@ -107,7 +108,8 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, "Removing trivial unfused activation function %s because the input " "minmax imply at least as tight a clamp anyway.", LogName(*op)); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } if (IsTrivialFusedActivationFunc(this, *model, op->fused_activation_function, op->outputs[0])) { @@ -117,9 +119,10 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, "because the output quantization parameters imply at least as tight " "a clamp anyway.", LogName(*op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index 142c876b15..0f0ae4af69 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -69,22 +69,26 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, // Attempts to remove min/max functions if the quantization params indicate that // the representable values fall inside the clip range. -bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* op = it->get(); if ((op->type != OperatorType::kMinimum && op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } if (IsTrivialMinMax(this, *model, op->type, op->inputs[0], op->inputs[1])) { AddMessageF( "Removing trivial min/max %s because the quantization parameters imply " "at least as tight a clamp anyway.", LogName(*op)); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index 5295eeccec..1caf944879 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -81,22 +81,26 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, } // namespace -bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { - return false; + return ::tensorflow::Status::OK(); } if (!IsReshapeTrivial(*model, *reshape_op, this)) { AddMessageF("%s is not trivial", LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Removing trivial %s", LogName(*reshape_op)); CHECK_EQ(reshape_op->inputs.size(), 2); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc index 0cbbcd7c81..dcb0148d58 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc @@ -49,21 +49,24 @@ bool IsSliceTrivial(const Model& model, const Operator& op, } // namespace -bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveTrivialSlice::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto reshape_it = model->operators.begin() + op_index; auto* slice_op = reshape_it->get(); if (slice_op->type != OperatorType::kSlice) { - return false; + return ::tensorflow::Status::OK(); } if (!IsSliceTrivial(*model, *slice_op, this)) { - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Removing trivial %s", LogName(*slice_op)); CHECK_EQ(slice_op->inputs.size(), 3); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index dde91234a8..3cd5d06bae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -25,7 +25,9 @@ limitations under the License. namespace toco { -bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { +::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); @@ -58,7 +60,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { } for (const string& output_array : model->flags.output_arrays()) { if (output == output_array) { - return false; + return ::tensorflow::Status::OK(); } } for (const auto& rnn_state : model->flags.rnn_states()) { @@ -67,19 +69,19 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) || !IsDiscardableArray(*model, rnn_state.state_array()) || CountOpsWithInput(*model, rnn_state.state_array())) { - return false; + return ::tensorflow::Status::OK(); } } } if (CountOpsWithInput(*model, output)) { - return false; + return ::tensorflow::Status::OK(); } } if (op->unresolved_outputs) { AddMessageF("Not discarding %s because it has unresolved outputs.", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Discarding %s because none of its outputs is used.", @@ -105,7 +107,8 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { } } model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 550de83018..3c8d411089 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -63,29 +63,32 @@ bool IsMoveOperator(OperatorType optype) { // Swap elementwise operators such that all value operators occur before all // element move operators, e.g. negation then transpose. -bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ReorderElementwiseUnary::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto element_op_it = model->operators.begin() + op_index; std::unique_ptr& element_op = *element_op_it; if (!IsElementwiseOperator(element_op->type)) { - return false; + return ::tensorflow::Status::OK(); } const string intermediate_name = element_op->inputs[0]; auto it = FindOpWithOutput(*model, intermediate_name); if (it == model->operators.end()) { AddMessageF("No preceding operator"); - return false; + return ::tensorflow::Status::OK(); } std::unique_ptr& move_op = *it; if (!IsMoveOperator(move_op->type)) { AddMessageF("Preceding operator is not a move operator"); - return false; + return ::tensorflow::Status::OK(); } if (CountOpsWithInput(*model, intermediate_name) != 1) { AddMessageF("Input %s used elsewhere", intermediate_name); - return false; + return ::tensorflow::Status::OK(); } // Check that the intermediate is discardable. @@ -94,7 +97,7 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) { "Cannot swap elementwise as it would invalidate %s which is " "an output array.", intermediate_name); - return false; + return ::tensorflow::Status::OK(); } // op->inputs may change so we need to keep a value by copy. @@ -147,7 +150,8 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) { // Swap the order of the operators. element_op.swap(move_op); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc index c907a597cb..a2c06e71e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -101,37 +101,40 @@ std::vector ComputeNewPerm(std::vector input_dims, // Swaps reshape-transpose to transpose-reshape whenever possible. This is // possible when the reshape does not affect memory ordering. -bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ReorderReshapeTranspose::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto transpose_it = model->operators.begin() + op_index; TransposeOperator* transpose_op = ConvertOperator( transpose_it->get(), OperatorType::kTranspose); if (transpose_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) { // Wait for values to propagate. - return false; + return ::tensorflow::Status::OK(); } // Find the operator that produces the transpose op. auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]); if (reshape_it == model->operators.end()) { - return false; + return ::tensorflow::Status::OK(); } TensorFlowReshapeOperator* reshape_op = ConvertOperator(reshape_it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { - return false; + return ::tensorflow::Status::OK(); } // Ignore if the reshape is uninitialized. if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) { - return false; + return ::tensorflow::Status::OK(); } // Need to copy to keep static if permutated. @@ -142,7 +145,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { // Intermediate should not be consumed by any other operators. if (CountOpsWithInput(*model, intermediate_name) != 1) { AddMessageF("Input %s used elsewhere", intermediate_name); - return false; + return ::tensorflow::Status::OK(); } // Check that the intermediate is not an output array. @@ -151,7 +154,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { "Cannot reorder reshape-transpose as it would invalidate %s which is " "an output array.", intermediate_name); - return false; + return ::tensorflow::Status::OK(); } // Get the arrays. @@ -173,7 +176,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { // dimensions then it can be moved between the transpose. if (!ReshapeIsEquivalentToTranspose(*model, reshape_op, true /*allow_extra_unary_dims*/)) { - return false; + return ::tensorflow::Status::OK(); } if (!IsDiscardableArray(*model, output_name)) { @@ -242,7 +245,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { // Swap the order of the operators. transpose_it->swap(*reshape_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 8f2c1f8162..a79779f55d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -25,10 +25,13 @@ limitations under the License. namespace toco { -bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveBatchNormalization::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto bn_it = model->operators.begin() + op_index; if (bn_it->get()->type != OperatorType::kBatchNormalization) { - return false; + return ::tensorflow::Status::OK(); } const auto* bn_op = static_cast(bn_it->get()); @@ -53,7 +56,7 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { // so we need to exit early if these buffers don't exist (i.e. if the params // haven't yet been resolved as constants). if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } // Create the new Mul, Add operators @@ -142,7 +145,8 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { DCHECK_EQ(bn_it->get(), bn_op); model->operators.erase(bn_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index b8b35161d7..d039d7d690 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -24,31 +24,35 @@ limitations under the License. namespace toco { -bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto op_it = model->operators.begin() + op_index; - if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false; + if (op_it->get()->type != OperatorType::kBatchToSpaceND) + return ::tensorflow::Status::OK(); auto* op = static_cast(op_it->get()); // The attributes are resolved only when the 3 attributes (block_shape, // before_crops, after_crops) are all constant. if (!op->block_shape.empty()) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->inputs.size(), 3); if (!IsConstantParameterArray(*model, op->inputs[1]) || !IsConstantParameterArray(*model, op->inputs[2])) - return false; + return ::tensorflow::Status::OK(); // Handle crops const auto& crops_array = model->GetArray(op->inputs[2]); - if (!crops_array.has_shape()) return false; + if (!crops_array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& crops_dims = crops_array.shape().dims(); if (crops_dims.size() != 2) { // Code only handles crops of 2 dimensions. Perhaps another transformation // will delete this op. - return false; + return ::tensorflow::Status::OK(); } const std::vector& crops_buffer = crops_array.GetBuffer().data; @@ -59,7 +63,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { // Handle block_shape const auto& block_shape_array = model->GetArray(op->inputs[1]); - if (!block_shape_array.has_shape()) return false; + if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); const std::vector& block_shape_buffer = @@ -68,7 +72,8 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { op->block_shape.push_back(block_shape_buffer[i]); } - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index f7e5aa6609..586f546a30 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, } } // namespace -bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; const auto* binary_op = binary_it->get(); // Test for binary ops of types that we know how to resolve @@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kLessEqual && binary_op->type != OperatorType::kGreater && binary_op->type != OperatorType::kGreaterEqual) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { const auto& input1_array = model->GetArray(binary_op->inputs[1]); // Check if both inputs are constant parameters. if (!input0_array.buffer || !input1_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } auto& output_array = model->GetArray(binary_op->outputs[0]); // Yield until the output array dims have been resolved. if (!output_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // At the moment we don't want to care about fused activation functions. @@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF( "Not resolving constant %s because it has a fused activation function", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } // Check that input data types agree. @@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*binary_op)); model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index d916ae0ddf..0c60fdfeb3 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation, } // namespace // Resolves the concatenation operator if all its inputs are constant arrays. -bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantConcatenation::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto concat_it = model->operators.begin() + op_index; const auto* concat_base_op = concat_it->get(); if (concat_base_op->type != OperatorType::kConcatenation) { - return false; + return ::tensorflow::Status::OK(); } const auto* concat_op = static_cast(concat_base_op); @@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // We also make sure the shapes of the input arrays are known and they are // all discardable. const Operator* input_op = GetOpWithOutput(*model, input_name); - if (input_op) return false; - if (!IsConstantParameterArray(*model, input_name)) return false; - if (!model->GetArray(input_name).has_shape()) return false; - if (model->GetArray(input_name).quantization_params) return false; - if (!IsDiscardableArray(*model, input_name)) return false; + if (input_op) return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, input_name)) + return ::tensorflow::Status::OK(); + if (!model->GetArray(input_name).has_shape()) + return ::tensorflow::Status::OK(); + if (model->GetArray(input_name).quantization_params) + return ::tensorflow::Status::OK(); + if (!IsDiscardableArray(*model, input_name)) + return ::tensorflow::Status::OK(); } const int concatenation_axis = concat_op->axis; @@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { // Remove concatenate operator. model->operators.erase(concat_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index f5f2f77460..4f330fdd84 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -59,11 +59,14 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type, } } -bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto fakequant_it = model->operators.begin() + op_index; const auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } const auto* fakequant_op = @@ -71,12 +74,12 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { // Yield until the fakequant MinMax has been resolved. if (!fakequant_op->minmax) { - return false; + return ::tensorflow::Status::OK(); } // This transformation only applies when the input array is constant. if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) { - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(fakequant_op->inputs[0]); @@ -87,7 +90,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, &quantized_data_type)) { AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Resolving constant %s", LogName(*fakequant_op)); @@ -136,7 +139,8 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { } model->operators.erase(fakequant_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc index f6f95481b5..5400d395ff 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc @@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) { return true; } -bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantFill::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto fill_it = model->operators.begin() + op_index; auto* base_op = fill_it->get(); if (base_op->type != OperatorType::kFill) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast(base_op); @@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } const auto& val_array = model->GetArray(op->inputs[1]); if (!val_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[1])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeFillArray(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kUint8: if (!ComputeFillArray(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt32: if (!ComputeFillArray(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt64: if (!ComputeFillArray(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; default: @@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(fill_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc index 36d7dad0ce..6e3a6a69c2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank, // Resolves a constant Gather operation. // This simply performs the gather and produces the output array with the // appropriate values. -bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantGather::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kGather) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } if (!op->axis) { // Yield until axis has been set by ResolveGatherAttributes. - return false; + return ::tensorflow::Status::OK(); } if (op->axis.value() != 0) { // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), op->axis.value()); - return false; + return ::tensorflow::Status::OK(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); const Array& coords_array = model->GetArray(op->inputs[1]); @@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc index e86616574d..e257ec37e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc @@ -49,11 +49,14 @@ void Pack(Model* model, PackOperator const& op) { } // namespace -bool ResolveConstantPack::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantPack::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kPack) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -62,18 +65,18 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input)) { // Yield if any input is mutable - return false; + return ::tensorflow::Status::OK(); } } @@ -111,7 +114,8 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc index 88d06d7dc7..db0fbba528 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc @@ -59,11 +59,14 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) { return true; } -bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* base_op = it->get(); if (base_op->type != OperatorType::kRandomUniform) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast(base_op); @@ -73,12 +76,12 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } if ((op->seed == 0) && (op->seed2 == 0)) { @@ -86,13 +89,13 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { << "\" is truly random (using /dev/random system entropy). " "Therefore, cannot resolve as constant. Set \"seed\" or " "\"seed2\" attr non-zero to fix this"; - return false; + return ::tensorflow::Status::OK(); } switch (output_array.data_type) { case ArrayDataType::kFloat: if (!ComputeRandomUniformArray(model, op)) { - return false; + return ::tensorflow::Status::OK(); } break; // For future support of double or half. @@ -110,7 +113,8 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc index 1a0ba9e2bc..069d4dafaa 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc @@ -19,11 +19,14 @@ limitations under the License. namespace toco { -bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantRange::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* base_op = it->get(); if (base_op->type != OperatorType::kRange) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast(base_op); @@ -31,23 +34,23 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { const auto& start_array = model->GetArray(op->inputs[0]); if (!start_array.has_shape()) { // Yield until all input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } const auto& limit_array = model->GetArray(op->inputs[1]); if (!limit_array.has_shape()) { // Yield until all input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } const auto& delta_array = model->GetArray(op->inputs[2]); if (!delta_array.has_shape()) { // Yield until all input dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } for (const auto& input : op->inputs) { if (!IsConstantParameterArray(*model, input)) { // yield if any input is mutable - return false; + return ::tensorflow::Status::OK(); } } @@ -55,7 +58,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1) @@ -101,7 +104,8 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) { // Delete the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index a6f665b5f0..fccecef600 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -22,11 +22,14 @@ limitations under the License. namespace toco { // Resolves a constant reshape operation by copying the buffer. -bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kReshape) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); @@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { AddMessageF("Constant reshape is non-trivial (%s -> %s)", ShapeToString(input_array.shape()), ShapeToString(output_array.shape())); - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); @@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { default: LOG(FATAL) << "Unsupported data type: " << ArrayDataTypeName(input_array.data_type); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Resolving constant reshape of %s", LogName(*op)); @@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc index e880a3f44d..ab1e0bd7a0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc @@ -27,11 +27,14 @@ namespace toco { // This implementation is looking strictly for all-or-nothing on the select // condition. It's possible to enhance this by looking per-element and possibly // producing a Mul op. -bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantSelect::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSelect) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } // We require the cond input to be constant. if (!IsConstantParameterArray(*model, op->inputs[0])) { - return false; + return ::tensorflow::Status::OK(); } const Array& cond_array = model->GetArray(op->inputs[0]); CHECK(cond_array.data_type == ArrayDataType::kBool) << "Only bool conditions are supported"; const auto& cond_data = cond_array.GetBuffer().data; if (cond_data.empty()) { - return false; + return ::tensorflow::Status::OK(); } // Check if the condition is the same for all elements. @@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { "Cannot resolve %s as constant; cond_array has differing " "per-element values", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } } // Pass-through the selected input. - return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); + *modified = + RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 8a0e3e8995..a1756a8207 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -19,29 +19,32 @@ limitations under the License. namespace toco { -bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* op = it->get(); if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been resolved - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the input array's shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } // Compute the output @@ -65,7 +68,8 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { } model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc index b35c3e19c4..869dfae98e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -86,11 +86,14 @@ bool Slice(SliceOperator const& op, Array const& input_array, } // namespace -bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantSlice::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSlice) { - return false; + return ::tensorflow::Status::OK(); } const SliceOperator* op = static_cast(base_op); @@ -99,49 +102,49 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } if (op->begin.empty() || op->size.empty()) { // Attributes have not resolved yet. - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!Slice(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kUint8: if (!Slice(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt32: if (!Slice(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt64: if (!Slice(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; default: @@ -159,7 +162,8 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 8853ed87e6..99c5a64662 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, } // anonymous namespace -bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kStridedSlice) { - return false; + return ::tensorflow::Status::OK(); } const StridedSliceOperator* op = @@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } if (op->start_indices.empty() || op->stop_indices.empty() || op->strides.empty()) { // Attributes have not resolved yet. - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); @@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { DeleteOpAndArraysIfUnused(model, it->get()); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc index 5cfa1a5582..c5e93c9bad 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -97,11 +97,14 @@ inline void Tile(const Array& input_array, const Array& multiples_array, } // namespace // Resolves a constant Tile operation. -bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantTile::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kTile) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -110,17 +113,17 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); const Array& multiples_array = model->GetArray(op->inputs[1]); @@ -159,7 +162,8 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc index fe15dfa06f..b759c4d6dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc @@ -101,11 +101,14 @@ void Transpose(Model* model, const Array& input_array, } // namespace -bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantTranspose::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kTranspose) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast(base_op); @@ -114,17 +117,17 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); @@ -132,7 +135,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { if (op->perm.empty()) { // Yield until perm has been populated by ResolveTransposeAttributes. - return false; + return ::tensorflow::Status::OK(); } // We currently only support 1-4 dimensions. @@ -174,7 +177,8 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 5364eebbc9..3034c1b1eb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) { return true; } -bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto unary_it = model->operators.begin() + op_index; const auto* unary_op = unary_it->get(); // Test for unary ops of types that we know how to resolve. @@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu: break; default: - return false; + return ::tensorflow::Status::OK(); } // Check if the input is a constant parameter. if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { - return false; + return ::tensorflow::Status::OK(); } // if the unary op involves a tensor required by a rnn state, ignore it for (const auto& rnn_state : model->flags.rnn_states()) { if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { - return false; + return ::tensorflow::Status::OK(); } if (unary_op->inputs[0] == rnn_state.state_array()) { - return false; + return ::tensorflow::Status::OK(); } } auto& output_array = model->GetArray(unary_op->outputs[0]); if (!output_array.has_shape()) { // Yield until the output array dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } // At the moment we don't want to care about fused activation functions. @@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { "Not resolving constant %s " " because it has a fused activation function", LogName(*unary_op)); - return false; + return ::tensorflow::Status::OK(); } // The min-max is only copied for ops that copy data without arithmetic. @@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { "Not resolving constant %s because we currently only support casting " "to float", LogName(*unary_op)); - return false; + return ::tensorflow::Status::OK(); } if (cast_op->src_data_type != input_array.buffer->type) { AddMessageF( @@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } } else { if (input_array.buffer->type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } input_float_data = &(input_array.GetBuffer().data); } @@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); - return false; + return ::tensorflow::Status::OK(); } auto& axis_array = model->GetArray(unary_op->inputs[1]); CHECK(axis_array.data_type == ArrayDataType::kInt32); @@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { default: LOG(FATAL) << "Unsupported activation function " << LogName(*unary_op); - return false; + return ::tensorflow::Status::OK(); } output_float_data[i] = new_value; } @@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*unary_op)); model->operators.erase(unary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc index 0dda1fd0b3..eed971c1d5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -25,17 +25,20 @@ limitations under the License. namespace toco { -bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto fakequant_it = model->operators.begin() + op_index; auto* fakequant_base_op = fakequant_it->get(); if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; + return ::tensorflow::Status::OK(); } auto* fakequant_op = static_cast(fakequant_base_op); if (fakequant_op->minmax) { // Already resolved. - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(fakequant_op->inputs.size(), 3); @@ -43,7 +46,7 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) { // resolved to constant arrays. for (int i = 1; i <= 2; i++) { if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) { - return false; + return ::tensorflow::Status::OK(); } } @@ -74,7 +77,8 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) { DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model); } fakequant_op->inputs.resize(1); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc index ce825c91af..69209b8dec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -24,20 +24,25 @@ limitations under the License. namespace toco { -bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveGatherAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto* gather_op = model->operators[op_index].get(); - if (gather_op->type != OperatorType::kGather) return false; + if (gather_op->type != OperatorType::kGather) + return ::tensorflow::Status::OK(); auto* op = static_cast(gather_op); if (op->axis) { // Attributes already resolved - return false; + return ::tensorflow::Status::OK(); } - if (op->inputs.size() != 3) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (op->inputs.size() != 3) return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, op->inputs[2])) + return ::tensorflow::Status::OK(); const auto& indices_array = model->GetArray(op->inputs[2]); - if (!indices_array.has_shape()) return false; + if (!indices_array.has_shape()) return ::tensorflow::Status::OK(); const auto& axis_data = indices_array.GetBuffer().data; CHECK_EQ(axis_data.size(), 1) << "Multidimensional gather not supported on " << LogName(*op); @@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { DeleteArrayIfUsedOnce(op->inputs[2], model); op->inputs.resize(2); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index b2b2ea151b..ac94f45321 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) { // Removes a multiplication by array of constant zeros by making the output // array an array of constant zeros and removing the input arrays if they are no // longer needed. -bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto mul_it = model->operators.begin() + op_index; auto* mul_op = mul_it->get(); if (mul_op->type != OperatorType::kMul) { - return false; + return ::tensorflow::Status::OK(); } const auto& output_array_name = mul_op->outputs[0]; auto& output_array = model->GetArray(output_array_name); if (!IsDiscardableArray(*model, output_array_name)) { - return false; + return ::tensorflow::Status::OK(); } if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } // Yield if the output shape is not known yet. if (!output_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // This transformation only handles the case where one operand is all 0's and @@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can resolve here. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants propagation, not // for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros(&output_array); } break; @@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros(&output_array); } break; @@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros(&output_array); } break; @@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer().data; if (!AreAllBufferElementsZero>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros(&output_array); } break; default: AddMessageF( "Cannot resolve multiply by 0 because of unsupported data type\n"); - return false; + return ::tensorflow::Status::OK(); } // Erase input arrays to the multiply if no longer used @@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { // Erase the multiply operator. model->operators.erase(mul_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc index 8a8e723cf7..adc87753bc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -24,19 +24,23 @@ limitations under the License. namespace toco { -bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolvePadAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto pad_it = model->operators.begin() + op_index; auto* pad_op = pad_it->get(); - if (pad_op->type != OperatorType::kPad) return false; + if (pad_op->type != OperatorType::kPad) return ::tensorflow::Status::OK(); auto* op = static_cast(pad_op); - if (!op->left_padding.empty()) return false; + if (!op->left_padding.empty()) return ::tensorflow::Status::OK(); CHECK_EQ(op->inputs.size(), 2); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) + return ::tensorflow::Status::OK(); const auto& array = model->GetArray(op->inputs[1]); - if (!array.has_shape()) return false; + if (!array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& dims = array.shape().dims(); CHECK_EQ(dims.size(), 2); @@ -50,6 +54,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { // TODO(dkalenichenko): Delete the extra input? - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc index ebb023e342..1f0f17a37a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc @@ -24,19 +24,23 @@ limitations under the License. namespace toco { -bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolvePadV2Attributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto pad_it = model->operators.begin() + op_index; auto* pad_op = pad_it->get(); - if (pad_op->type != OperatorType::kPadV2) return false; + if (pad_op->type != OperatorType::kPadV2) return ::tensorflow::Status::OK(); auto* op = static_cast(pad_op); - if (!op->left_padding.empty()) return false; + if (!op->left_padding.empty()) return ::tensorflow::Status::OK(); CHECK_EQ(op->inputs.size(), 3); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) + return ::tensorflow::Status::OK(); const auto& array = model->GetArray(op->inputs[1]); - if (!array.has_shape()) return false; + if (!array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& dims = array.shape().dims(); CHECK_EQ(dims.size(), 2); @@ -50,6 +54,7 @@ bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) { // TODO(dkalenichenko): Delete the extra input? - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 73198ac7c0..c3246ab90f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -39,23 +39,37 @@ bool ResolveAttributes(Model* model, T* op) { return true; } -bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveReduceAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; Operator* op = model->operators[op_index].get(); switch (op->type) { case OperatorType::kMean: - return ResolveAttributes(model, static_cast(op)); + *modified = ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kSum: - return ResolveAttributes(model, static_cast(op)); + *modified = + ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kReduceProd: - return ResolveAttributes(model, static_cast(op)); + *modified = + ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kReduceMin: - return ResolveAttributes(model, static_cast(op)); + *modified = + ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kReduceMax: - return ResolveAttributes(model, static_cast(op)); + *modified = + ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); case OperatorType::kAny: - return ResolveAttributes(model, static_cast(op)); + *modified = + ResolveAttributes(model, static_cast(op)); + return ::tensorflow::Status::OK(); default: - return false; + return ::tensorflow::Status::OK(); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index 8e150db6fa..ee5c4810e6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -78,11 +78,13 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order, } } -bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; auto* op = it->get(); if (op->type != OperatorType::kReorderAxes) { - return false; + return ::tensorflow::Status::OK(); } auto* reorder_op = static_cast(op); @@ -93,11 +95,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { auto& input_array = model->GetArray(input_array_name); auto& output_array = model->GetArray(output_array_name); if (!input_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } // Yield until output dims have been resolved. if (!output_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // Reorder the input array dims and buffer data if (input_array.buffer->type == ArrayDataType::kFloat) { @@ -120,7 +122,8 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { DeleteOpAndArraysIfUnused(model, op); RenameArray(model, output_array_name, input_array_name); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index b615c9a545..7b7a59264f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -25,25 +25,29 @@ limitations under the License. namespace toco { -bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveReshapeAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { - return false; + return ::tensorflow::Status::OK(); } auto* op = static_cast(reshape_op); - if (!op->shape.empty()) return false; + if (!op->shape.empty()) return ::tensorflow::Status::OK(); if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]); op->shape = constant_input_array.GetBuffer().data; } - if (op->shape.empty()) return false; + if (op->shape.empty()) return ::tensorflow::Status::OK(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc index e760d08e5a..5a838168de 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -24,29 +24,35 @@ limitations under the License. namespace toco { -bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveSliceAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto slice_it = model->operators.begin() + op_index; auto* slice_op = slice_it->get(); - if (slice_op->type != OperatorType::kSlice) return false; + if (slice_op->type != OperatorType::kSlice) return ::tensorflow::Status::OK(); auto* op = static_cast(slice_op); - if (!op->begin.empty()) return false; + if (!op->begin.empty()) return ::tensorflow::Status::OK(); CHECK_EQ(op->inputs.size(), 3); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) + return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, op->inputs[2])) + return ::tensorflow::Status::OK(); const auto& begin_array = model->GetArray(op->inputs[1]); - if (!begin_array.has_shape()) return false; + if (!begin_array.has_shape()) return ::tensorflow::Status::OK(); const auto& size_array = model->GetArray(op->inputs[2]); - if (!size_array.has_shape()) return false; + if (!size_array.has_shape()) return ::tensorflow::Status::OK(); op->begin = begin_array.GetBuffer().data; op->size = size_array.GetBuffer().data; // TODO(dkalenichenko): Delete the extra inputs? - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index fab50bec1f..3804145c4f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -24,16 +24,20 @@ limitations under the License. namespace toco { -bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto op_it = model->operators.begin() + op_index; - if (op_it->get()->type != OperatorType::kSpaceToBatchND) return false; + if (op_it->get()->type != OperatorType::kSpaceToBatchND) + return ::tensorflow::Status::OK(); auto* op = static_cast(op_it->get()); // The attributes are resolved only when the 3 attributes (block_shape, // before_paddings, after_paddings) are all constant. if (!op->block_shape.empty()) { - return false; + return ::tensorflow::Status::OK(); } const int block_shape_index = 1; @@ -42,16 +46,16 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 3); if (!IsConstantParameterArray(*model, op->inputs[block_shape_index]) || !IsConstantParameterArray(*model, op->inputs[paddings_index])) - return false; + return ::tensorflow::Status::OK(); // Handle paddings. const auto& paddings_array = model->GetArray(op->inputs[paddings_index]); - if (!paddings_array.has_shape()) return false; + if (!paddings_array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& paddings_dims = paddings_array.shape().dims(); if (paddings_dims.size() != 2) { // Code only handles padding of 2 dimensions. Perhaps another transformation // will delete this op. - return false; + return ::tensorflow::Status::OK(); } const std::vector& paddings_buffer = paddings_array.GetBuffer().data; @@ -63,7 +67,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { // Handle block_shape. const auto& block_shape_array = model->GetArray(op->inputs[block_shape_index]); - if (!block_shape_array.has_shape()) return false; + if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); const std::vector& block_shape_buffer = @@ -72,7 +76,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { op->block_shape.push_back(block_shape_buffer[i]); } - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index e8bb85704e..c601b0774e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -25,10 +25,13 @@ limitations under the License. namespace toco { -bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto* squeeze_op = model->operators[op_index].get(); if (squeeze_op->type != OperatorType::kSqueeze) { - return false; + return ::tensorflow::Status::OK(); } DCHECK_EQ(squeeze_op->inputs.size(), 1); DCHECK_EQ(squeeze_op->outputs.size(), 1); @@ -42,10 +45,11 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { "Reshape op", LogName(*squeeze_op)); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } - return false; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc index 65132d7d1e..f54f5b42a1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -37,40 +37,47 @@ int PadAttributeArray(Array* attribute_array, std::vector pad_values, return mask; } -bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto slice_it = model->operators.begin() + op_index; auto* slice_op = slice_it->get(); - if (slice_op->type != OperatorType::kStridedSlice) return false; + if (slice_op->type != OperatorType::kStridedSlice) + return ::tensorflow::Status::OK(); auto* op = static_cast(slice_op); if (!op->start_indices.empty()) { // We have already resolved these attributes - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->inputs.size(), 4); const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // We require the dimensionality of the input to pad the indices - return false; + return ::tensorflow::Status::OK(); } auto& start_array = model->GetArray(op->inputs[1]); - if (!start_array.has_shape()) return false; + if (!start_array.has_shape()) return ::tensorflow::Status::OK(); if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) { // Only 1-4D arrays are supported for now. - return false; + return ::tensorflow::Status::OK(); } auto& stop_array = model->GetArray(op->inputs[2]); - if (!stop_array.has_shape()) return false; + if (!stop_array.has_shape()) return ::tensorflow::Status::OK(); auto& stride_array = model->GetArray(op->inputs[3]); - if (!stride_array.has_shape()) return false; + if (!stride_array.has_shape()) return ::tensorflow::Status::OK(); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - if (!IsConstantParameterArray(*model, op->inputs[2])) return false; - if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) + return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, op->inputs[2])) + return ::tensorflow::Status::OK(); + if (!IsConstantParameterArray(*model, op->inputs[3])) + return ::tensorflow::Status::OK(); int num_input_axes = input_array.shape().dimensions_count(); int start_indices_size = start_array.shape().dims(0); @@ -112,6 +119,7 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { op->stop_indices = stop_array.GetBuffer().data; op->strides = stride_array.GetBuffer().data; - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index fa5ee89933..4927ccd95d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -25,12 +25,15 @@ limitations under the License. namespace toco { -bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto concat_it = model->operators.begin() + op_index; const auto* tf_concat_op = concat_it->get(); if (tf_concat_op->type != OperatorType::kConcat && tf_concat_op->type != OperatorType::kConcatV2) { - return false; + return ::tensorflow::Status::OK(); } CHECK_GE(tf_concat_op->inputs.size(), 2); @@ -54,7 +57,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { if (!axis_array.buffer) { AddMessageF("Waiting for the axis of %s to be resolved to a constant", LogName(*tf_concat_op)); - return false; + return ::tensorflow::Status::OK(); } CHECK(axis_array.data_type == ArrayDataType::kInt32); @@ -79,7 +82,8 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { } // Remove the TensorFlowConcat op model->operators.erase(concat_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 65346c4fe4..da039da546 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -55,10 +55,13 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model, } // namespace -bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto matmul_it = model->operators.begin() + op_index; if (matmul_it->get()->type != OperatorType::kMatMul) { - return false; + return ::tensorflow::Status::OK(); } const auto* matmul_op = static_cast(matmul_it->get()); @@ -73,7 +76,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { "Not replacing %s by a FullyConnected operator, because it has " "the transpose_a attribute", LogName(*matmul_op)); - return false; + return ::tensorflow::Status::OK(); } // Reorder the axes on the second input. TensorFlow uses row-major ordering @@ -198,7 +201,8 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { // erase the MatMul operator model->operators.erase(matmul_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 4edffe3d48..9beea3e937 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -24,11 +24,14 @@ limitations under the License. namespace toco { -bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); if (merge_op->type != OperatorType::kMerge) { - return false; + return ::tensorflow::Status::OK(); } // We need to yield until this Merge node has only 1 input, which will mean @@ -37,7 +40,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { // non-selected inputs, so that at some point there will be only 1 input left. if (merge_op->inputs.size() > 1) { AddMessageF("Waiting for %s to be resolved", LogName(*merge_op)); - return false; + return ::tensorflow::Status::OK(); } // Now that the merge node has 1 input exactly, it is the same as an Identity @@ -57,7 +60,8 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { AddMessageF("Removing already-resolved %s", LogName(*merge_op)); model->EraseArray(merge_op->outputs[0]); model->operators.erase(merge_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index 8bef440afd..e215981b42 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -24,11 +24,14 @@ limitations under the License. namespace toco { -bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); if (switch_op->type != OperatorType::kSwitch) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(switch_op->inputs.size(), 2); @@ -40,7 +43,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { AddMessageF( "Waiting for the boolean predicate of %s to be resolved to a constant", LogName(*switch_op)); - return false; + return ::tensorflow::Status::OK(); } // The predicate should be boolean, and should consist of a single value. @@ -119,7 +122,8 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { // Remove the switch node itself. AddMessageF("Removing already-resolved %s", LogName(*switch_op)); model->operators.erase(switch_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc index a657ee00af..aa7945391c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc @@ -24,19 +24,24 @@ limitations under the License. namespace toco { -bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveTransposeAttributes::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto op_it = model->operators.begin() + op_index; - if (op_it->get()->type != OperatorType::kTranspose) return false; + if (op_it->get()->type != OperatorType::kTranspose) + return ::tensorflow::Status::OK(); auto* op = static_cast(op_it->get()); - if (!op->perm.empty()) return false; + if (!op->perm.empty()) return ::tensorflow::Status::OK(); CHECK_EQ(op->inputs.size(), 2); - if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) + return ::tensorflow::Status::OK(); // Handling perm. const auto& perm_array = model->GetArray(op->inputs[1]); - if (!perm_array.has_shape()) return false; + if (!perm_array.has_shape()) return ::tensorflow::Status::OK(); const std::vector& perm_dims = perm_array.shape().dims(); CHECK_EQ(perm_dims.size(), 1); @@ -47,7 +52,8 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) { op->perm.push_back(perm_buffer[i]); } - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc index 22c258cec5..e9f24a29ab 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -24,15 +24,17 @@ limitations under the License. namespace toco { -bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { - return false; + return ::tensorflow::Status::OK(); } FullyConnectedOperator* fc_op = static_cast(op); // Exit if this FC op already has shuffled weights if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(fc_op->inputs[0]); const string& weights_name = fc_op->inputs[1]; @@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { output_array.data_type != ArrayDataType::kInt16 || !input_array.quantization_params || !weights_array.quantization_params || !output_array.quantization_params) { - return false; + return ::tensorflow::Status::OK(); } // Exit if the shapes aren't known if (!input_array.has_shape() || !weights_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // Exit if, based on the known shapes, this FC op is not a GEMV. // The shuffling of FC weights is only useful to enable fast GEMV paths. @@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "the input shape is not 1D or 2D (possibly with additional inner " "dimensions of size 1)", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } } if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) { @@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "the input shape's leading dimension, i.e. the 'batch size', is not " "equal to 1 or 4", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights shape isn't an integral multiple of the shuffled // block shape, 4x16. We don't want to have to write code dealing with @@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { // two. const Shape& weights_shape = weights_array.shape(); if (weights_shape.dimensions_count() != 2) { - return false; + return ::tensorflow::Status::OK(); } const int rows = weights_shape.dims(0); const int cols = weights_shape.dims(1); @@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "Not applying experimental shuffling to the weights of %s because its " "shape isn't a multiple of the shuffling block shape, 4x16", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights aren't already a constant array. if (!weights_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights are used by more than one op. if (CountOpsWithInput(*model, weights_name) != 1) { @@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "Not applying experimental shuffling to the weights of %s because that " "array is consumed by other operators", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Compute the shuffled weights auto& weights_data = @@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { shuffled_input_workspace_array.GetOrCreateQuantizationParams() = input_array.GetQuantizationParams(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 66cfed4ac2..e2a6f12481 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -166,7 +166,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; @@ -185,7 +188,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; @@ -204,7 +210,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index a53abc9941..57d85a0435 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -50,7 +50,8 @@ void RunResolveSum(const std::vector& input, sum_op->inputs = {"input0", "input1"}; sum_op->outputs = {"output"}; model.operators.push_back(std::move(sum_op)); - ResolveConstantUnaryOperator().Run(&model, 0); + bool modified; + ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok()); EXPECT_EQ(model.GetArray("output").GetBuffer().data, expected_output); EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 69bad2fa89..4ada5c3fd0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -25,13 +25,16 @@ limitations under the License. namespace toco { -bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { +::tensorflow::Status UnfuseActivationFunctions::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; auto* op = it->get(); // If a conv operation has an im2col array, yield: it should be dropped first. if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) { - return false; + return ::tensorflow::Status::OK(); } Operator* ac_op = nullptr; @@ -46,7 +49,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { ac_op = new Relu1Operator; break; default: - return false; + return ::tensorflow::Status::OK(); } // At this point we know that the op has a fused activation function. At the @@ -74,7 +77,8 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index dd9e26e68b..e19527968d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -22,7 +22,10 @@ limitations under the License. namespace toco { -bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { +::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // Collapses a partitioned tf.nn.embedding_lookup back into a single Gather. // https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup // This transform attempts to identify the len(params) > 1 case and collapse @@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { // First look for the final DynamicStitch. auto op_it = model->operators.begin() + op_index; if (op_it->get()->type != OperatorType::kDynamicStitch) { - return false; + return ::tensorflow::Status::OK(); } auto* stitch_op = static_cast(op_it->get()); @@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because indices input %s into " "%s is unexpected", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } if (!indices_partition_op) { indices_partition_op = static_cast(op); @@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because indices input %s into " "%s is from a different source op than others", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } } } @@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { // The data for the indices must be a constant range of the array shape. if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) { AddMessageF("Skipping because indices partition data is non-constant"); - return false; + return ::tensorflow::Status::OK(); } auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]); if (indices_data_array.data_type == ArrayDataType::kNone) { // Yield until data types are propagated. - return false; + return ::tensorflow::Status::OK(); } CHECK(indices_data_array.data_type == ArrayDataType::kInt32) << "Indices partition inputs must be int32"; @@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into %s " "is unexpected", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } gather_ops.push_back(static_cast(op)); } @@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into " "%s is unexpected", LogName(*op), LogName(*gather_op)); - return false; + return ::tensorflow::Status::OK(); } if (!data_partition_op) { data_partition_op = static_cast(op); @@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into " "%s is from a different source op than others", LogName(*op), LogName(*gather_op)); - return false; + return ::tensorflow::Status::OK(); } } } @@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { DeleteOpAndArraysIfUnused(model, indices_partition_op); DeleteOpAndArraysIfUnused(model, data_partition_op); DeleteOpAndArraysIfUnused(model, stitch_op); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc index fedf4441e2..5ff39aa313 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -36,10 +36,12 @@ namespace toco { // slice_c = tf.matmul(slice_a, slice_b) // result_slices[bat] = slice_c // result = tf.stack(result_slices) -bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { +::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; auto batch_op_it = model->operators.begin() + op_index; if (batch_op_it->get()->type != OperatorType::kBatchMatMul) { - return false; + return ::tensorflow::Status::OK(); } const auto* batch_op = static_cast(batch_op_it->get()); @@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { // We must have the shape of at least one input to know our batch size. const auto& input_array_a = model->GetArray(batch_op->inputs[0]); const auto& input_array_b = model->GetArray(batch_op->inputs[1]); - if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false; + if (!input_array_a.has_shape() || !input_array_b.has_shape()) + return ::tensorflow::Status::OK(); // We only support the rank 3 case. If you are batching on rank > 3 you'll // have to figure that out. @@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { batch_op_it = matmul_op_it + 1; CHECK_EQ(batch_op_it->get(), batch_op); model->operators.erase(batch_op_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } CHECK_EQ(input_array_a.shape().dimensions_count(), 3) << "Input arrays must have rank 3"; @@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { CHECK(batch_op_it != model->operators.end()); CHECK(batch_op_it->get() == batch_op); model->operators.erase(batch_op_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco -- cgit v1.2.3