aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-12 10:58:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 11:02:23 -0800
commit968da4bf2722b1303cc223e8342357d62c27dfc1 (patch)
treedab12578913f6bfc11b50d92f39ae95fd7301963 /tensorflow/c/c_api.cc
parentc8a5ffdeb2a17df2d2822c7a6df8a44f3ab85614 (diff)
Raise exception on missing unused input_map keys with C API enabled.
Without this change, the C++ ImportGraphDef API returns unused input_map keys (which are plumbed through to the C API as well). However, the Python import_graph_def API requires slightly different semantics: it throws an error for unused input_map keys that are missing from the GraphDef. This change modifies the C and C++ APIs to limit the returned keys to those missing from the GraphDef, and plumbs this through to the C API-enabled import_graph_def implementation. Note that this is a change to the existing C API. Luckily the modified method hasn't been released yet, so it's ok to change it. PiperOrigin-RevId: 178783957
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc37
1 files changed, 20 insertions, 17 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 13253ced49..6f5abd074c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1917,12 +1917,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
*opers = results->return_nodes.data();
}
-void TF_ImportGraphDefResultsUnusedInputMappings(
- TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
+void TF_ImportGraphDefResultsMissingUnusedInputMappings(
+ TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
const char*** src_names, int** src_indexes) {
- *num_unused_input_mappings = results->unused_key_names.size();
- *src_names = results->unused_key_names.data();
- *src_indexes = results->unused_key_indexes.data();
+ *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
+ *src_names = results->missing_unused_key_names.data();
+ *src_indexes = results->missing_unused_key_indexes.data();
}
void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
@@ -1962,18 +1962,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
}
- // Populate unused map keys
- DCHECK(tf_results->unused_key_names.empty());
- DCHECK(tf_results->unused_key_indexes.empty());
- DCHECK(tf_results->unused_key_names_data.empty());
- tf_results->unused_key_names.resize(results.unused_input_map_keys.size());
- tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size());
- for (int i = 0; i < results.unused_input_map_keys.size(); ++i) {
- TensorId id = results.unused_input_map_keys[i];
- tf_results->unused_key_names_data.push_back(id.first.ToString());
- tf_results->unused_key_names[i] =
- tf_results->unused_key_names_data.back().c_str();
- tf_results->unused_key_indexes[i] = id.second;
+ // Populate missing unused map keys
+ DCHECK(tf_results->missing_unused_key_names.empty());
+ DCHECK(tf_results->missing_unused_key_indexes.empty());
+ DCHECK(tf_results->missing_unused_key_names_data.empty());
+
+ size_t size = results.missing_unused_input_map_keys.size();
+ tf_results->missing_unused_key_names.resize(size);
+ tf_results->missing_unused_key_indexes.resize(size);
+
+ for (int i = 0; i < size; ++i) {
+ TensorId id = results.missing_unused_input_map_keys[i];
+ tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
+ tf_results->missing_unused_key_names[i] =
+ tf_results->missing_unused_key_names_data.back().c_str();
+ tf_results->missing_unused_key_indexes[i] = id.second;
}
}