aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc22
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc32
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h29
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc22
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc33
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc26
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc22
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc26
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc28
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc28
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc22
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc21
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc32
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc27
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc15
94 files changed, 1003 insertions, 617 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
index 310a88484c..8a945ac435 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc
@@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
-bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertExpandDimsToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto expand_it = model->operators.begin() + op_index;
if (expand_it->get()->type != OperatorType::kExpandDims) {
- return false;
+ return ::tensorflow::Status::OK();
}
ExpandDimsOperator* expand_op =
static_cast<ExpandDimsOperator*>(expand_it->get());
@@ -38,18 +41,18 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(expand_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& axis_array = model->GetArray(expand_op->inputs[1]);
if (!axis_array.has_shape()) {
// Yield until input axis array shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1);
if (!axis_array.buffer) {
// Yield until the input axis array is constant
- return false;
+ return ::tensorflow::Status::OK();
}
int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
std::vector<int> reshape_dims(input_array.shape().dims());
@@ -90,7 +93,8 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(expand_it->get(), expand_op);
model->operators.erase(expand_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index e88839be5d..a151012891 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -24,29 +24,32 @@ limitations under the License.
namespace toco {
-bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto conv_it = model->operators.begin() + op_index;
if (conv_it->get()->type != OperatorType::kConv) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
if (conv_op->stride_width != conv_op->stride_height) {
- return false;
+ return ::tensorflow::Status::OK();
}
if ((conv_op->dilation_width_factor != 1) ||
(conv_op->dilation_height_factor != 1)) {
// Depthwise conv does not support dilation
- return false;
+ return ::tensorflow::Status::OK();
}
auto& input_array = model->GetArray(conv_op->inputs[0]);
if (!input_array.has_shape()) {
// Shapes not propagated yet
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dims(3) != 1) {
// Not a pure convolution: Conv does accumulation across the depth
// dimension.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights_name = conv_op->inputs[1];
@@ -56,15 +59,15 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
"Not changing %s to DepthwiseConv because the weights is consumed by "
"another op.",
LogName(*conv_op));
- return false;
+ return ::tensorflow::Status::OK();
}
auto& weights_array = model->GetArray(weights_name);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
- return false;
+ return ::tensorflow::Status::OK();
}
if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
// At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
AddMessageF(
@@ -112,7 +115,8 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
}
*weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
weights_buffer.data = depthwise_conv_weights_data;
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc
index 0d274fc687..4a264e1cf1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_reorder_axes.cc
@@ -86,9 +86,12 @@ TransposeOperator* CreateTransposeFromReorderAxes(
// Converts ReorderAxes into Transpose and Reshape which are compatible with the
// TFLite interpreter.
-bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertReorderAxes::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto reorder_it = model->operators.begin() + op_index;
- if (reorder_it->get()->type != OperatorType::kReorderAxes) return false;
+ if (reorder_it->get()->type != OperatorType::kReorderAxes)
+ return ::tensorflow::Status::OK();
auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
CHECK_EQ(reorder_op->inputs.size(), 1);
@@ -113,8 +116,9 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
// Yield if input array contains constants or if output array size has not
// been adjusted to reflect the permutations in ReorderAxes. ReorderAxes will
// be merged into a constant array when possible.
- if (IsConstantParameterArray(*model, constant_input_array_name)) return false;
- if (!output_array.has_shape()) return false;
+ if (IsConstantParameterArray(*model, constant_input_array_name))
+ return ::tensorflow::Status::OK();
+ if (!output_array.has_shape()) return ::tensorflow::Status::OK();
const auto input_axes_order = reorder_op->input_axes_order;
const auto output_axes_order = reorder_op->output_axes_order;
@@ -143,7 +147,8 @@ bool ConvertReorderAxes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(reorder_it->get(), reorder_op);
model->operators.erase(reorder_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
index 81cedb5dad..a0bd1ed4a4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc
@@ -30,10 +30,13 @@ namespace toco {
// means that the data layout will never change with this op, just the shape.
// By converting these to reshapes once we have run shape propagation we allow
// standard reshape optimization transforms to do their magic.
-bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertSqueezeToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto squeeze_it = model->operators.begin() + op_index;
if (squeeze_it->get()->type != OperatorType::kSqueeze) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto squeeze_op = static_cast<SqueezeOperator*>(squeeze_it->get());
CHECK_EQ(squeeze_op->inputs.size(), 1);
@@ -42,16 +45,16 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
const auto& input_array = model->GetArray(squeeze_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!model->HasArray(squeeze_op->outputs[0]) ||
!model->GetArray(squeeze_op->outputs[0]).has_shape()) {
// Yield until shape propagation has set the output shape for us.
- return false;
+ return ::tensorflow::Status::OK();
}
// We use the output shape that has been calculated by shape propagation.
@@ -59,7 +62,7 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
// Empty shapes will not work as empty data arrays.
if (output_shape.dimensions_count() == 0) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* reshape_op = new TensorFlowReshapeOperator;
@@ -79,7 +82,8 @@ bool ConvertSqueezeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(squeeze_it->get(), squeeze_op);
model->operators.erase(squeeze_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
index dcaaddbf3b..d7cacf77f4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc
@@ -20,10 +20,13 @@ namespace toco {
// This pass will convert an AddN operator with only 2 inputs into a regular Add
// operator, to which more optimizations may apply.
-bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialAddNToAdd::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto addn_it = model->operators.begin() + op_index;
if (addn_it->get()->type != OperatorType::kAddN) {
- return false;
+ return ::tensorflow::Status::OK();
}
AddNOperator* addn_op = static_cast<AddNOperator*>(addn_it->get());
CHECK_GE(addn_op->inputs.size(), 2);
@@ -31,7 +34,7 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
// We only reduce AddN with N=2 to a regular Add.
if (addn_op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Copy inputs & outputs to regular Add.
@@ -45,7 +48,8 @@ bool ConvertTrivialAddNToAdd::Run(Model* model, std::size_t op_index) {
addn_it = add_it + 1;
CHECK_EQ(addn_it->get(), addn_op);
model->operators.erase(addn_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
index 75113a2a8c..78779243a9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc
@@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
-bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialPackToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto pack_it = model->operators.begin() + op_index;
if (pack_it->get()->type != OperatorType::kPack) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* pack_op = static_cast<PackOperator*>(pack_it->get());
if (pack_op->inputs.size() > 1) {
// Not trivial.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(pack_op->outputs.size(), 1);
const auto& input_array = model->GetArray(pack_op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.shape().dimensions_count() == 0) {
// Input array cannot be 0-D.
// (Unsure if this is TF behavior, but was required to get a test to pass.)
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op));
@@ -75,7 +78,8 @@ bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(pack_it->get(), pack_op);
model->operators.erase(pack_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
index b689be0792..b6d712ca44 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -21,10 +21,13 @@ limitations under the License.
namespace toco {
-bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialTileToConcat::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto tile_it = model->operators.begin() + op_index;
if (tile_it->get()->type != OperatorType::kTile) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
@@ -34,13 +37,13 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
if (!input_array.has_shape() || !multiples_array.has_shape() ||
!output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
- return false;
+ return ::tensorflow::Status::OK();
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
if (!multiples_array.buffer) {
// Yield until the multiples is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
std::vector<int32> const& multiples =
multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
@@ -59,7 +62,7 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
// The tile is non-trivial. Good luck.
AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)",
LogName(*tile_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// The tile is like a concat.
@@ -88,7 +91,8 @@ bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
CHECK_EQ(tile_it->get(), tile_op);
model->operators.erase(tile_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
index 5a36a90b38..e5a96d4335 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc
@@ -48,10 +48,13 @@ bool TransposeAffectsMemoryOrder(std::vector<int> perm,
} // namespace
-bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ConvertTrivialTransposeToReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto transpose_it = model->operators.begin() + op_index;
if (transpose_it->get()->type != OperatorType::kTranspose) {
- return false;
+ return ::tensorflow::Status::OK();
}
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
@@ -60,14 +63,14 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
if (!input_array.has_shape() || !output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
- return false;
+ return ::tensorflow::Status::OK();
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
// Check that the permutation has propogated.
std::vector<int> const& perm = transpose_op->perm;
if (perm.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transpose is trivial if non-unitary dimensions remain in the same
@@ -76,7 +79,7 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
std::vector<int> const& output_dims = output_array.shape().dims();
if (TransposeAffectsMemoryOrder(perm, input_dims)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transpose is trivial. Replace it with a Reshape op.
@@ -109,7 +112,8 @@ bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(transpose_it->get(), transpose_op);
model->operators.erase(transpose_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 1e68cd678b..ebc0e9afca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -73,18 +73,22 @@ bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
return true;
}
-bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status CreateIm2colArrays::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
switch (op->type) {
case OperatorType::kConv:
- return ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ *modified = ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kTransposeConv:
- return ProcessTransposeConvOperator(
+ *modified = ProcessTransposeConvOperator(
model, static_cast<TransposeConvOperator*>(op));
+ return ::tensorflow::Status::OK();
default:
- return false;
+ return ::tensorflow::Status::OK();
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index 1688586733..2119174950 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -186,24 +186,27 @@ bool DequantizeArray(const string& array_name,
} // namespace
-bool Dequantize::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status Dequantize::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto op_it = model->operators.begin() + op_index;
auto* op = op_it->get();
if (op->type == OperatorType::kDequantize) {
auto& input_array = model->GetArray(op->inputs[0]);
if (input_array.data_type == ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array.final_data_type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
input_array.data_type = ArrayDataType::kFloat;
input_array.quantization_params = nullptr;
auto& output_array = model->GetArray(op->outputs[0]);
output_array.data_type = ArrayDataType::kFloat;
output_array.quantization_params = nullptr;
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
std::vector<string> arrays;
@@ -220,7 +223,8 @@ bool Dequantize::Run(Model* model, std::size_t op_index) {
}
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
index 95558ef5ec..1555cf60a1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -25,21 +25,23 @@ limitations under the License.
namespace toco {
-bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status DropFakeQuant::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (!fakequant_op->minmax) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
if (!output_array.minmax) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Drop min/max inputs
@@ -50,7 +52,8 @@ bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
}
fakequant_op->inputs.resize(1);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
index f7fd878b7e..7d66ea5dd2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -19,15 +19,17 @@ limitations under the License.
namespace toco {
-bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status DropIm2colArrays::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto conv_it = model->operators.begin() + op_index;
if (conv_it->get()->type != OperatorType::kConv) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
if (conv_op->outputs.size() < 2) {
// Conv op does not have im2col.
- return false;
+ return ::tensorflow::Status::OK();
}
// Drop the im2col array.
@@ -36,7 +38,8 @@ bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
conv_op->outputs.resize(1);
AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
index e80ed036b3..72b1dda3be 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -62,17 +62,20 @@ bool ProcessLinearOperator(Model* model, Operator* op) {
}
} // namespace
-bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status EnsureBiasVectors::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto* op = model->operators[op_index].get();
if (op->type == OperatorType::kConv ||
op->type == OperatorType::kDepthwiseConv ||
op->type == OperatorType::kFullyConnected) {
if (ProcessLinearOperator(model, op)) {
AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index c13fc0de75..60dcd52684 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -108,8 +108,9 @@ namespace toco {
// we can foresee these 'fast int8 kernels' to remain important to have into
// the 2020s.
//
-bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status EnsureUint8WeightsSafeForFastInt8Kernels::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto& op = *model->operators[op_index];
int weights_index = 0;
switch (op.type) {
@@ -148,16 +149,16 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
// That's why at the moment we only handle operators that use a GEMM
// (Conv, fully-connected --- note that LSTM merely wraps a
// fully-connected operator).
- return false;
+ return ::tensorflow::Status::OK();
}
const string& name = op.inputs[weights_index];
auto& array = model->GetArray(name);
if (!array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (array.data_type != ArrayDataType::kUint8) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto& buffer_data = array.GetMutableBuffer<ArrayDataType::kUint8>().data;
@@ -212,7 +213,8 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
AddMessageF("Tweaked weights values for %s", LogName(op));
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
index c5ce3fcd95..88511a7d3c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
-bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseActivationFunctions::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function
Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
- if (!op) return false;
+ if (!op) return ::tensorflow::Status::OK();
if (CountTrueOutputs(*model, *op) > 1) {
AddMessageF(
"Not fusing activation function %s into %s because it has more than "
"one consumed output",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
@@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function into %s because it is consumed by more "
"than 1 other operator",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsDiscardableArray(*model, op->outputs[0])) {
@@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because output %s it is not "
"discardable",
LogName(*ac_op), LogName(*op), op->outputs[0]);
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because it already has a "
"fused activation function",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorSupportsFusedActivation(op->type)) {
@@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s because the %s op doesn't support "
"it",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Fusing activation function %s into the preceding %s",
@@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
index dcbbead517..0de22b8ff4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
} // namespace
-bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s into the following affine op, because we only know "
"how to do so when the constant operand is a scalar",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
FusedActivationFunctionType::kNone) {
AddMessageF("Not fusing %s because it has a fused activation function",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
@@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because it is not consumed by exactly one other op",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (following_op->type != OperatorType::kConv &&
@@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (following_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the following %s does not have a bias vector",
LogName(*following_op), LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights = model->GetArray(following_op->inputs[1]);
@@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the following %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Try to fuse the binary params into the following op's params
@@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
if (following_op->type == OperatorType::kDepthwiseConv) {
@@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not fusing %s because the following %s does not use VALID padding",
LogName(*binary_op), LogName(*following_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
@@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
index b324631579..b8da756d85 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -188,14 +188,17 @@ void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
}
} // namespace
-bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBinaryIntoPrecedingAffine::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -213,12 +216,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can fuse into a constant.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -230,7 +233,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (index_of_constant_input != 1) {
AddMessageF("Not fusing %s because the denominator is not constant",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -239,12 +242,12 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
if (!preceding_op) {
AddMessageF("Not fusing %s because it is not the output of another op",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
for (const string& output_array : model->flags.output_arrays()) {
if (preceding_op->outputs[0] == output_array) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -255,7 +258,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s is not of one of the supported "
"types",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (preceding_op->fused_activation_function !=
@@ -264,14 +267,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a fused activation "
"function",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (preceding_op->inputs.size() < 3) {
AddMessageF(
"Not fusing %s because the preceding %s does not have a bias vector",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& weights_name = preceding_op->inputs[1];
@@ -289,14 +292,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has a non-constant bias "
"array",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the bias of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
} else {
if (!weights.buffer || !bias.buffer) {
@@ -304,14 +307,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the preceding %s has non-constant weights or "
"bias arrays",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (count_ops_consuming_weights > 1 || count_ops_consuming_bias > 1) {
AddMessageF(
"Not fusing %s because the weights or bias of the preceding %s is "
"consumed by another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -323,7 +326,7 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
"Not fusing %s because the output of the preceding %s is consumed by "
"another op",
LogName(*binary_op), LogName(*preceding_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
@@ -352,7 +355,8 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
model->EraseArray(old_constant_param_name);
}
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
index 874d8def57..4848867b9a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -51,19 +51,22 @@ bool IsBroadcastingOp(const Model& model, Operator* op) {
// Finds an operation that looks like a broadcast (concat of the same sources
// along the last dimension) and drops it by relying on the ability of certain
// binary ops to perform an implicit broadcast.
-bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseBroadcastIntoFollowingBinary::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
if (binary_op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
// NOTE: either of these ops may be nullptr if the input array is constant.
@@ -78,14 +81,14 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
if (!is_op_0_broadcast && !is_op_1_broadcast) {
// Neither input is a broadcast-looking thing.
AddMessageF("Neither input looks broadcasty");
- return false;
+ return ::tensorflow::Status::OK();
} else if (is_op_0_broadcast && is_op_1_broadcast) {
AddMessageF(
"Unable to fuse broadcast into %s as both inputs (%s, %s) are "
"broadcasts",
LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
op[1] ? LogName(*op[1]) : "(?)");
- return false;
+ return ::tensorflow::Status::OK();
}
int broadcast_index = is_op_0_broadcast ? 0 : 1;
@@ -96,7 +99,8 @@ bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
// We leave the broadcast op in; it'll get cleaned up if it's not used later.
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
index 6961e23690..8b0bc2d865 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -142,7 +142,7 @@ bool GraphTransformationsPass(int increment, Model* model,
for (const auto& transformation : transformations) {
CHECK(!changed_now);
CHECK(transformation->Messages().empty());
- changed_now = transformation->Run(model, op_index);
+ CHECK(transformation->Run(model, op_index, &changed_now).ok());
const char* made_a_change_msg =
changed_now ? "made a change" : "did NOT make a change";
const int log_level =
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 4d213b3f9c..a89db320ea 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -27,7 +27,8 @@ namespace toco {
class GraphTransformation {
public:
- virtual bool Run(Model* model, std::size_t op_index) = 0;
+ virtual ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) = 0;
virtual const char* Name() const = 0;
virtual ~GraphTransformation() {}
// Returns the list of messages that this graph transformation
@@ -104,11 +105,12 @@ class GraphTransformationsSet {
void RunGraphTransformations(Model* model, const string& message,
const GraphTransformationsSet& transformations);
-#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
- class GTName : public GraphTransformation { \
- public: \
- bool Run(Model* model, std::size_t op_index) override; \
- const char* Name() const override { return #GTName; } \
+#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
+ class GTName : public GraphTransformation { \
+ public: \
+ ::tensorflow::Status Run(Model* model, std::size_t op_index, \
+ bool* modified) override; \
+ const char* Name() const override { return #GTName; } \
};
// List of all graph transformations
@@ -200,7 +202,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes)
class PropagateDefaultMinMax : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "PropagateDefaultMinMax"; }
bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
@@ -218,7 +221,8 @@ class PropagateDefaultMinMax : public GraphTransformation {
class RemoveTrivialReshape : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "RemoveTrivialReshape"; }
bool treat_expand_dims_as_trivial() const {
return treat_expand_dims_as_trivial_;
@@ -233,7 +237,8 @@ class RemoveTrivialReshape : public GraphTransformation {
class ResolveConstantFakeQuant : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "ResolveConstantFakeQuant"; }
// True if the num_bits should adjust the final data type.
@@ -250,7 +255,8 @@ class ResolveConstantFakeQuant : public GraphTransformation {
class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override {
return "EnsureUint8WeightsSafeForFastInt8Kernels";
}
@@ -267,7 +273,8 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
class IdentifyDilatedConv : public GraphTransformation {
public:
- bool Run(Model* model, std::size_t op_index) override;
+ ::tensorflow::Status Run(Model* model, std::size_t op_index,
+ bool* modified) override;
const char* Name() const override { return "IdentifyDilatedConv"; }
bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 3114fa93e8..72df53548b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -372,7 +372,9 @@ bool HardcodeMinMaxForLstmCell(Model* model, Operator* op) {
}
} // namespace
-bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status HardcodeMinMax::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
bool changed = false;
@@ -467,7 +469,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
if (changed) {
AddMessageF("Hardcoded min-max through %s", LogName(*op));
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index aac77eb39e..9e4a3005a1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -168,7 +168,10 @@ bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
return true;
}
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyDilatedConv::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* stb_op = it->get();
@@ -176,17 +179,17 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// ***************************************************************************
// SpaceToBatch Op.
if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (stb_op->inputs.size() != 3) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(stb_op->outputs.size(), 1);
// Extract the dilation factor from Input[1] of SpaceToBatch
// TODO(mjmatthews): Support 2D dilation factors.
const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
if (!block_shape_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
int dilation_factor =
@@ -195,7 +198,7 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
// Expand Op
auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
if (!post_stb_op) {
- return false;
+ return ::tensorflow::Status::OK();
}
bool has_expand_op = false;
if (post_stb_op->type == OperatorType::kExpandDims) {
@@ -229,7 +232,8 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
}
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index b78efd7fc3..78f60f52fb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -39,7 +39,10 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
}
} // namespace
-bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyL2Normalization::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto div_it = model->operators.begin() + op_index;
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
@@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
} else if (div_or_mul_op->type == OperatorType::kMul) {
expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(div_or_mul_op->inputs.size(), 2);
Operator* op_producing_div_or_mul_input[2] = {
@@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
if (!op_producing_div_or_mul_input[1] ||
op_producing_div_or_mul_input[1]->type !=
expected_op_type_producing_div_or_mul_input) {
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
Operator* op_producing_sqrt_or_rsqrt_input =
GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
if (!op_producing_sqrt_or_rsqrt_input) {
- return false;
+ return ::tensorflow::Status::OK();
}
// There may be an Add or a Maximum here, adding or clamping to a "small"
@@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
" because the operator producing the input to the square root, %s,"
", does not match the expected pattern",
LogName(*op_producing_sqrt_or_rsqrt_input));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
LogName(*sum_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
@@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",
LogName(*square_op));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(square_op->inputs.size(), 1);
@@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: %s does not "
"take the same input as the Mul/Div node",
LogName(*square_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Create and emplace the new L2Normalization
@@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
index 705e73779b..13664bb344 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -38,11 +38,13 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
}
} // namespace
-bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyL2Pool::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto sqrt_it = model->operators.begin() + op_index;
const auto* sqrt_op = sqrt_it->get();
if (sqrt_op->type != OperatorType::kSqrt) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(sqrt_op->inputs.size(), 1);
@@ -56,7 +58,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Giving up trying to identify L2Pool subgraph: "
"expected AveragePool op, but Sqrt op has no preceding op");
- return false;
+ return ::tensorflow::Status::OK();
}
if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
@@ -64,7 +66,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Pool subgraph: "
"expected AveragePool op, got %s",
LogName(*prev_to_sqrt_op));
- return false;
+ return ::tensorflow::Status::OK();
}
avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
@@ -77,7 +79,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Pool subgraph: "
"expected Square op, got %s",
LogName(*square_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Create and emplace L2Pool node.
@@ -107,7 +109,8 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, avpool_op));
model->operators.erase(FindOperator(model, sqrt_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index c0b014b45e..7fd8f906e2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -132,7 +132,9 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
} // namespace
-bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyLstmCell::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// This LSTM cell identification method is not invariant to commutation of
// commutative operator inputs. For example, if input[0] and input[1] of the
// final output multiplication were swapped, this method would not identify it
@@ -143,13 +145,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
auto op_it = model->operators.begin() + op_index;
Operator* final_output_mul = op_it->get();
if (final_output_mul->type != OperatorType::kMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
Operator *state_output_tanh, *fc_output_sig;
if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
&state_output_tanh, OperatorType::kLogistic,
&fc_output_sig)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State output TanH
@@ -158,7 +160,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
Operator* state_combine_add;
if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
&state_combine_add)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State forget & remember addition
@@ -166,7 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
&state_forget_mul, OperatorType::kMul,
&state_remember_mul)) {
- return false;
+ return ::tensorflow::Status::OK();
}
const string prev_state = state_forget_mul->inputs[0];
@@ -175,7 +177,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
nullptr, OperatorType::kLogistic,
&state_forget_sig)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember gate
@@ -183,40 +185,40 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
&state_remember_sig, OperatorType::kTanh,
&state_info_tanh)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember "information" activation function
Operator* fc_output_split;
if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
&fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State remember gate activation function
Operator* tmp;
if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// State forget gate activation function
if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected output activation function
if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
&tmp) ||
(tmp != fc_output_split)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected output split
Operator* fully_connected;
if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
nullptr, OperatorType::kFullyConnected,
&fully_connected)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Fully connected op
@@ -225,13 +227,13 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
OperatorType::kConcatenation, &concat_inputs,
OperatorType::kNone, nullptr, OperatorType::kNone,
nullptr)) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format !=
FullyConnectedWeightsFormat::kDefault) {
// Not yet implemented: experimental shuffled weights in fused LSTM cell.
- return false;
+ return ::tensorflow::Status::OK();
}
// Emplace a new LSTM cell operator
@@ -300,7 +302,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, *fully_connected));
DeleteArrayIfUnused(concat_inputs->outputs[0], model);
model->operators.erase(FindOperator(model, *concat_inputs));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 5b6a984ee1..6ccce923f3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -25,19 +25,22 @@ limitations under the License.
namespace toco {
-bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status MergeLstmCellInputs::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Find lstm cell.
auto op_it = model->operators.begin() + op_index;
auto src_op = op_it->get();
if (src_op->type != OperatorType::kLstmCell) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Already a compact LstmCell. Do not need to merge cell inputs.
const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Identify prev_activ_input, prev_state_input as required Op inputs,
@@ -45,12 +48,12 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
string prev_activ_input;
if (!GetMatchingRnnArray(model, src_op->outputs[kOutputTensor],
&prev_activ_input)) {
- return false;
+ return ::tensorflow::Status::OK();
}
string prev_state_input;
if (!GetMatchingRnnArray(model, src_op->outputs[kCellStateTensor],
&prev_state_input)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Get LstmCell's cell, input, output size.
@@ -184,7 +187,8 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(src_op->inputs[kOutputGateBiasTensor], model);
model->operators.erase(FindOp(*model, src_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 46d1fce50e..ad5120e2aa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -25,19 +25,22 @@ limitations under the License.
namespace toco {
-bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status SplitLstmCellInputs::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Find lstm cell.
auto op_it = model->operators.begin() + op_index;
auto curr_op = op_it->get();
if (curr_op->type != OperatorType::kLstmCell) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
// Already an extended LstmCell. Do not need to split cell inputs.
if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Make sure the WEIGHTS_INPUT and BIASES_INPUT are constant arrays,
@@ -46,13 +49,13 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
*model, curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]) ||
!IsConstantParameterArray(
*model, curr_op->inputs[LstmCellOperator::BIASES_INPUT])) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Make sure propagate_fixed_sizes has defined the size of the output.
if (!model->GetArray(curr_op->outputs[LstmCellOperator::ACTIV_OUTPUT])
.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
@@ -168,7 +171,8 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
model->operators.erase(FindOp(*model, curr_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
index b90a156a0d..c11fee4dc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -43,13 +43,15 @@ limitations under the License.
namespace toco {
-bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto add_op_it = model->operators.begin() + op_index;
const auto* add_op = add_op_it->get();
if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
add_op->inputs.size() != 2 ||
add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
@@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_input_op->inputs.size() != 1 ||
relu_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
// TODO(ycling): Both Add and Mul are commutative. Support the case where
@@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
mul_op->inputs.size() != 2 ||
mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto neg_alpha_tensor_name = mul_op->inputs[0];
@@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (relu_neg_input_op == nullptr ||
relu_neg_input_op->inputs.size() != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Operator* final_input_op;
@@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_neg_input_op->type != OperatorType::kRelu ||
relu_neg_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
final_input_op = neg_input_op;
}
if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto input_tensor_name = relu_input_op->inputs[0];
@@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index 94820a0166..51d0629362 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
}
} // namespace
-bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Follow sequences of min+max and max+min. First get the leading op.
const auto op_it = model->operators.begin() + op_index;
const auto* op_0 = op_it->get();
if (op_0->type != OperatorType::kMinimum &&
op_0->type != OperatorType::kMaximum) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the paired op and ensure it's the counter to the first.
@@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
(op_1->type != OperatorType::kMinimum &&
op_1->type != OperatorType::kMaximum) ||
op_0->type == op_1->type) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the original input to the min+max pair.
@@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
int max_scalar_input_index =
GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f);
if (min_scalar_input_index == -1 || max_scalar_input_index == -1) {
- return false;
+ return ::tensorflow::Status::OK();
}
int op_0_scalar_input_index =
op_0 == min_op ? min_scalar_input_index : max_scalar_input_index;
@@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, op_0));
model->operators.erase(FindOperator(model, op_1));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index f684de08ab..5bf17d5b4c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -97,7 +97,10 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
return true;
}
-bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status MakeInitialDequantizeOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// This is effectively a transformation applied to edges. We iterate over the
// specified node (op) and proceed for input edges.
const auto it = model->operators.begin() + op_index;
@@ -114,7 +117,8 @@ bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
}
}
}
- return change_made;
+ *modified = change_made;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index 95bc7f7d4b..06de9b1cd8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -102,18 +102,19 @@ std::vector<int32> ReshapeToTranspose(const Model& model,
// to be merged if the reshape does not affect memory ordering and does not
// affects the number of dimensions. This only occurs when only unary dimensions
// are shifting position.
-bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status MergeReshapeIntoPrecedingTranspose::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
it->get(), OperatorType::kReshape);
if (reshape_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
const string intermediate_name = reshape_op->inputs[0];
@@ -121,13 +122,13 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Guarantee the input is only consume by the reshape.
if (CountOpsWithInput(*model, intermediate_name) != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check for the parent operator.
const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
if (transpose_it == model->operators.end()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the parent operator and guarantee it is a transpose.
@@ -135,16 +136,16 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
transpose_it->get(), OperatorType::kTranspose);
if (transpose_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
false /*allow_extra_unary_dimensions*/)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that the intermediate is not an output array.
@@ -153,7 +154,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
"Cannot fuse %s and %s as it would invalidate the transpose "
"output array.",
LogName(*transpose_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
@@ -172,7 +173,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// Remove the reshape as passthrough operation.
if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Update transpose_op's constant buffer to contain the new permutation.
@@ -184,7 +185,8 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
// transpose_ops's shape will likely has changed.
model->GetArray(transpose_op->outputs[0]).clear_shape();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
index 7f44c65285..f0d8d924ad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) {
//
// Note we are testing for one particular case of a broader set of possible
// binary-reshape op transformations. This transformation could be generalized.
-bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
Operator* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
@@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
- return false;
+ return ::tensorflow::Status::OK();
}
// BINARY OP INPUT CHECKS
@@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (!input_is_const[0] && !input_is_const[1]) {
// To limit our scope, we require one constant input. Though there's no
// reason this transformation wouldn't work with all variable inputs.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_is_const[0] && input_is_const[1]) {
// Both inputs are constants. Leave this for constants propagation.
- return false;
+ return ::tensorflow::Status::OK();
}
const int constant_input_idx = input_is_const[0] ? 0 : 1;
const int variable_input_idx = input_is_const[0] ? 1 : 0;
@@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because it's non-constant input shape is not resolved.",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
// Constant array shape must be the latter part of the variable shape.
- return false;
+ return ::tensorflow::Status::OK();
}
// RESHAPE OP CHECKS
@@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (reshape_it == model->operators.end()) {
AddMessageF("Not moving %s because it's variable input is not connected.",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
AddMessageF("Not moving %s because the preceding %s is not a reshape op",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
if (!reshape_input_array.has_shape()) {
@@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because it's non-constant input shape is not resolved "
"yet",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(reshape_op->outputs[0]).shape())) {
// Constant array shape must be the latter part of the binary op output
// shape.
- return false;
+ return ::tensorflow::Status::OK();
}
// EXTRA CHECKS ON CONNECTING ARRAY
@@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because the output of reshape op %s is an output op.",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
int count_ops_consuming_output =
@@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because the output of reshape op %s is consumed by "
"another op",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// SWAP ORDER OF BINARY AND RESHAPE OPS
@@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
// Clear binary output shape so it will be re-propagated
model->GetArray(binary_op->outputs[0]).clear_shape();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
index cf17c49b10..9c1ed2b732 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
@@ -26,20 +26,21 @@ limitations under the License.
namespace toco {
-bool PropagateActivationFunctionIntoConstants::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status PropagateActivationFunctionIntoConstants::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function.
auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]);
if (!src_op) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Ensure the src_op is not used without the activation function applied.
@@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
src_op_input = src_op->inputs[0];
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]);
@@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is not "
"constant",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the array we'll be working with and ensure it's a compatible type.
@@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is "
"non-float data",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
auto& const_array_data =
const_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
@@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
}
default:
LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op);
- return false;
+ return ::tensorflow::Status::OK();
}
const_array_data[i] = new_value;
}
AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op),
LogName(*src_op), src_op_input);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 323eefcd3a..40cd6dea82 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -32,7 +32,10 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
} // namespace
-bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status PropagateArrayDataTypes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
@@ -40,7 +43,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
if (!model->IsOptionalArray(input) &&
model->GetArray(input).data_type == ArrayDataType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
// Record data types of output before processing, so we can see at the
@@ -131,7 +134,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
auto* rand_op = static_cast<RandomUniformOperator*>(op);
// The output type of RandomUniform is specified with an attribute
if (rand_op->dtype == ArrayDataType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs.size(), 1);
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
@@ -153,7 +156,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// This can make unsupported_op->output_data_types have more elements than
// op->outputs.
if (unsupported_op->output_data_types.size() < op->outputs.size()) {
- return false;
+ return ::tensorflow::Status::OK();
}
for (int i = 0; i < op->outputs.size(); ++i) {
const string& output = op->outputs[i];
@@ -164,7 +167,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
}
case OperatorType::kExpandDims: {
// Yield on ExpandDim until it is converted to Reshape
- return false;
+ return ::tensorflow::Status::OK();
}
case OperatorType::kSelect: {
// Select produces outputs with the same type as their 2nd input
@@ -248,10 +251,11 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// Return true if any output data type changed, false if none changed.
for (const auto& output : op->outputs) {
if (old_output_data_types[output] != model->GetArray(output).data_type) {
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
index cd078ef189..3cf191436d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
@@ -39,7 +39,10 @@ bool SupportsMinMax(const Array& array) {
// When provided a set of min/max values for uint8 arrays this will rescale
// the values for other data types as required and preserving the floating point
// range within the new type.
-bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status PropagateDefaultMinMax::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
@@ -61,7 +64,8 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
}
}
- return did_change;
+ *modified = did_change;
+ return ::tensorflow::Status::OK();
}
// Sets the min/max on the given array, adjusting the reference_minmax for the
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 3ad6b0ec6f..d0113237ce 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -277,11 +277,14 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
// nice logging and integration with the graphviz video dumping mode.
// In general you should not copy this style of transformation and stick to
// local-only changes as seen in the other transformations.
-bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status PropagateFakeQuantNumBits::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
@@ -290,7 +293,7 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
&quantized_data_type)) {
AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
LogName(*op), fakequant_op->num_bits);
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& final_minmax = *fakequant_op->minmax;
@@ -311,7 +314,8 @@ bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
did_change |=
RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
- return did_change;
+ *modified = did_change;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index d056a8add7..5496e2093e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1622,7 +1622,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
} // namespace
-bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status PropagateFixedSizes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
std::unordered_map<string, std::vector<int>> old_output_dims;
@@ -1836,7 +1839,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<TensorFlowUnsupportedOperator*>(op);
// Attribute can be not specified, ignore it.
if (unsupported_op->output_shapes.size() < op->outputs.size()) {
- return false;
+ return ::tensorflow::Status::OK();
}
for (int i = 0; i < op->outputs.size(); ++i) {
const string& output = op->outputs[i];
@@ -1886,10 +1889,11 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
(old_output_dims[output] != model->GetArray(output).shape().dims())) {
AddMessageF("Set shape of %s to [%s]", output,
absl::StrJoin(model->GetArray(output).shape().dims(), ","));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index fb299c31b7..29ea17dc61 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -439,7 +439,9 @@ void FixMinMaxPostQuantization(GraphTransformation* transformation,
} // namespace
-bool Quantize::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status Quantize::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Our general "quantization" graph transformation consists in replacing
// QuantizedInputArrays[] ->
// DequantizeOperators[] ->
@@ -460,7 +462,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
auto& op = *model->operators[op_index];
if (op.type == OperatorType::kDequantize ||
op.type == OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Our assumption here is that the input arrays are already quantized -
@@ -497,7 +499,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
if (!array.minmax && !array.buffer) {
LOG(ERROR) << "Can't quantize input array " << input
<< " because it lacks min/max info";
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* other_op = GetOpWithOutput(*model, input);
if (other_op && other_op->type != OperatorType::kDequantize) {
@@ -507,7 +509,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
"which means that we should yield and let other ops "
"get quantized first",
LogName(op), input);
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
@@ -672,7 +674,8 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
}
}
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
index eaa9d3bcda..0c32218ff2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -51,18 +51,19 @@ bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
} // end namespace
-bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* fq_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (!fq_op->minmax) {
// Need to be resolved first by ResolveFakeQuantArgsFromVars.
- return false;
+ return ::tensorflow::Status::OK();
}
// At this point, this FakeQuantOperator should have a MinMax
@@ -74,7 +75,8 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
bool changed = false;
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]);
changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
index c3b2709a33..fe8023ab8f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -25,11 +25,14 @@ limitations under the License.
namespace toco {
-bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveFinalDequantizeOp::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto dequantize_it = model->operators.begin() + op_index;
const auto* dequantize_op = dequantize_it->get();
if (dequantize_op->type != OperatorType::kDequantize) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& output = dequantize_op->outputs[0];
// We can remove any dequantize op whose output is not consumed by
@@ -38,7 +41,7 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
// in the middle of the graph might be designated as an output
// array.
if (CountOpsWithInput(*model, output)) {
- return false;
+ return ::tensorflow::Status::OK();
}
// If one of the model's output arrays was actually the Dequantize op's
@@ -53,7 +56,8 @@ bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
AddMessageF("Removed final %s", LogName(*dequantize_op));
model->EraseArray(output);
model->operators.erase(dequantize_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
index 73ad326299..be8c0acc7b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -23,11 +23,14 @@ limitations under the License.
namespace toco {
-bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTensorFlowAssert::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto assert_it = model->operators.begin() + op_index;
const auto* assert_op = assert_it->get();
if (assert_op->type != OperatorType::kAssert) {
- return false;
+ return ::tensorflow::Status::OK();
}
bool changed = false;
@@ -54,7 +57,8 @@ bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
// That's it. We can stop here, no need to duplicate the work that
// RemoveUnusedOp will do removing this now-unused node.
- return changed;
+ *modified = changed;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
index 7ec7752f25..37fe5fa3d7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -25,14 +25,18 @@ limitations under the License.
namespace toco {
-bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTensorFlowIdentity::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto passthru_it = model->operators.begin() + op_index;
const auto* passthru_op = passthru_it->get();
if (passthru_op->type != OperatorType::kIdentity) {
- return false;
+ return ::tensorflow::Status::OK();
}
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
index 0dfdc40e4c..68c6fb65c5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -46,14 +46,17 @@ bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
// For example, an Add operator is trivial if
// one of its operands is constant 0, a Mul operator is trivial
// if one of its operands is constant 1, etc.
-bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialBinaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
auto* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
binary_op->type != OperatorType::kSub &&
binary_op->type != OperatorType::kDiv) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -66,12 +69,12 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants
// propagation, not for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -84,7 +87,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input_array_1 = model->GetArray(binary_op->inputs[1]);
if (!input_array_0.has_shape() || !input_array_1.has_shape()) {
// Both input shapes must be known.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_array_0.shape().dimensions_count() ==
input_array_1.shape().dimensions_count() &&
@@ -94,7 +97,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
"(lhs %s, rhs %s)",
LogName(*binary_op), ShapeToString(input_array_0.shape()),
ShapeToString(input_array_1.shape()));
- return false;
+ return ::tensorflow::Status::OK();
}
// Now check if the constant operand makes this binary
@@ -103,7 +106,7 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
model->GetArray(binary_op->inputs[index_of_constant_input]);
// For now, we only handle floats here.
if (constant_input_array.data_type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& constant_input_float_data =
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
@@ -121,12 +124,13 @@ bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
}
if (!is_trivial) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Now we know that this node is trivial, so we can remove it.
AddMessageF("Removing trivial %s", LogName(*binary_op));
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
index 3ceb93d8ee..faaa2a828e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
@@ -25,16 +25,20 @@ limitations under the License.
namespace toco {
-bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialConcatenation::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto concat_it = model->operators.begin() + op_index;
auto* concat_op = concat_it->get();
if (concat_op->type != OperatorType::kConcatenation) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (concat_op->inputs.size() != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
index 936854a04f..ccfc181fe0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -25,7 +25,10 @@ limitations under the License.
namespace toco {
-bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialConcatenationInput::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// TensorFlow allows Concatenation nodes to have 0-D inputs,
// and they are then treated as empty i.e. omitted from concatenation,
// in violation of the notion that 0-D is equivalent to 1x1x1x1.
@@ -36,7 +39,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
const auto concat_it = model->operators.begin() + op_index;
auto* concat_op = concat_it->get();
if (concat_op->type != OperatorType::kConcatenation) {
- return false;
+ return ::tensorflow::Status::OK();
}
std::vector<string> trivial_inputs;
std::vector<string> nontrivial_inputs;
@@ -52,7 +55,7 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
}
if (trivial_inputs.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Drop trivial inputs.
@@ -63,7 +66,8 @@ bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
}
}
concat_op->inputs = nontrivial_inputs;
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
index 2c8d04440f..5448a816bc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
@@ -64,23 +64,27 @@ bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model,
} // namespace
// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc).
-bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialFakeQuant::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto op_it = model->operators.begin() + op_index;
auto* op = op_it->get();
if (op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) {
AddMessageF("%s is not trivial", LogName(*fakequant_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*fakequant_op));
CHECK_EQ(fakequant_op->inputs.size(), 1);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
index 752560e075..4133815285 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
@@ -94,12 +94,13 @@ bool IsTrivialFusedActivationFunc(
// Attempts to remove both fused and unfused activation functions if the
// quantization params indicate that the representable values fall inside the
// activation range.
-bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status RemoveTrivialQuantizedActivationFunc::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->inputs.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (IsTrivialUnfusedActivationFunc(this, *model, op->type, op->inputs[0])) {
@@ -107,7 +108,8 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
"Removing trivial unfused activation function %s because the input "
"minmax imply at least as tight a clamp anyway.",
LogName(*op));
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
if (IsTrivialFusedActivationFunc(this, *model, op->fused_activation_function,
op->outputs[0])) {
@@ -117,9 +119,10 @@ bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
"because the output quantization parameters imply at least as tight "
"a clamp anyway.",
LogName(*op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
index 142c876b15..0f0ae4af69 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
@@ -69,22 +69,26 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
// Attempts to remove min/max functions if the quantization params indicate that
// the representable values fall inside the clip range.
-bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialQuantizedMinMax::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
if ((op->type != OperatorType::kMinimum &&
op->type != OperatorType::kMaximum) ||
op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (IsTrivialMinMax(this, *model, op->type, op->inputs[0], op->inputs[1])) {
AddMessageF(
"Removing trivial min/max %s because the quantization parameters imply "
"at least as tight a clamp anyway.",
LogName(*op));
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
index 5295eeccec..1caf944879 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -81,22 +81,26 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
} // namespace
-bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsReshapeTrivial(*model, *reshape_op, this)) {
AddMessageF("%s is not trivial", LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*reshape_op));
CHECK_EQ(reshape_op->inputs.size(), 2);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc
index 0cbbcd7c81..dcb0148d58 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_slice.cc
@@ -49,21 +49,24 @@ bool IsSliceTrivial(const Model& model, const Operator& op,
} // namespace
-bool RemoveTrivialSlice::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveTrivialSlice::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* slice_op = reshape_it->get();
if (slice_op->type != OperatorType::kSlice) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsSliceTrivial(*model, *slice_op, this)) {
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Removing trivial %s", LogName(*slice_op));
CHECK_EQ(slice_op->inputs.size(), 3);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index dde91234a8..3cd5d06bae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -25,7 +25,9 @@ limitations under the License.
namespace toco {
-bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status RemoveUnusedOp::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
@@ -58,7 +60,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
}
for (const string& output_array : model->flags.output_arrays()) {
if (output == output_array) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
for (const auto& rnn_state : model->flags.rnn_states()) {
@@ -67,19 +69,19 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
if (!IsDiscardableArray(*model, rnn_state.back_edge_source_array()) ||
!IsDiscardableArray(*model, rnn_state.state_array()) ||
CountOpsWithInput(*model, rnn_state.state_array())) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
if (CountOpsWithInput(*model, output)) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
if (op->unresolved_outputs) {
AddMessageF("Not discarding %s because it has unresolved outputs.",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Discarding %s because none of its outputs is used.",
@@ -105,7 +107,8 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
}
}
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
index 550de83018..3c8d411089 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
@@ -63,29 +63,32 @@ bool IsMoveOperator(OperatorType optype) {
// Swap elementwise operators such that all value operators occur before all
// element move operators, e.g. negation then transpose.
-bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ReorderElementwiseUnary::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto element_op_it = model->operators.begin() + op_index;
std::unique_ptr<Operator>& element_op = *element_op_it;
if (!IsElementwiseOperator(element_op->type)) {
- return false;
+ return ::tensorflow::Status::OK();
}
const string intermediate_name = element_op->inputs[0];
auto it = FindOpWithOutput(*model, intermediate_name);
if (it == model->operators.end()) {
AddMessageF("No preceding operator");
- return false;
+ return ::tensorflow::Status::OK();
}
std::unique_ptr<Operator>& move_op = *it;
if (!IsMoveOperator(move_op->type)) {
AddMessageF("Preceding operator is not a move operator");
- return false;
+ return ::tensorflow::Status::OK();
}
if (CountOpsWithInput(*model, intermediate_name) != 1) {
AddMessageF("Input %s used elsewhere", intermediate_name);
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that the intermediate is discardable.
@@ -94,7 +97,7 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
"Cannot swap elementwise as it would invalidate %s which is "
"an output array.",
intermediate_name);
- return false;
+ return ::tensorflow::Status::OK();
}
// op->inputs may change so we need to keep a value by copy.
@@ -147,7 +150,8 @@ bool ReorderElementwiseUnary::Run(Model* model, std::size_t op_index) {
// Swap the order of the operators.
element_op.swap(move_op);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
index c907a597cb..a2c06e71e8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
@@ -101,37 +101,40 @@ std::vector<int> ComputeNewPerm(std::vector<int> input_dims,
// Swaps reshape-transpose to transpose-reshape whenever possible. This is
// possible when the reshape does not affect memory ordering.
-bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ReorderReshapeTranspose::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto transpose_it = model->operators.begin() + op_index;
TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
transpose_it->get(), OperatorType::kTranspose);
if (transpose_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
// Wait for values to propagate.
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the operator that produces the transpose op.
auto reshape_it = FindOpWithOutput(*model, transpose_op->inputs[0]);
if (reshape_it == model->operators.end()) {
- return false;
+ return ::tensorflow::Status::OK();
}
TensorFlowReshapeOperator* reshape_op =
ConvertOperator<TensorFlowReshapeOperator*>(reshape_it->get(),
OperatorType::kReshape);
if (reshape_op == nullptr) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Ignore if the reshape is uninitialized.
if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Need to copy to keep static if permutated.
@@ -142,7 +145,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// Intermediate should not be consumed by any other operators.
if (CountOpsWithInput(*model, intermediate_name) != 1) {
AddMessageF("Input %s used elsewhere", intermediate_name);
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that the intermediate is not an output array.
@@ -151,7 +154,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
"Cannot reorder reshape-transpose as it would invalidate %s which is "
"an output array.",
intermediate_name);
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the arrays.
@@ -173,7 +176,7 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// dimensions then it can be moved between the transpose.
if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
true /*allow_extra_unary_dims*/)) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsDiscardableArray(*model, output_name)) {
@@ -242,7 +245,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
// Swap the order of the operators.
transpose_it->swap(*reshape_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
index 8f2c1f8162..a79779f55d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
-bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveBatchNormalization::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto bn_it = model->operators.begin() + op_index;
if (bn_it->get()->type != OperatorType::kBatchNormalization) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* bn_op =
static_cast<const BatchNormalizationOperator*>(bn_it->get());
@@ -53,7 +56,7 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
// so we need to exit early if these buffers don't exist (i.e. if the params
// haven't yet been resolved as constants).
if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Create the new Mul, Add operators
@@ -142,7 +145,8 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
DCHECK_EQ(bn_it->get(), bn_op);
model->operators.erase(bn_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
index b8b35161d7..d039d7d690 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -24,31 +24,35 @@ limitations under the License.
namespace toco {
-bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveBatchToSpaceNDAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto op_it = model->operators.begin() + op_index;
- if (op_it->get()->type != OperatorType::kBatchToSpaceND) return false;
+ if (op_it->get()->type != OperatorType::kBatchToSpaceND)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<BatchToSpaceNDOperator*>(op_it->get());
// The attributes are resolved only when the 3 attributes (block_shape,
// before_crops, after_crops) are all constant.
if (!op->block_shape.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[1]) ||
!IsConstantParameterArray(*model, op->inputs[2]))
- return false;
+ return ::tensorflow::Status::OK();
// Handle crops
const auto& crops_array = model->GetArray(op->inputs[2]);
- if (!crops_array.has_shape()) return false;
+ if (!crops_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& crops_dims = crops_array.shape().dims();
if (crops_dims.size() != 2) {
// Code only handles crops of 2 dimensions. Perhaps another transformation
// will delete this op.
- return false;
+ return ::tensorflow::Status::OK();
}
const std::vector<int>& crops_buffer =
crops_array.GetBuffer<ArrayDataType::kInt32>().data;
@@ -59,7 +63,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
// Handle block_shape
const auto& block_shape_array = model->GetArray(op->inputs[1]);
- if (!block_shape_array.has_shape()) return false;
+ if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
const std::vector<int>& block_shape_buffer =
@@ -68,7 +72,8 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
op->block_shape.push_back(block_shape_buffer[i]);
}
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index f7e5aa6609..586f546a30 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
}
} // namespace
-bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
@@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input1_array = model->GetArray(binary_op->inputs[1]);
// Check if both inputs are constant parameters.
if (!input0_array.buffer || !input1_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not resolving constant %s because it has a fused activation function",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that input data types agree.
@@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index d916ae0ddf..0c60fdfeb3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -135,11 +135,14 @@ void SetMinMaxForConcatenedArray(GraphTransformation* transformation,
} // namespace
// Resolves the concatenation operator if all its inputs are constant arrays.
-bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantConcatenation::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto concat_it = model->operators.begin() + op_index;
const auto* concat_base_op = concat_it->get();
if (concat_base_op->type != OperatorType::kConcatenation) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* concat_op =
static_cast<const ConcatenationOperator*>(concat_base_op);
@@ -149,11 +152,15 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// We also make sure the shapes of the input arrays are known and they are
// all discardable.
const Operator* input_op = GetOpWithOutput(*model, input_name);
- if (input_op) return false;
- if (!IsConstantParameterArray(*model, input_name)) return false;
- if (!model->GetArray(input_name).has_shape()) return false;
- if (model->GetArray(input_name).quantization_params) return false;
- if (!IsDiscardableArray(*model, input_name)) return false;
+ if (input_op) return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, input_name))
+ return ::tensorflow::Status::OK();
+ if (!model->GetArray(input_name).has_shape())
+ return ::tensorflow::Status::OK();
+ if (model->GetArray(input_name).quantization_params)
+ return ::tensorflow::Status::OK();
+ if (!IsDiscardableArray(*model, input_name))
+ return ::tensorflow::Status::OK();
}
const int concatenation_axis = concat_op->axis;
@@ -205,7 +212,8 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove concatenate operator.
model->operators.erase(concat_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index f5f2f77460..4f330fdd84 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -59,11 +59,14 @@ void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
}
}
-bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantFakeQuant::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
const auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* fakequant_op =
@@ -71,12 +74,12 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
// Yield until the fakequant MinMax has been resolved.
if (!fakequant_op->minmax) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transformation only applies when the input array is constant.
if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
@@ -87,7 +90,7 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
&quantized_data_type)) {
AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Resolving constant %s", LogName(*fakequant_op));
@@ -136,7 +139,8 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
}
model->operators.erase(fakequant_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
index f6f95481b5..5400d395ff 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
@@ -41,11 +41,14 @@ bool ComputeFillArray(Model* model, FillOperator* op) {
return true;
}
-bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantFill::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto fill_it = model->operators.begin() + op_index;
auto* base_op = fill_it->get();
if (base_op->type != OperatorType::kFill) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* op = static_cast<FillOperator*>(base_op);
@@ -55,44 +58,44 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& val_array = model->GetArray(op->inputs[1]);
if (!val_array.has_shape()) {
// Yield until the value shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[1])) {
// Yield until the value is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(val_array.shape()), 1);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!ComputeFillArray<ArrayDataType::kFloat>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kUint8:
if (!ComputeFillArray<ArrayDataType::kUint8>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt32:
if (!ComputeFillArray<ArrayDataType::kInt32>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt64:
if (!ComputeFillArray<ArrayDataType::kInt64>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
default:
@@ -114,7 +117,8 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(fill_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
index 36d7dad0ce..6e3a6a69c2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
@@ -61,11 +61,14 @@ inline void Gather(const Array& input_array, int input_rank,
// Resolves a constant Gather operation.
// This simply performs the gather and produces the output array with the
// appropriate values.
-bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantGather::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kGather) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const GatherOperator*>(base_op);
@@ -74,28 +77,28 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!op->axis) {
// Yield until axis has been set by ResolveGatherAttributes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->axis.value() != 0) {
// Only handling axis=0 for now.
AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op),
op->axis.value());
- return false;
+ return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
const Array& coords_array = model->GetArray(op->inputs[1]);
@@ -142,7 +145,8 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
index e86616574d..e257ec37e8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
@@ -49,11 +49,14 @@ void Pack(Model* model, PackOperator const& op) {
} // namespace
-bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantPack::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kPack) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const PackOperator*>(base_op);
@@ -62,18 +65,18 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(*model, input)) {
// Yield if any input is mutable
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -111,7 +114,8 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
index 88d06d7dc7..db0fbba528 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
@@ -59,11 +59,14 @@ bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
return true;
}
-bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantRandomUniform::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* base_op = it->get();
if (base_op->type != OperatorType::kRandomUniform) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* op = static_cast<RandomUniformOperator*>(base_op);
@@ -73,12 +76,12 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
if ((op->seed == 0) && (op->seed2 == 0)) {
@@ -86,13 +89,13 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
<< "\" is truly random (using /dev/random system entropy). "
"Therefore, cannot resolve as constant. Set \"seed\" or "
"\"seed2\" attr non-zero to fix this";
- return false;
+ return ::tensorflow::Status::OK();
}
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
// For future support of double or half.
@@ -110,7 +113,8 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
index 1a0ba9e2bc..069d4dafaa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
@@ -19,11 +19,14 @@ limitations under the License.
namespace toco {
-bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantRange::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* base_op = it->get();
if (base_op->type != OperatorType::kRange) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* op = static_cast<RangeOperator*>(base_op);
@@ -31,23 +34,23 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& limit_array = model->GetArray(op->inputs[1]);
if (!limit_array.has_shape()) {
// Yield until all input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& delta_array = model->GetArray(op->inputs[2]);
if (!delta_array.has_shape()) {
// Yield until all input dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
for (const auto& input : op->inputs) {
if (!IsConstantParameterArray(*model, input)) {
// yield if any input is mutable
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -55,7 +58,7 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
@@ -101,7 +104,8 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
// Delete the operator
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index a6f665b5f0..fccecef600 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -22,11 +22,14 @@ limitations under the License.
namespace toco {
// Resolves a constant reshape operation by copying the buffer.
-bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kReshape) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
@@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
@@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Constant reshape is non-trivial (%s -> %s)",
ShapeToString(input_array.shape()),
ShapeToString(output_array.shape()));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported data type: "
<< ArrayDataTypeName(input_array.data_type);
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Resolving constant reshape of %s", LogName(*op));
@@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
index e880a3f44d..ab1e0bd7a0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
@@ -27,11 +27,14 @@ namespace toco {
// This implementation is looking strictly for all-or-nothing on the select
// condition. It's possible to enhance this by looking per-element and possibly
// producing a Mul op.
-bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantSelect::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kSelect) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const SelectOperator*>(base_op);
@@ -40,23 +43,23 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
// We require the cond input to be constant.
if (!IsConstantParameterArray(*model, op->inputs[0])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& cond_array = model->GetArray(op->inputs[0]);
CHECK(cond_array.data_type == ArrayDataType::kBool)
<< "Only bool conditions are supported";
const auto& cond_data = cond_array.GetBuffer<ArrayDataType::kBool>().data;
if (cond_data.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Check if the condition is the same for all elements.
@@ -67,12 +70,14 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
"Cannot resolve %s as constant; cond_array has differing "
"per-element values",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
// Pass-through the selected input.
- return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
+ *modified =
+ RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2);
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 8a0e3e8995..a1756a8207 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -19,29 +19,32 @@ limitations under the License.
namespace toco {
-bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been resolved
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the input array's shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
// Compute the output
@@ -65,7 +68,8 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
}
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
index b35c3e19c4..869dfae98e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
@@ -86,11 +86,14 @@ bool Slice(SliceOperator const& op, Array const& input_array,
} // namespace
-bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantSlice::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kSlice) {
- return false;
+ return ::tensorflow::Status::OK();
}
const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
@@ -99,49 +102,49 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->begin.empty() || op->size.empty()) {
// Attributes have not resolved yet.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the value shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[0])) {
// Yield until the value is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kUint8:
if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt32:
if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
case ArrayDataType::kInt64:
if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) {
- return false;
+ return ::tensorflow::Status::OK();
}
break;
default:
@@ -159,7 +162,8 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
// Erase the operator
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 8853ed87e6..99c5a64662 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
} // anonymous namespace
-bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kStridedSlice) {
- return false;
+ return ::tensorflow::Status::OK();
}
const StridedSliceOperator* op =
@@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->start_indices.empty() || op->stop_indices.empty() ||
op->strides.empty()) {
// Attributes have not resolved yet.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the value shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[0])) {
// Yield until the value is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, it->get());
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
index 5cfa1a5582..c5e93c9bad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -97,11 +97,14 @@ inline void Tile(const Array& input_array, const Array& multiples_array,
} // namespace
// Resolves a constant Tile operation.
-bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantTile::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kTile) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TensorFlowTileOperator*>(base_op);
@@ -110,17 +113,17 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
const Array& multiples_array = model->GetArray(op->inputs[1]);
@@ -159,7 +162,8 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
index fe15dfa06f..b759c4d6dd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
@@ -101,11 +101,14 @@ void Transpose(Model* model, const Array& input_array,
} // namespace
-bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantTranspose::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kTranspose) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TransposeOperator*>(base_op);
@@ -114,17 +117,17 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
@@ -132,7 +135,7 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
if (op->perm.empty()) {
// Yield until perm has been populated by ResolveTransposeAttributes.
- return false;
+ return ::tensorflow::Status::OK();
}
// We currently only support 1-4 dimensions.
@@ -174,7 +177,8 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 5364eebbc9..3034c1b1eb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
return true;
}
-bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto unary_it = model->operators.begin() + op_index;
const auto* unary_op = unary_it->get();
// Test for unary ops of types that we know how to resolve.
@@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kRelu:
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
// Check if the input is a constant parameter.
if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
- return false;
+ return ::tensorflow::Status::OK();
}
// if the unary op involves a tensor required by a rnn state, ignore it
for (const auto& rnn_state : model->flags.rnn_states()) {
if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (unary_op->inputs[0] == rnn_state.state_array()) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
auto& output_array = model->GetArray(unary_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until the output array dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s "
" because it has a fused activation function",
LogName(*unary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// The min-max is only copied for ops that copy data without arithmetic.
@@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s because we currently only support casting "
"to float",
LogName(*unary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (cast_op->src_data_type != input_array.buffer->type) {
AddMessageF(
@@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
} else {
if (input_array.buffer->type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
input_float_data = &(input_array.GetBuffer<ArrayDataType::kFloat>().data);
}
@@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
AddMessageF("Axis input is non-constant");
- return false;
+ return ::tensorflow::Status::OK();
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
@@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported activation function "
<< LogName(*unary_op);
- return false;
+ return ::tensorflow::Status::OK();
}
output_float_data[i] = new_value;
}
@@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*unary_op));
model->operators.erase(unary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc
index 0dda1fd0b3..eed971c1d5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc
@@ -25,17 +25,20 @@ limitations under the License.
namespace toco {
-bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveFakeQuantArgsFromVars::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto fakequant_it = model->operators.begin() + op_index;
auto* fakequant_base_op = fakequant_it->get();
if (fakequant_base_op->type != OperatorType::kFakeQuant) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
if (fakequant_op->minmax) {
// Already resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(fakequant_op->inputs.size(), 3);
@@ -43,7 +46,7 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
// resolved to constant arrays.
for (int i = 1; i <= 2; i++) {
if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -74,7 +77,8 @@ bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model);
}
fakequant_op->inputs.resize(1);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
index ce825c91af..69209b8dec 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc
@@ -24,20 +24,25 @@ limitations under the License.
namespace toco {
-bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveGatherAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto* gather_op = model->operators[op_index].get();
- if (gather_op->type != OperatorType::kGather) return false;
+ if (gather_op->type != OperatorType::kGather)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<GatherOperator*>(gather_op);
if (op->axis) {
// Attributes already resolved
- return false;
+ return ::tensorflow::Status::OK();
}
- if (op->inputs.size() != 3) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (op->inputs.size() != 3) return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
const auto& indices_array = model->GetArray(op->inputs[2]);
- if (!indices_array.has_shape()) return false;
+ if (!indices_array.has_shape()) return ::tensorflow::Status::OK();
const auto& axis_data = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
CHECK_EQ(axis_data.size(), 1)
<< "Multidimensional gather not supported on " << LogName(*op);
@@ -47,7 +52,8 @@ bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUsedOnce(op->inputs[2], model);
op->inputs.resize(2);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index b2b2ea151b..ac94f45321 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) {
// Removes a multiplication by array of constant zeros by making the output
// array an array of constant zeros and removing the input arrays if they are no
// longer needed.
-bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveMultiplyByZero::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto mul_it = model->operators.begin() + op_index;
auto* mul_op = mul_it->get();
if (mul_op->type != OperatorType::kMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
if (!IsDiscardableArray(*model, output_array_name)) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
// Yield if the output shape is not known yet.
if (!output_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transformation only handles the case where one operand is all 0's and
@@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants propagation, not
// for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
} break;
@@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
} break;
@@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
} break;
@@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
} break;
default:
AddMessageF(
"Cannot resolve multiply by 0 because of unsupported data type\n");
- return false;
+ return ::tensorflow::Status::OK();
}
// Erase input arrays to the multiply if no longer used
@@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
// Erase the multiply operator.
model->operators.erase(mul_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
index 8a8e723cf7..adc87753bc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -24,19 +24,23 @@ limitations under the License.
namespace toco {
-bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolvePadAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto pad_it = model->operators.begin() + op_index;
auto* pad_op = pad_it->get();
- if (pad_op->type != OperatorType::kPad) return false;
+ if (pad_op->type != OperatorType::kPad) return ::tensorflow::Status::OK();
auto* op = static_cast<PadOperator*>(pad_op);
- if (!op->left_padding.empty()) return false;
+ if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 2);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
const auto& array = model->GetArray(op->inputs[1]);
- if (!array.has_shape()) return false;
+ if (!array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& dims = array.shape().dims();
CHECK_EQ(dims.size(), 2);
@@ -50,6 +54,7 @@ bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
// TODO(dkalenichenko): Delete the extra input?
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc
index ebb023e342..1f0f17a37a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_padv2_attributes.cc
@@ -24,19 +24,23 @@ limitations under the License.
namespace toco {
-bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolvePadV2Attributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto pad_it = model->operators.begin() + op_index;
auto* pad_op = pad_it->get();
- if (pad_op->type != OperatorType::kPadV2) return false;
+ if (pad_op->type != OperatorType::kPadV2) return ::tensorflow::Status::OK();
auto* op = static_cast<PadV2Operator*>(pad_op);
- if (!op->left_padding.empty()) return false;
+ if (!op->left_padding.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 3);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
const auto& array = model->GetArray(op->inputs[1]);
- if (!array.has_shape()) return false;
+ if (!array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& dims = array.shape().dims();
CHECK_EQ(dims.size(), 2);
@@ -50,6 +54,7 @@ bool ResolvePadV2Attributes::Run(Model* model, std::size_t op_index) {
// TODO(dkalenichenko): Delete the extra input?
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 73198ac7c0..c3246ab90f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -39,23 +39,37 @@ bool ResolveAttributes(Model* model, T* op) {
return true;
}
-bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveReduceAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
Operator* op = model->operators[op_index].get();
switch (op->type) {
case OperatorType::kMean:
- return ResolveAttributes(model, static_cast<MeanOperator*>(op));
+ *modified = ResolveAttributes(model, static_cast<MeanOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kSum:
- return ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
+ *modified =
+ ResolveAttributes(model, static_cast<TensorFlowSumOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kReduceProd:
- return ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
+ *modified =
+ ResolveAttributes(model, static_cast<TensorFlowProdOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kReduceMin:
- return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
+ *modified =
+ ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kReduceMax:
- return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ *modified =
+ ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ return ::tensorflow::Status::OK();
case OperatorType::kAny:
- return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ *modified =
+ ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ return ::tensorflow::Status::OK();
default:
- return false;
+ return ::tensorflow::Status::OK();
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8e150db6fa..ee5c4810e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -78,11 +78,13 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
}
}
-bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveReorderAxes::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
auto* op = it->get();
if (op->type != OperatorType::kReorderAxes) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
@@ -93,11 +95,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Yield until output dims have been resolved.
if (!output_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Reorder the input array dims and buffer data
if (input_array.buffer->type == ArrayDataType::kFloat) {
@@ -120,7 +122,8 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, op);
RenameArray(model, output_array_name, input_array_name);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
index b615c9a545..7b7a59264f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -25,25 +25,29 @@ limitations under the License.
namespace toco {
-bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveReshapeAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
- if (!op->shape.empty()) return false;
+ if (!op->shape.empty()) return ::tensorflow::Status::OK();
if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
const auto& constant_input_array = model->GetArray(reshape_op->inputs[1]);
op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
}
- if (op->shape.empty()) return false;
+ if (op->shape.empty()) return ::tensorflow::Status::OK();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
index e760d08e5a..5a838168de 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -24,29 +24,35 @@ limitations under the License.
namespace toco {
-bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveSliceAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
- if (slice_op->type != OperatorType::kSlice) return false;
+ if (slice_op->type != OperatorType::kSlice) return ::tensorflow::Status::OK();
auto* op = static_cast<SliceOperator*>(slice_op);
- if (!op->begin.empty()) return false;
+ if (!op->begin.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 3);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
const auto& begin_array = model->GetArray(op->inputs[1]);
- if (!begin_array.has_shape()) return false;
+ if (!begin_array.has_shape()) return ::tensorflow::Status::OK();
const auto& size_array = model->GetArray(op->inputs[2]);
- if (!size_array.has_shape()) return false;
+ if (!size_array.has_shape()) return ::tensorflow::Status::OK();
op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
// TODO(dkalenichenko): Delete the extra inputs?
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
index fab50bec1f..3804145c4f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
@@ -24,16 +24,20 @@ limitations under the License.
namespace toco {
-bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveSpaceToBatchNDAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto op_it = model->operators.begin() + op_index;
- if (op_it->get()->type != OperatorType::kSpaceToBatchND) return false;
+ if (op_it->get()->type != OperatorType::kSpaceToBatchND)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<SpaceToBatchNDOperator*>(op_it->get());
// The attributes are resolved only when the 3 attributes (block_shape,
// before_paddings, after_paddings) are all constant.
if (!op->block_shape.empty()) {
- return false;
+ return ::tensorflow::Status::OK();
}
const int block_shape_index = 1;
@@ -42,16 +46,16 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 3);
if (!IsConstantParameterArray(*model, op->inputs[block_shape_index]) ||
!IsConstantParameterArray(*model, op->inputs[paddings_index]))
- return false;
+ return ::tensorflow::Status::OK();
// Handle paddings.
const auto& paddings_array = model->GetArray(op->inputs[paddings_index]);
- if (!paddings_array.has_shape()) return false;
+ if (!paddings_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& paddings_dims = paddings_array.shape().dims();
if (paddings_dims.size() != 2) {
// Code only handles padding of 2 dimensions. Perhaps another transformation
// will delete this op.
- return false;
+ return ::tensorflow::Status::OK();
}
const std::vector<int>& paddings_buffer =
paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
@@ -63,7 +67,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
// Handle block_shape.
const auto& block_shape_array =
model->GetArray(op->inputs[block_shape_index]);
- if (!block_shape_array.has_shape()) return false;
+ if (!block_shape_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
const std::vector<int>& block_shape_buffer =
@@ -72,7 +76,8 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
op->block_shape.push_back(block_shape_buffer[i]);
}
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
index e8bb85704e..c601b0774e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
@@ -25,10 +25,13 @@ limitations under the License.
namespace toco {
-bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveSqueezeAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto* squeeze_op = model->operators[op_index].get();
if (squeeze_op->type != OperatorType::kSqueeze) {
- return false;
+ return ::tensorflow::Status::OK();
}
DCHECK_EQ(squeeze_op->inputs.size(), 1);
DCHECK_EQ(squeeze_op->outputs.size(), 1);
@@ -42,10 +45,11 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
"Reshape op",
LogName(*squeeze_op));
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
}
- return false;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index 65132d7d1e..f54f5b42a1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -37,40 +37,47 @@ int PadAttributeArray(Array* attribute_array, std::vector<int> pad_values,
return mask;
}
-bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
- if (slice_op->type != OperatorType::kStridedSlice) return false;
+ if (slice_op->type != OperatorType::kStridedSlice)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<StridedSliceOperator*>(slice_op);
if (!op->start_indices.empty()) {
// We have already resolved these attributes
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->inputs.size(), 4);
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// We require the dimensionality of the input to pad the indices
- return false;
+ return ::tensorflow::Status::OK();
}
auto& start_array = model->GetArray(op->inputs[1]);
- if (!start_array.has_shape()) return false;
+ if (!start_array.has_shape()) return ::tensorflow::Status::OK();
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
// Only 1-4D arrays are supported for now.
- return false;
+ return ::tensorflow::Status::OK();
}
auto& stop_array = model->GetArray(op->inputs[2]);
- if (!stop_array.has_shape()) return false;
+ if (!stop_array.has_shape()) return ::tensorflow::Status::OK();
auto& stride_array = model->GetArray(op->inputs[3]);
- if (!stride_array.has_shape()) return false;
+ if (!stride_array.has_shape()) return ::tensorflow::Status::OK();
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[3]))
+ return ::tensorflow::Status::OK();
int num_input_axes = input_array.shape().dimensions_count();
int start_indices_size = start_array.shape().dims(0);
@@ -112,6 +119,7 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
index fa5ee89933..4927ccd95d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -25,12 +25,15 @@ limitations under the License.
namespace toco {
-bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveTensorFlowConcat::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto concat_it = model->operators.begin() + op_index;
const auto* tf_concat_op = concat_it->get();
if (tf_concat_op->type != OperatorType::kConcat &&
tf_concat_op->type != OperatorType::kConcatV2) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_GE(tf_concat_op->inputs.size(), 2);
@@ -54,7 +57,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
if (!axis_array.buffer) {
AddMessageF("Waiting for the axis of %s to be resolved to a constant",
LogName(*tf_concat_op));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(axis_array.data_type == ArrayDataType::kInt32);
@@ -79,7 +82,8 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
}
// Remove the TensorFlowConcat op
model->operators.erase(concat_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index 65346c4fe4..da039da546 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -55,10 +55,13 @@ TransposeOperator* FindTransposeOpWithInput(const Model& model,
} // namespace
-bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveTensorFlowMatMul::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
@@ -73,7 +76,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
"Not replacing %s by a FullyConnected operator, because it has "
"the transpose_a attribute",
LogName(*matmul_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Reorder the axes on the second input. TensorFlow uses row-major ordering
@@ -198,7 +201,8 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// erase the MatMul operator
model->operators.erase(matmul_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
index 4edffe3d48..9beea3e937 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -24,11 +24,14 @@ limitations under the License.
namespace toco {
-bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveTensorFlowMerge::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto merge_it = model->operators.begin() + op_index;
const auto* merge_op = merge_it->get();
if (merge_op->type != OperatorType::kMerge) {
- return false;
+ return ::tensorflow::Status::OK();
}
// We need to yield until this Merge node has only 1 input, which will mean
@@ -37,7 +40,7 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
// non-selected inputs, so that at some point there will be only 1 input left.
if (merge_op->inputs.size() > 1) {
AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Now that the merge node has 1 input exactly, it is the same as an Identity
@@ -57,7 +60,8 @@ bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
AddMessageF("Removing already-resolved %s", LogName(*merge_op));
model->EraseArray(merge_op->outputs[0]);
model->operators.erase(merge_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
index 8bef440afd..e215981b42 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -24,11 +24,14 @@ limitations under the License.
namespace toco {
-bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveTensorFlowSwitch::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto switch_it = model->operators.begin() + op_index;
const auto* switch_op = switch_it->get();
if (switch_op->type != OperatorType::kSwitch) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(switch_op->inputs.size(), 2);
@@ -40,7 +43,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Waiting for the boolean predicate of %s to be resolved to a constant",
LogName(*switch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// The predicate should be boolean, and should consist of a single value.
@@ -119,7 +122,8 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
// Remove the switch node itself.
AddMessageF("Removing already-resolved %s", LogName(*switch_op));
model->operators.erase(switch_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
index a657ee00af..aa7945391c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
@@ -24,19 +24,24 @@ limitations under the License.
namespace toco {
-bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveTransposeAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto op_it = model->operators.begin() + op_index;
- if (op_it->get()->type != OperatorType::kTranspose) return false;
+ if (op_it->get()->type != OperatorType::kTranspose)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<TransposeOperator*>(op_it->get());
- if (!op->perm.empty()) return false;
+ if (!op->perm.empty()) return ::tensorflow::Status::OK();
CHECK_EQ(op->inputs.size(), 2);
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
// Handling perm.
const auto& perm_array = model->GetArray(op->inputs[1]);
- if (!perm_array.has_shape()) return false;
+ if (!perm_array.has_shape()) return ::tensorflow::Status::OK();
const std::vector<int>& perm_dims = perm_array.shape().dims();
CHECK_EQ(perm_dims.size(), 1);
@@ -47,7 +52,8 @@ bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
op->perm.push_back(perm_buffer[i]);
}
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
index 22c258cec5..e9f24a29ab 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
@@ -24,15 +24,17 @@ limitations under the License.
namespace toco {
-bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
Operator* op = model->operators[op_index].get();
if (op->type != OperatorType::kFullyConnected) {
- return false;
+ return ::tensorflow::Status::OK();
}
FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op);
// Exit if this FC op already has shuffled weights
if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(fc_op->inputs[0]);
const string& weights_name = fc_op->inputs[1];
@@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
output_array.data_type != ArrayDataType::kInt16 ||
!input_array.quantization_params || !weights_array.quantization_params ||
!output_array.quantization_params) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the shapes aren't known
if (!input_array.has_shape() || !weights_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if, based on the known shapes, this FC op is not a GEMV.
// The shuffling of FC weights is only useful to enable fast GEMV paths.
@@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape is not 1D or 2D (possibly with additional inner "
"dimensions of size 1)",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) {
@@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape's leading dimension, i.e. the 'batch size', is not "
"equal to 1 or 4",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights shape isn't an integral multiple of the shuffled
// block shape, 4x16. We don't want to have to write code dealing with
@@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
// two.
const Shape& weights_shape = weights_array.shape();
if (weights_shape.dimensions_count() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
const int rows = weights_shape.dims(0);
const int cols = weights_shape.dims(1);
@@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because its "
"shape isn't a multiple of the shuffling block shape, 4x16",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights aren't already a constant array.
if (!weights_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights are used by more than one op.
if (CountOpsWithInput(*model, weights_name) != 1) {
@@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because that "
"array is consumed by other operators",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Compute the shuffled weights
auto& weights_data =
@@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
shuffled_input_workspace_array.GetOrCreateQuantizationParams() =
input_array.GetQuantizationParams();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
index 66cfed4ac2..e2a6f12481 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -166,7 +166,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ bool modified;
+ ASSERT_TRUE((*graph_transformation_set.begin())
+ ->Run(&model, /*op_index=*/0, &modified)
+ .ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
@@ -185,7 +188,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ bool modified;
+ ASSERT_TRUE((*graph_transformation_set.begin())
+ ->Run(&model, /*op_index=*/0, &modified)
+ .ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
@@ -204,7 +210,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
GraphTransformationsSet graph_transformation_set;
graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
EXPECT_THAT(model.GetArrayMap().size(), 5);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ bool modified;
+ ASSERT_TRUE((*graph_transformation_set.begin())
+ ->Run(&model, /*op_index=*/0, &modified)
+ .ok());
EXPECT_THAT(model.GetArrayMap().size(), 1);
auto& concatenated_array = (*model.GetArrayMap().begin()).second;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
index a53abc9941..57d85a0435 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
@@ -50,7 +50,8 @@ void RunResolveSum(const std::vector<float>& input,
sum_op->inputs = {"input0", "input1"};
sum_op->outputs = {"output"};
model.operators.push_back(std::move(sum_op));
- ResolveConstantUnaryOperator().Run(&model, 0);
+ bool modified;
+ ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok());
EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
expected_output);
EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
index 69bad2fa89..4ada5c3fd0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -25,13 +25,16 @@ limitations under the License.
namespace toco {
-bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status UnfuseActivationFunctions::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
// If a conv operation has an im2col array, yield: it should be dropped first.
if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* ac_op = nullptr;
@@ -46,7 +49,7 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op = new Relu1Operator;
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
// At this point we know that the op has a fused activation function. At the
@@ -74,7 +77,8 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
ac_op->inputs = {tmp_array_name};
op->outputs = {tmp_array_name};
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
index dd9e26e68b..e19527968d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
@@ -22,7 +22,10 @@ limitations under the License.
namespace toco {
-bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
// https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
// This transform attempts to identify the len(params) > 1 case and collapse
@@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// First look for the final DynamicStitch.
auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kDynamicStitch) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
@@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is unexpected",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!indices_partition_op) {
indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
@@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
@@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// The data for the indices must be a constant range of the array shape.
if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
AddMessageF("Skipping because indices partition data is non-constant");
- return false;
+ return ::tensorflow::Status::OK();
}
auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
if (indices_data_array.data_type == ArrayDataType::kNone) {
// Yield until data types are propagated.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
<< "Indices partition inputs must be int32";
@@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into %s "
"is unexpected",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
gather_ops.push_back(static_cast<GatherOperator*>(op));
}
@@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is unexpected",
LogName(*op), LogName(*gather_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!data_partition_op) {
data_partition_op = static_cast<DynamicPartitionOperator*>(op);
@@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*gather_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
@@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, indices_partition_op);
DeleteOpAndArraysIfUnused(model, data_partition_op);
DeleteOpAndArraysIfUnused(model, stitch_op);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
index fedf4441e2..5ff39aa313 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc
@@ -36,10 +36,12 @@ namespace toco {
// slice_c = tf.matmul(slice_a, slice_b)
// result_slices[bat] = slice_c
// result = tf.stack(result_slices)
-bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto batch_op_it = model->operators.begin() + op_index;
if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* batch_op =
static_cast<const BatchMatMulOperator*>(batch_op_it->get());
@@ -47,7 +49,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
// We must have the shape of at least one input to know our batch size.
const auto& input_array_a = model->GetArray(batch_op->inputs[0]);
const auto& input_array_b = model->GetArray(batch_op->inputs[1]);
- if (!input_array_a.has_shape() || !input_array_b.has_shape()) return false;
+ if (!input_array_a.has_shape() || !input_array_b.has_shape())
+ return ::tensorflow::Status::OK();
// We only support the rank 3 case. If you are batching on rank > 3 you'll
// have to figure that out.
@@ -66,7 +69,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
batch_op_it = matmul_op_it + 1;
CHECK_EQ(batch_op_it->get(), batch_op);
model->operators.erase(batch_op_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(input_array_a.shape().dimensions_count(), 3)
<< "Input arrays must have rank 3";
@@ -167,7 +171,8 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) {
CHECK(batch_op_it != model->operators.end());
CHECK(batch_op_it->get() == batch_op);
model->operators.erase(batch_op_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco