diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-12-12 10:58:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-12 11:02:23 -0800 |
commit | 968da4bf2722b1303cc223e8342357d62c27dfc1 (patch) | |
tree | dab12578913f6bfc11b50d92f39ae95fd7301963 | |
parent | c8a5ffdeb2a17df2d2822c7a6df8a44f3ab85614 (diff) |
Raise exception on missing unused input_map keys with C API enabled.
Without this change, the C++ ImportGraphDef API returns unused
input_map keys (which are plumbed through to the C API as
well). However, the Python import_graph_def API requires slightly
different semantics: it throws an error for unused input_map keys that
are missing from the GraphDef.
This change modifies the C and C++ APIs to limit the returned keys to
those missing from the GraphDef, and plumbs this through to the C
API-enabled import_graph_def implementation.
Note that this is a change to the existing C API. Luckily the modified
method hasn't been released yet, so it's ok to change it.
PiperOrigin-RevId: 178783957
-rw-r--r-- | tensorflow/c/c_api.cc | 37 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 14 | ||||
-rw-r--r-- | tensorflow/c/c_api_internal.h | 8 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 71 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.h | 7 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 54 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 17 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 15 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/importer.py | 12 | ||||
-rw-r--r-- | tensorflow/python/framework/importer_test.py | 14 |
13 files changed, 179 insertions, 79 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 13253ced49..6f5abd074c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1917,12 +1917,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, *opers = results->return_nodes.data(); } -void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes) { - *num_unused_input_mappings = results->unused_key_names.size(); - *src_names = results->unused_key_names.data(); - *src_indexes = results->unused_key_indexes.data(); + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); } void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { @@ -1962,18 +1962,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); } - // Populate unused map keys - DCHECK(tf_results->unused_key_names.empty()); - DCHECK(tf_results->unused_key_indexes.empty()); - DCHECK(tf_results->unused_key_names_data.empty()); - tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); - tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); - for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { - TensorId id = results.unused_input_map_keys[i]; - tf_results->unused_key_names_data.push_back(id.first.ToString()); - tf_results->unused_key_names[i] = - tf_results->unused_key_names_data.back().c_str(); - tf_results->unused_key_indexes[i] = id.second; + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; } } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index df7fe222b1..de9527f86d 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -962,16 +962,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); // Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any -// node in the imported graph def. The number of fetched mappings is returned in -// `num_unused_input_mappings`. The array of each mapping's source node name is -// returned in `src_names`, and the array of each mapping's source index is -// returned in `src_indexes`. +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. // // `*src_names`, `*src_indexes`, and the memory backing each string in // `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes); // Deletes a results object returned by TF_GraphImportGraphDefWithResults(). diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index aac333d9e2..6df77a7f9b 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -143,11 +143,11 @@ struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefResults { std::vector<TF_Output> return_tensors; std::vector<TF_Operation*> return_nodes; - std::vector<const char*> unused_key_names; - std::vector<int> unused_key_indexes; + std::vector<const char*> missing_unused_key_names; + std::vector<int> missing_unused_key_indexes; - // Backing memory for unused_key_names values. - std::list<tensorflow::string> unused_key_names_data; + // Backing memory for missing_unused_key_names values. + std::list<tensorflow::string> missing_unused_key_names_data; }; struct TF_DeviceList { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 6ec1db8ccf..4e89b4fc43 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -773,7 +773,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) { TF_DeleteStatus(s); } -TEST(CAPI, ImportGraphDef_UnusedInputMappings) { +TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -816,7 +816,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) { int num_unused_input_mappings; const char** src_names; int* src_indexes; - TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResultsMissingUnusedInputMappings( results, &num_unused_input_mappings, &src_names, &src_indexes); ASSERT_EQ(1, num_unused_input_mappings); EXPECT_EQ(string("fake"), string(src_names[0])); diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6e72d73918..e19f4aebba 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -113,20 +113,20 @@ class GraphConstructor { typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; // versions and library may be nullptr - static Status Construct(const Options& opts, NodeDefSlice node_defs, - const VersionDef* versions, - const FunctionDefLibrary* library, Graph* g, - ShapeRefiner* refiner, - std::vector<std::pair<Node*, int>>* return_tensors, - std::vector<Node*>* return_nodes, - std::vector<TensorId>* unused_input_map_keys) { + static Status Construct( + const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, + std::vector<std::pair<Node*, int>>* return_tensors, + std::vector<Node*>* return_nodes, + std::vector<TensorId>* missing_unused_input_map_keys) { if (versions) { TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph")); } GraphConstructor c(opts, node_defs, versions, library, g, refiner, - return_tensors, return_nodes, unused_input_map_keys); + return_tensors, return_nodes, + missing_unused_input_map_keys); const Status s = c.TryImport(); if (!s.ok()) c.Undo(); return s; @@ -139,7 +139,7 @@ class GraphConstructor { ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, std::vector<Node*>* return_nodes, - std::vector<TensorId>* unused_input_map_keys) + std::vector<TensorId>* missing_unused_input_map_keys) : opts_(opts), node_defs_(node_defs), versions_(versions), @@ -150,7 +150,7 @@ class GraphConstructor { refiner_(refiner), return_tensors_(return_tensors), return_nodes_(return_nodes), - unused_input_map_keys_(unused_input_map_keys) {} + missing_unused_input_map_keys_(missing_unused_input_map_keys) {} Status TryImport() { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); @@ -162,6 +162,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(UpdateVersionDef()); TF_RETURN_IF_ERROR(PopulateReturnTensors()); TF_RETURN_IF_ERROR(PopulateReturnNodes()); + TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); UpdateUniquifiedColocationNames(); FixupSourceAndSinkEdges(g_); return Status::OK(); @@ -176,6 +177,7 @@ class GraphConstructor { Status UpdateVersionDef(); Status PopulateReturnTensors(); Status PopulateReturnNodes(); + Status PopulateMissingUnusedInputMapKeys(); void Undo(); @@ -242,9 +244,10 @@ class GraphConstructor { std::vector<Node*>* return_nodes_; // May be null. Not owned. - std::vector<TensorId>* unused_input_map_keys_; + std::vector<TensorId>* missing_unused_input_map_keys_; - // Intermediate datastructure used to populate `unused_input_map_keys_`. + // Intermediate datastructure used to populate + // `missing_unused_input_map_keys_`. std::set<TensorId> used_input_map_keys_; // Mapping from node name to the index within node_defs_. @@ -1024,15 +1027,6 @@ Status GraphConstructor::Convert() { " nodes in a cycle"); } - // Update unused_input_map_keys_ - if (unused_input_map_keys_ != nullptr) { - for (const auto& pair : opts_.input_map) { - if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) { - unused_input_map_keys_->push_back(pair.first); - } - } - } - return Status::OK(); } @@ -1122,6 +1116,33 @@ Status GraphConstructor::PopulateReturnNodes() { return Status::OK(); } +Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { + if (missing_unused_input_map_keys_ == nullptr) return Status::OK(); + for (const auto& input_map_pair : opts_.input_map) { + TensorId key = input_map_pair.first; + if (used_input_map_keys_.count(key) > 0) continue; + + auto pair = gdef_nodes_.find(key.first); + if (pair == gdef_nodes_.end()) { + // key's node doesn't exist in GraphDef + missing_unused_input_map_keys_->push_back(key); + continue; + } + + // Check that key's index is in bounds. Get the number of outputs from the + // NodeDef, rather than the imported Node, since the Node may not exist if + // opts_.skip_mapped_nodes is true. + const NodeDef* node_def = node_defs_[pair->second.gdef_index]; + const OpDef* op_def; + TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); + if (key.second >= op_def->output_arg_size()) { + // key's index out of bounds + missing_unused_input_map_keys_->push_back(key); + } + } + return Status::OK(); +} + void GraphConstructor::Undo() { for (const auto& iter : gdef_nodes_) { if (iter.second.node != nullptr) { @@ -1153,7 +1174,7 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, - /*unused_input_map_keys=*/nullptr); + /*missing_unused_input_map_keys=*/nullptr); } Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, @@ -1167,7 +1188,7 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, &refiner, /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, - /*unused_input_map_keys=*/nullptr); + /*missing_unused_input_map_keys=*/nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, @@ -1196,7 +1217,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, if (results != nullptr) { if (!results->return_tensors.empty() || !results->return_nodes.empty() || - !results->unused_input_map_keys.empty()) { + !results->missing_unused_input_map_keys.empty()) { return errors::InvalidArgument( "All fields in results argument to ImportGraphDef() must be empty."); } @@ -1239,7 +1260,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, &results->return_tensors, &results->return_nodes, - &results->unused_input_map_keys); + &results->missing_unused_input_map_keys); } } diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index b4dd2ba51a..07814b2ef7 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -148,9 +148,10 @@ struct ImportGraphDefResults { // The requested nodes associated with ImportGraphDefOptions::return_nodes. std::vector<Node*> return_nodes; - // Keys in ImportGraphDefOptions::input_map that weren't used as an input to - // any node in`gdef`. - std::vector<TensorId> unused_input_map_keys; + // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and + // weren't used as an input to any node in `gdef`. These keys are likely due + // to typos, and callers may wish to treat their existence as an error. + std::vector<TensorId> missing_unused_input_map_keys; }; // Adds the graph in GraphDef `gdef` into an existing Graph `*g`. diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 9be3de2388..01bb1ac748 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1433,7 +1433,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) { &refiner); } -TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { +TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); // No input map @@ -1443,10 +1443,10 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { "node { name: 'W1' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }", opts, &refiner, &results); - EXPECT_TRUE(results.unused_input_map_keys.empty()); + EXPECT_TRUE(results.missing_unused_input_map_keys.empty()); - // Non-empty unused_input_map_keys - results.unused_input_map_keys.push_back(TensorId()); + // Non-empty missing_unused_input_map_keys + results.missing_unused_input_map_keys.push_back(TensorId()); ExpectError( "node { name: 'W2' op: 'TestParams' }", opts, {"All fields in results argument to ImportGraphDef() must be empty."}, @@ -1454,13 +1454,16 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { // Input map with some used, some unused keys const int kControlSlot = Graph::kControlSlot; - results.unused_input_map_keys.clear(); + results.missing_unused_input_map_keys.clear(); opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot); opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0); - opts.input_map[TensorId("new_input", kControlSlot)] = - TensorId("input", kControlSlot); - opts.input_map[TensorId("t1", 1)] = TensorId("input", 0); + // Unused and missing (nonexistent index) + opts.input_map[TensorId("new_input", 3)] = TensorId("input", 0); + // Unused and missing (nonexistent node) + opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0); + // Unused but not missing + opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0); ExpectOK( R"EOF( node { name: 'W2' op: 'TestParams' } @@ -1470,9 +1473,36 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { )EOF", opts, &refiner, &results); - std::vector<TensorId> expected_unused_keys = { - TensorId("new_input", kControlSlot), TensorId("t1", 1)}; - EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys); + std::set<TensorId> expected_unused_keys = {TensorId("new_input", 3), + TensorId("DNE", 0)}; + ASSERT_EQ(results.missing_unused_input_map_keys.size(), + expected_unused_keys.size()); + + std::set<TensorId> actual_unused_keys( + results.missing_unused_input_map_keys.begin(), + results.missing_unused_input_map_keys.end()); + EXPECT_EQ(actual_unused_keys, expected_unused_keys); + + // Test edge case: node isn't imported due to skip_mapped_nodes, but we still + // have a bad input_map key involving it. + opts = ImportGraphDefOptions(); + opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); + opts.input_map[TensorId("new_input", 1)] = TensorId("input", 1); + // Index out of bounds + opts.input_map[TensorId("new_input", 2)] = TensorId("input", 1); + opts.skip_mapped_nodes = true; + opts.prefix = "import"; + results = ImportGraphDefResults(); + ExpectOK( + R"EOF( + node { name: 'W2' op: 'TestParams' } + node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] } + node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } + )EOF", + opts, &refiner, &results); + + ASSERT_EQ(results.missing_unused_input_map_keys.size(), 1); + EXPECT_EQ(results.missing_unused_input_map_keys[0], TensorId("new_input", 2)); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) { @@ -1709,7 +1739,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) { // Check return tensors ASSERT_EQ(results.return_nodes.size(), 2); EXPECT_EQ(results.return_tensors.size(), 0); - EXPECT_EQ(results.unused_input_map_keys.size(), 0); + EXPECT_EQ(results.missing_unused_input_map_keys.size(), 0); EXPECT_EQ(results.return_nodes[0]->name(), "input"); EXPECT_EQ(results.return_nodes[1]->name(), "t1"); diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 3566a36ddd..20944d1678 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3018,6 +3018,7 @@ tf_cuda_library( "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//third_party/py/numpy:headers", diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index f57c5d73bc..e424e19c77 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -183,6 +183,23 @@ tensorflow::TF_OperationOutputConsumers_wrapper { } } +%ignore TF_ImportGraphDefResultsMissingUnusedInputMappings; +%unignore TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper; + +%typemap(out) std::vector<string> +TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ + $result = PyList_New($1.size()); + if (!$result) { + SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list"); + } + for (size_t i = 0; i < $1.size(); ++i) { + const string& input_str = $1[i]; + PyList_SET_ITEM($result, i, PyBytes_FromStringAndSize(input_str.data(), + input_str.size())); + } +} //////////////////////////////////////////////////////////////////////////////// // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index a00fade7ac..efe50dc247 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -439,4 +440,18 @@ std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, return dims; } +std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( + TF_ImportGraphDefResults* results) { + int num_missing_unused_input_mappings; + const char** src_names; + int* src_indexes; + TF_ImportGraphDefResultsMissingUnusedInputMappings( + results, &num_missing_unused_input_mappings, &src_names, &src_indexes); + std::vector<string> input_strs(num_missing_unused_input_mappings); + for (int i = 0; i < num_missing_unused_input_mappings; ++i) { + input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString(); + } + return input_strs; +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 3a8506de4d..cdb68d2a23 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -187,6 +187,10 @@ std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, int num_dims, TF_Status* status); +// Returns the string representations of the missing unused input mappings. +std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( + TF_ImportGraphDefResults* results); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 62765aff00..d74fb25bb3 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -478,7 +478,17 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, f.add_to_graph(graph) # pylint: enable=protected-access - # TODO(skyewm): error if unused input map key + # Treat input mappings that don't appear in the graph as an error, because + # they are likely to be due to a typo. + missing_unused_input_keys = ( + c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( + results)) + if missing_unused_input_keys: + missing_unused_input_keys = [compat.as_str(s) + for s in missing_unused_input_keys] + raise ValueError( + 'Attempted to map inputs that were not found in graph_def: [%s]' + % ', '.join(missing_unused_input_keys)) if return_elements is None: return None diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 7bf13ba93d..0da651c607 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -570,20 +570,17 @@ class ImportGraphDefTest(test.TestCase): return_elements=["A:B:0"]) def testMissingInputMap(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): - with self.assertRaises(ValueError) as e: + with self.assertRaisesRegexp( + ValueError, + r"Attempted to map inputs that were not found in graph_def: \[B:0\]"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'None' } """), input_map={"B:0": constant_op.constant(5.0)}) - self.assertTrue("not found in graph_def: [B:0]" in str(e.exception)) def testInputMapUnusedAsInput(self): - if ops._USE_C_API: return # TODO(skyewm): make this work with C API - with ops.Graph().as_default(): # Mapping an unused node output should succeed. importer.import_graph_def( @@ -593,13 +590,14 @@ class ImportGraphDefTest(test.TestCase): input_map={"A:0": constant_op.constant(5.0)}) # Mapping a non-existent output of an existing node should fail. - with self.assertRaises(ValueError) as e: + with self.assertRaisesRegexp( + ValueError, + r"Attempted to map inputs that were not found in graph_def: \[A:2\]"): importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'IntOutput' } """), input_map={"A:2": constant_op.constant(5.0)}) - self.assertTrue("not found in graph_def: [A:2]" in str(e.exception)) def testInputMapTypeMismatch(self): if ops._USE_C_API: |