aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 09:32:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 09:36:34 -0700
commitce41d2f95e1e5883f1808030c94fd9aaa57d9f10 (patch)
tree52b27f286cf79177fee601e7e1555a2b1e81d43e /tensorflow/contrib/lite/toco
parent28757ad658243526d84fd16d53b9eefbf809c6ff (diff)
Generate an error when --rnn_states refers to array names that aren't produced/consumed by any op.
PiperOrigin-RevId: 215402308
Diffstat (limited to 'tensorflow/contrib/lite/toco')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc14
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc29
3 files changed, 41 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index 4bb1217828..b2b2ea151b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -60,6 +60,10 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
+ if (!IsDiscardableArray(*model, output_array_name)) {
+ return false;
+ }
+
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
@@ -139,14 +143,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
}
// Erase input arrays to the multiply if no longer used
- if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
- CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
- model->EraseArray(mul_op->inputs[0]);
- }
- if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
- CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
- model->EraseArray(mul_op->inputs[1]);
- }
+ DeleteArrayIfUsedOnce(mul_op->inputs[0], model);
+ DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
// Erase the multiply operator.
model->operators.erase(mul_it);
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index d34da63e43..b6a401aaf2 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -394,12 +394,18 @@ void ReadModelFlagsFromCommandLineFlags(
}
}
- model_flags->set_allow_nonascii_arrays(
- parsed_model_flags.allow_nonascii_arrays.value());
- model_flags->set_allow_nonexistent_arrays(
- parsed_model_flags.allow_nonexistent_arrays.value());
- model_flags->set_change_concat_input_ranges(
- parsed_model_flags.change_concat_input_ranges.value());
+ if (!model_flags->has_allow_nonascii_arrays()) {
+ model_flags->set_allow_nonascii_arrays(
+ parsed_model_flags.allow_nonascii_arrays.value());
+ }
+ if (!model_flags->has_allow_nonexistent_arrays()) {
+ model_flags->set_allow_nonexistent_arrays(
+ parsed_model_flags.allow_nonexistent_arrays.value());
+ }
+ if (!model_flags->has_change_concat_input_ranges()) {
+ model_flags->set_change_concat_input_ranges(
+ parsed_model_flags.change_concat_input_ranges.value());
+ }
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 4a1ae35cb5..b87e01fbf0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -843,24 +843,40 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
}
void CheckNonExistentIOArrays(const Model& model) {
+ // "non-existent" is interpreted in the stronger sense of
+ // "not actually produced/consumed by an op".
+ // Rationale: we have to artificially fix up TensorFlow graphs by creating
+ // any array that it refers to, so just checking that arrays exist isn't
+ // sufficient. The real invariant here is whether arrays are produced/consumed
+ // by something.
if (model.flags.allow_nonexistent_arrays()) {
return;
}
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.HasArray(input_array.name()))
- << "Input array not found: " << input_array.name();
+ QCHECK(GetOpWithInput(model, input_array.name()))
+ << "Specified input array " << input_array.name()
+ << " is not consumed by any op in this graph. Is it a typo?";
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.HasArray(output_array))
- << "Output array not found: " << output_array;
+ QCHECK(GetOpWithOutput(model, output_array))
+ << "Specified output array " << output_array
+ << " is not produced by any op in this graph. Is it a typo?";
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.HasArray(rnn_state.state_array()));
- CHECK(model.HasArray(rnn_state.back_edge_source_array()));
+ // Check that all RNN states are consumed
+ QCHECK(GetOpWithInput(model, rnn_state.state_array()))
+ << "Specified RNN state " << rnn_state.state_array()
+ << " is not consumed by any op in this graph. Is it a typo?";
+ // Check that all RNN back-edge source arrays are produced
+ QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
+ << "Specified RNN back-edge source array "
+ << rnn_state.back_edge_source_array()
+ << " is not produced by any op in this graph. Is it a typo?";
}
}
}
+
} // namespace
void CheckNoMissingArray(const Model& model) {
@@ -1597,6 +1613,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.GetOrCreateMinMax() = input_minmax;
}
}
+
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),