aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 09:32:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 09:36:34 -0700
commitce41d2f95e1e5883f1808030c94fd9aaa57d9f10 (patch)
tree52b27f286cf79177fee601e7e1555a2b1e81d43e /tensorflow/contrib/lite/toco/tooling_util.cc
parent28757ad658243526d84fd16d53b9eefbf809c6ff (diff)
Generate an error when --rnn_states refers to array names that aren't produced/consumed by any op.
PiperOrigin-RevId: 215402308
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc29
1 files changed, 23 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 4a1ae35cb5..b87e01fbf0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -843,24 +843,40 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
}
void CheckNonExistentIOArrays(const Model& model) {
+ // "non-existent" is interpreted in the stronger sense of
+ // "not actually produced/consumed by an op".
+ // Rationale: we have to artificially fix up TensorFlow graphs by creating
+ // any array that it refers to, so just checking that arrays exist isn't
+ // sufficient. The real invariant here is whether arrays are produced/consumed
+ // by something.
if (model.flags.allow_nonexistent_arrays()) {
return;
}
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.HasArray(input_array.name()))
- << "Input array not found: " << input_array.name();
+ QCHECK(GetOpWithInput(model, input_array.name()))
+ << "Specified input array " << input_array.name()
+ << " is not consumed by any op in this graph. Is it a typo?";
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.HasArray(output_array))
- << "Output array not found: " << output_array;
+ QCHECK(GetOpWithOutput(model, output_array))
+ << "Specified output array " << output_array
+ << " is not produced by any op in this graph. Is it a typo?";
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.HasArray(rnn_state.state_array()));
- CHECK(model.HasArray(rnn_state.back_edge_source_array()));
+ // Check that all RNN states are consumed
+ QCHECK(GetOpWithInput(model, rnn_state.state_array()))
+ << "Specified RNN state " << rnn_state.state_array()
+ << " is not consumed by any op in this graph. Is it a typo?";
+ // Check that all RNN back-edge source arrays are produced
+ QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
+ << "Specified RNN back-edge source array "
+ << rnn_state.back_edge_source_array()
+ << " is not produced by any op in this graph. Is it a typo?";
}
}
}
+
} // namespace
void CheckNoMissingArray(const Model& model) {
@@ -1597,6 +1613,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.GetOrCreateMinMax() = input_minmax;
}
}
+
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),