diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 09:32:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 09:36:34 -0700 |
commit | ce41d2f95e1e5883f1808030c94fd9aaa57d9f10 (patch) | |
tree | 52b27f286cf79177fee601e7e1555a2b1e81d43e /tensorflow/contrib/lite/toco/tooling_util.cc | |
parent | 28757ad658243526d84fd16d53b9eefbf809c6ff (diff) |
Generate an error when --rnn_states refers to array names that aren't produced/consumed by any op.
PiperOrigin-RevId: 215402308
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 4a1ae35cb5..b87e01fbf0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -843,24 +843,40 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) { } void CheckNonExistentIOArrays(const Model& model) { + // "non-existent" is interpreted in the stronger sense of + // "not actually produced/consumed by an op". + // Rationale: we have to artificially fix up TensorFlow graphs by creating + // any array that it refers to, so just checking that arrays exist isn't + // sufficient. The real invariant here is whether arrays are produced/consumed + // by something. if (model.flags.allow_nonexistent_arrays()) { return; } for (const auto& input_array : model.flags.input_arrays()) { - CHECK(model.HasArray(input_array.name())) - << "Input array not found: " << input_array.name(); + QCHECK(GetOpWithInput(model, input_array.name())) + << "Specified input array " << input_array.name() + << " is not consumed by any op in this graph. Is it a typo?"; } for (const string& output_array : model.flags.output_arrays()) { - CHECK(model.HasArray(output_array)) - << "Output array not found: " << output_array; + QCHECK(GetOpWithOutput(model, output_array)) + << "Specified output array " << output_array + << " is not produced by any op in this graph. Is it a typo?"; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { - CHECK(model.HasArray(rnn_state.state_array())); - CHECK(model.HasArray(rnn_state.back_edge_source_array())); + // Check that all RNN states are consumed + QCHECK(GetOpWithInput(model, rnn_state.state_array())) + << "Specified RNN state " << rnn_state.state_array() + << " is not consumed by any op in this graph. Is it a typo?"; + // Check that all RNN back-edge source arrays are produced + QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array())) + << "Specified RNN back-edge source array " + << rnn_state.back_edge_source_array() + << " is not produced by any op in this graph. Is it a typo?"; } } } + } // namespace void CheckNoMissingArray(const Model& model) { @@ -1597,6 +1613,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { input_array.GetOrCreateMinMax() = input_minmax; } } + // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), |