diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 09:32:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 09:36:34 -0700 |
commit | ce41d2f95e1e5883f1808030c94fd9aaa57d9f10 (patch) | |
tree | 52b27f286cf79177fee601e7e1555a2b1e81d43e /tensorflow/contrib/lite/toco | |
parent | 28757ad658243526d84fd16d53b9eefbf809c6ff (diff) |
Generate an error when --rnn_states refers to array names that aren't produced/consumed by any op.
PiperOrigin-RevId: 215402308
Diffstat (limited to 'tensorflow/contrib/lite/toco')
3 files changed, 41 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index 4bb1217828..b2b2ea151b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -60,6 +60,10 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { const auto& output_array_name = mul_op->outputs[0]; auto& output_array = model->GetArray(output_array_name); + if (!IsDiscardableArray(*model, output_array_name)) { + return false; + } + if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes return false; @@ -139,14 +143,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { } // Erase input arrays to the multiply if no longer used - if (IsDiscardableArray(*model, mul_op->inputs[0]) && - CountOpsWithInput(*model, mul_op->inputs[0]) == 1) { - model->EraseArray(mul_op->inputs[0]); - } - if (IsDiscardableArray(*model, mul_op->inputs[1]) && - CountOpsWithInput(*model, mul_op->inputs[1]) == 1) { - model->EraseArray(mul_op->inputs[1]); - } + DeleteArrayIfUsedOnce(mul_op->inputs[0], model); + DeleteArrayIfUsedOnce(mul_op->inputs[1], model); // Erase the multiply operator. model->operators.erase(mul_it); diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index d34da63e43..b6a401aaf2 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -394,12 +394,18 @@ void ReadModelFlagsFromCommandLineFlags( } } - model_flags->set_allow_nonascii_arrays( - parsed_model_flags.allow_nonascii_arrays.value()); - model_flags->set_allow_nonexistent_arrays( - parsed_model_flags.allow_nonexistent_arrays.value()); - model_flags->set_change_concat_input_ranges( - parsed_model_flags.change_concat_input_ranges.value()); + if (!model_flags->has_allow_nonascii_arrays()) { + model_flags->set_allow_nonascii_arrays( + parsed_model_flags.allow_nonascii_arrays.value()); + } + if (!model_flags->has_allow_nonexistent_arrays()) { + model_flags->set_allow_nonexistent_arrays( + parsed_model_flags.allow_nonexistent_arrays.value()); + } + if (!model_flags->has_change_concat_input_ranges()) { + model_flags->set_change_concat_input_ranges( + parsed_model_flags.change_concat_input_ranges.value()); + } if (parsed_model_flags.arrays_extra_info_file.specified()) { string arrays_extra_info_file_contents; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 4a1ae35cb5..b87e01fbf0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -843,24 +843,40 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) { } void CheckNonExistentIOArrays(const Model& model) { + // "non-existent" is interpreted in the stronger sense of + // "not actually produced/consumed by an op". + // Rationale: we have to artificially fix up TensorFlow graphs by creating + // any array that it refers to, so just checking that arrays exist isn't + // sufficient. The real invariant here is whether arrays are produced/consumed + // by something. if (model.flags.allow_nonexistent_arrays()) { return; } for (const auto& input_array : model.flags.input_arrays()) { - CHECK(model.HasArray(input_array.name())) - << "Input array not found: " << input_array.name(); + QCHECK(GetOpWithInput(model, input_array.name())) + << "Specified input array " << input_array.name() + << " is not consumed by any op in this graph. Is it a typo?"; } for (const string& output_array : model.flags.output_arrays()) { - CHECK(model.HasArray(output_array)) - << "Output array not found: " << output_array; + QCHECK(GetOpWithOutput(model, output_array)) + << "Specified output array " << output_array + << " is not produced by any op in this graph. Is it a typo?"; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { - CHECK(model.HasArray(rnn_state.state_array())); - CHECK(model.HasArray(rnn_state.back_edge_source_array())); + // Check that all RNN states are consumed + QCHECK(GetOpWithInput(model, rnn_state.state_array())) + << "Specified RNN state " << rnn_state.state_array() + << " is not consumed by any op in this graph. Is it a typo?"; + // Check that all RNN back-edge source arrays are produced + QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array())) + << "Specified RNN back-edge source array " + << rnn_state.back_edge_source_array() + << " is not produced by any op in this graph. Is it a typo?"; } } } + } // namespace void CheckNoMissingArray(const Model& model) { @@ -1597,6 +1613,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { input_array.GetOrCreateMinMax() = input_minmax; } } + // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), |