aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-12 10:58:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 11:02:23 -0800
commit968da4bf2722b1303cc223e8342357d62c27dfc1 (patch)
treedab12578913f6bfc11b50d92f39ae95fd7301963
parentc8a5ffdeb2a17df2d2822c7a6df8a44f3ab85614 (diff)
Raise exception on missing unused input_map keys with C API enabled.
Without this change, the C++ ImportGraphDef API returns unused input_map keys (which are plumbed through to the C API as well). However, the Python import_graph_def API requires slightly different semantics: it throws an error for unused input_map keys that are missing from the GraphDef. This change modifies the C and C++ APIs to limit the returned keys to those missing from the GraphDef, and plumbs this through to the C API-enabled import_graph_def implementation. Note that this is a change to the existing C API. Luckily the modified method hasn't been released yet, so it's ok to change it. PiperOrigin-RevId: 178783957
-rw-r--r--tensorflow/c/c_api.cc37
-rw-r--r--tensorflow/c/c_api.h14
-rw-r--r--tensorflow/c/c_api_internal.h8
-rw-r--r--tensorflow/c/c_api_test.cc4
-rw-r--r--tensorflow/core/graph/graph_constructor.cc71
-rw-r--r--tensorflow/core/graph/graph_constructor.h7
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc54
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/tf_session.i17
-rw-r--r--tensorflow/python/client/tf_session_helper.cc15
-rw-r--r--tensorflow/python/client/tf_session_helper.h4
-rw-r--r--tensorflow/python/framework/importer.py12
-rw-r--r--tensorflow/python/framework/importer_test.py14
13 files changed, 179 insertions, 79 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 13253ced49..6f5abd074c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1917,12 +1917,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
*opers = results->return_nodes.data();
}
-void TF_ImportGraphDefResultsUnusedInputMappings(
- TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
+void TF_ImportGraphDefResultsMissingUnusedInputMappings(
+ TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
const char*** src_names, int** src_indexes) {
- *num_unused_input_mappings = results->unused_key_names.size();
- *src_names = results->unused_key_names.data();
- *src_indexes = results->unused_key_indexes.data();
+ *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
+ *src_names = results->missing_unused_key_names.data();
+ *src_indexes = results->missing_unused_key_indexes.data();
}
void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
@@ -1962,18 +1962,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
}
- // Populate unused map keys
- DCHECK(tf_results->unused_key_names.empty());
- DCHECK(tf_results->unused_key_indexes.empty());
- DCHECK(tf_results->unused_key_names_data.empty());
- tf_results->unused_key_names.resize(results.unused_input_map_keys.size());
- tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size());
- for (int i = 0; i < results.unused_input_map_keys.size(); ++i) {
- TensorId id = results.unused_input_map_keys[i];
- tf_results->unused_key_names_data.push_back(id.first.ToString());
- tf_results->unused_key_names[i] =
- tf_results->unused_key_names_data.back().c_str();
- tf_results->unused_key_indexes[i] = id.second;
+ // Populate missing unused map keys
+ DCHECK(tf_results->missing_unused_key_names.empty());
+ DCHECK(tf_results->missing_unused_key_indexes.empty());
+ DCHECK(tf_results->missing_unused_key_names_data.empty());
+
+ size_t size = results.missing_unused_input_map_keys.size();
+ tf_results->missing_unused_key_names.resize(size);
+ tf_results->missing_unused_key_indexes.resize(size);
+
+ for (int i = 0; i < size; ++i) {
+ TensorId id = results.missing_unused_input_map_keys[i];
+ tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
+ tf_results->missing_unused_key_names[i] =
+ tf_results->missing_unused_key_names_data.back().c_str();
+ tf_results->missing_unused_key_indexes[i] = id.second;
}
}
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index df7fe222b1..de9527f86d 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -962,16 +962,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations(
TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers);
// Fetches any input mappings requested via
-// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any
-// node in the imported graph def. The number of fetched mappings is returned in
-// `num_unused_input_mappings`. The array of each mapping's source node name is
-// returned in `src_names`, and the array of each mapping's source index is
-// returned in `src_indexes`.
+// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef
+// and weren't used as input to any node in the imported graph def. The number
+// of fetched mappings is returned in `num_missing_unused_input_mappings`. The
+// array of each mapping's source node name is returned in `src_names`, and the
+// array of each mapping's source index is returned in `src_indexes`.
//
// `*src_names`, `*src_indexes`, and the memory backing each string in
// `src_names` are owned by and have the lifetime of `results`.
-TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings(
- TF_ImportGraphDefResults* results, int* num_unused_input_mappings,
+TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings(
+ TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
const char*** src_names, int** src_indexes);
// Deletes a results object returned by TF_GraphImportGraphDefWithResults().
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index aac333d9e2..6df77a7f9b 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -143,11 +143,11 @@ struct TF_ImportGraphDefOptions {
struct TF_ImportGraphDefResults {
std::vector<TF_Output> return_tensors;
std::vector<TF_Operation*> return_nodes;
- std::vector<const char*> unused_key_names;
- std::vector<int> unused_key_indexes;
+ std::vector<const char*> missing_unused_key_names;
+ std::vector<int> missing_unused_key_indexes;
- // Backing memory for unused_key_names values.
- std::list<tensorflow::string> unused_key_names_data;
+ // Backing memory for missing_unused_key_names values.
+ std::list<tensorflow::string> missing_unused_key_names_data;
};
struct TF_DeviceList {
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 6ec1db8ccf..4e89b4fc43 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -773,7 +773,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
TF_DeleteStatus(s);
}
-TEST(CAPI, ImportGraphDef_UnusedInputMappings) {
+TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@@ -816,7 +816,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) {
int num_unused_input_mappings;
const char** src_names;
int* src_indexes;
- TF_ImportGraphDefResultsUnusedInputMappings(
+ TF_ImportGraphDefResultsMissingUnusedInputMappings(
results, &num_unused_input_mappings, &src_names, &src_indexes);
ASSERT_EQ(1, num_unused_input_mappings);
EXPECT_EQ(string("fake"), string(src_names[0]));
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 6e72d73918..e19f4aebba 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -113,20 +113,20 @@ class GraphConstructor {
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
// versions and library may be nullptr
- static Status Construct(const Options& opts, NodeDefSlice node_defs,
- const VersionDef* versions,
- const FunctionDefLibrary* library, Graph* g,
- ShapeRefiner* refiner,
- std::vector<std::pair<Node*, int>>* return_tensors,
- std::vector<Node*>* return_nodes,
- std::vector<TensorId>* unused_input_map_keys) {
+ static Status Construct(
+ const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
+ const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
+ std::vector<std::pair<Node*, int>>* return_tensors,
+ std::vector<Node*>* return_nodes,
+ std::vector<TensorId>* missing_unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
- return_tensors, return_nodes, unused_input_map_keys);
+ return_tensors, return_nodes,
+ missing_unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
@@ -139,7 +139,7 @@ class GraphConstructor {
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* unused_input_map_keys)
+ std::vector<TensorId>* missing_unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
versions_(versions),
@@ -150,7 +150,7 @@ class GraphConstructor {
refiner_(refiner),
return_tensors_(return_tensors),
return_nodes_(return_nodes),
- unused_input_map_keys_(unused_input_map_keys) {}
+ missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
Status TryImport() {
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
@@ -162,6 +162,7 @@ class GraphConstructor {
TF_RETURN_IF_ERROR(UpdateVersionDef());
TF_RETURN_IF_ERROR(PopulateReturnTensors());
TF_RETURN_IF_ERROR(PopulateReturnNodes());
+ TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
UpdateUniquifiedColocationNames();
FixupSourceAndSinkEdges(g_);
return Status::OK();
@@ -176,6 +177,7 @@ class GraphConstructor {
Status UpdateVersionDef();
Status PopulateReturnTensors();
Status PopulateReturnNodes();
+ Status PopulateMissingUnusedInputMapKeys();
void Undo();
@@ -242,9 +244,10 @@ class GraphConstructor {
std::vector<Node*>* return_nodes_;
// May be null. Not owned.
- std::vector<TensorId>* unused_input_map_keys_;
+ std::vector<TensorId>* missing_unused_input_map_keys_;
- // Intermediate datastructure used to populate `unused_input_map_keys_`.
+ // Intermediate datastructure used to populate
+ // `missing_unused_input_map_keys_`.
std::set<TensorId> used_input_map_keys_;
// Mapping from node name to the index within node_defs_.
@@ -1024,15 +1027,6 @@ Status GraphConstructor::Convert() {
" nodes in a cycle");
}
- // Update unused_input_map_keys_
- if (unused_input_map_keys_ != nullptr) {
- for (const auto& pair : opts_.input_map) {
- if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) {
- unused_input_map_keys_->push_back(pair.first);
- }
- }
- }
-
return Status::OK();
}
@@ -1122,6 +1116,33 @@ Status GraphConstructor::PopulateReturnNodes() {
return Status::OK();
}
+Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
+ if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
+ for (const auto& input_map_pair : opts_.input_map) {
+ TensorId key = input_map_pair.first;
+ if (used_input_map_keys_.count(key) > 0) continue;
+
+ auto pair = gdef_nodes_.find(key.first);
+ if (pair == gdef_nodes_.end()) {
+ // key's node doesn't exist in GraphDef
+ missing_unused_input_map_keys_->push_back(key);
+ continue;
+ }
+
+ // Check that key's index is in bounds. Get the number of outputs from the
+ // NodeDef, rather than the imported Node, since the Node may not exist if
+ // opts_.skip_mapped_nodes is true.
+ const NodeDef* node_def = node_defs_[pair->second.gdef_index];
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
+ if (key.second >= op_def->output_arg_size()) {
+ // key's index out of bounds
+ missing_unused_input_map_keys_->push_back(key);
+ }
+ }
+ return Status::OK();
+}
+
void GraphConstructor::Undo() {
for (const auto& iter : gdef_nodes_) {
if (iter.second.node != nullptr) {
@@ -1153,7 +1174,7 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
- /*unused_input_map_keys=*/nullptr);
+ /*missing_unused_input_map_keys=*/nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
@@ -1167,7 +1188,7 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, /*return_tensors=*/nullptr,
/*return_nodes=*/nullptr,
- /*unused_input_map_keys=*/nullptr);
+ /*missing_unused_input_map_keys=*/nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
@@ -1196,7 +1217,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
if (results != nullptr) {
if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
- !results->unused_input_map_keys.empty()) {
+ !results->missing_unused_input_map_keys.empty()) {
return errors::InvalidArgument(
"All fields in results argument to ImportGraphDef() must be empty.");
}
@@ -1239,7 +1260,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
&results->return_tensors, &results->return_nodes,
- &results->unused_input_map_keys);
+ &results->missing_unused_input_map_keys);
}
}
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index b4dd2ba51a..07814b2ef7 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -148,9 +148,10 @@ struct ImportGraphDefResults {
// The requested nodes associated with ImportGraphDefOptions::return_nodes.
std::vector<Node*> return_nodes;
- // Keys in ImportGraphDefOptions::input_map that weren't used as an input to
- // any node in`gdef`.
- std::vector<TensorId> unused_input_map_keys;
+ // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
+ // weren't used as an input to any node in `gdef`. These keys are likely due
+ // to typos, and callers may wish to treat their existence as an error.
+ std::vector<TensorId> missing_unused_input_map_keys;
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 9be3de2388..01bb1ac748 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1433,7 +1433,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
&refiner);
}
-TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
+TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
// No input map
@@ -1443,10 +1443,10 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }",
opts, &refiner, &results);
- EXPECT_TRUE(results.unused_input_map_keys.empty());
+ EXPECT_TRUE(results.missing_unused_input_map_keys.empty());
- // Non-empty unused_input_map_keys
- results.unused_input_map_keys.push_back(TensorId());
+ // Non-empty missing_unused_input_map_keys
+ results.missing_unused_input_map_keys.push_back(TensorId());
ExpectError(
"node { name: 'W2' op: 'TestParams' }", opts,
{"All fields in results argument to ImportGraphDef() must be empty."},
@@ -1454,13 +1454,16 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
// Input map with some used, some unused keys
const int kControlSlot = Graph::kControlSlot;
- results.unused_input_map_keys.clear();
+ results.missing_unused_input_map_keys.clear();
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
- opts.input_map[TensorId("new_input", kControlSlot)] =
- TensorId("input", kControlSlot);
- opts.input_map[TensorId("t1", 1)] = TensorId("input", 0);
+ // Unused and missing (nonexistent index)
+ opts.input_map[TensorId("new_input", 3)] = TensorId("input", 0);
+ // Unused and missing (nonexistent node)
+ opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
@@ -1470,9 +1473,36 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
)EOF",
opts, &refiner, &results);
- std::vector<TensorId> expected_unused_keys = {
- TensorId("new_input", kControlSlot), TensorId("t1", 1)};
- EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys);
+ std::set<TensorId> expected_unused_keys = {TensorId("new_input", 3),
+ TensorId("DNE", 0)};
+ ASSERT_EQ(results.missing_unused_input_map_keys.size(),
+ expected_unused_keys.size());
+
+ std::set<TensorId> actual_unused_keys(
+ results.missing_unused_input_map_keys.begin(),
+ results.missing_unused_input_map_keys.end());
+ EXPECT_EQ(actual_unused_keys, expected_unused_keys);
+
+ // Test edge case: node isn't imported due to skip_mapped_nodes, but we still
+ // have a bad input_map key involving it.
+ opts = ImportGraphDefOptions();
+ opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
+ opts.input_map[TensorId("new_input", 1)] = TensorId("input", 1);
+ // Index out of bounds
+ opts.input_map[TensorId("new_input", 2)] = TensorId("input", 1);
+ opts.skip_mapped_nodes = true;
+ opts.prefix = "import";
+ results = ImportGraphDefResults();
+ ExpectOK(
+ R"EOF(
+ node { name: 'W2' op: 'TestParams' }
+ node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
+ node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
+ )EOF",
+ opts, &refiner, &results);
+
+ ASSERT_EQ(results.missing_unused_input_map_keys.size(), 1);
+ EXPECT_EQ(results.missing_unused_input_map_keys[0], TensorId("new_input", 2));
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) {
@@ -1709,7 +1739,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) {
// Check return tensors
ASSERT_EQ(results.return_nodes.size(), 2);
EXPECT_EQ(results.return_tensors.size(), 0);
- EXPECT_EQ(results.unused_input_map_keys.size(), 0);
+ EXPECT_EQ(results.missing_unused_input_map_keys.size(), 0);
EXPECT_EQ(results.return_nodes[0]->name(), "input");
EXPECT_EQ(results.return_nodes[1]->name(), "t1");
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3566a36ddd..20944d1678 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3018,6 +3018,7 @@ tf_cuda_library(
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/py/numpy:headers",
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index f57c5d73bc..e424e19c77 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -183,6 +183,23 @@ tensorflow::TF_OperationOutputConsumers_wrapper {
}
}
+%ignore TF_ImportGraphDefResultsMissingUnusedInputMappings;
+%unignore TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper;
+
+%typemap(out) std::vector<string>
+TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+ for (size_t i = 0; i < $1.size(); ++i) {
+ const string& input_str = $1[i];
+ PyList_SET_ITEM($result, i, PyBytes_FromStringAndSize(input_str.data(),
+ input_str.size()));
+ }
+}
////////////////////////////////////////////////////////////////////////////////
// BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper()
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index a00fade7ac..efe50dc247 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -439,4 +440,18 @@ std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
return dims;
}
+std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
+ TF_ImportGraphDefResults* results) {
+ int num_missing_unused_input_mappings;
+ const char** src_names;
+ int* src_indexes;
+ TF_ImportGraphDefResultsMissingUnusedInputMappings(
+ results, &num_missing_unused_input_mappings, &src_names, &src_indexes);
+ std::vector<string> input_strs(num_missing_unused_input_mappings);
+ for (int i = 0; i < num_missing_unused_input_mappings; ++i) {
+ input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString();
+ }
+ return input_strs;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 3a8506de4d..cdb68d2a23 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -187,6 +187,10 @@ std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
int num_dims,
TF_Status* status);
+// Returns the string representations of the missing unused input mappings.
+std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
+ TF_ImportGraphDefResults* results);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 62765aff00..d74fb25bb3 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -478,7 +478,17 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
f.add_to_graph(graph)
# pylint: enable=protected-access
- # TODO(skyewm): error if unused input map key
+ # Treat input mappings that don't appear in the graph as an error, because
+ # they are likely to be due to a typo.
+ missing_unused_input_keys = (
+ c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
+ results))
+ if missing_unused_input_keys:
+ missing_unused_input_keys = [compat.as_str(s)
+ for s in missing_unused_input_keys]
+ raise ValueError(
+ 'Attempted to map inputs that were not found in graph_def: [%s]'
+ % ', '.join(missing_unused_input_keys))
if return_elements is None:
return None
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 7bf13ba93d..0da651c607 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -570,20 +570,17 @@ class ImportGraphDefTest(test.TestCase):
return_elements=["A:B:0"])
def testMissingInputMap(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
- with self.assertRaises(ValueError) as e:
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Attempted to map inputs that were not found in graph_def: \[B:0\]"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
input_map={"B:0": constant_op.constant(5.0)})
- self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
def testInputMapUnusedAsInput(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
# Mapping an unused node output should succeed.
importer.import_graph_def(
@@ -593,13 +590,14 @@ class ImportGraphDefTest(test.TestCase):
input_map={"A:0": constant_op.constant(5.0)})
# Mapping a non-existent output of an existing node should fail.
- with self.assertRaises(ValueError) as e:
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Attempted to map inputs that were not found in graph_def: \[A:2\]"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
"""),
input_map={"A:2": constant_op.constant(5.0)})
- self.assertTrue("not found in graph_def: [A:2]" in str(e.exception))
def testInputMapTypeMismatch(self):
if ops._USE_C_API: