diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-12-12 10:58:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-12 11:02:23 -0800 |
commit | 968da4bf2722b1303cc223e8342357d62c27dfc1 (patch) | |
tree | dab12578913f6bfc11b50d92f39ae95fd7301963 /tensorflow/core/graph/graph_constructor.cc | |
parent | c8a5ffdeb2a17df2d2822c7a6df8a44f3ab85614 (diff) |
Raise exception on missing unused input_map keys with C API enabled.
Without this change, the C++ ImportGraphDef API returns unused
input_map keys (which are plumbed through to the C API as
well). However, the Python import_graph_def API requires slightly
different semantics: it throws an error for unused input_map keys that
are missing from the GraphDef.
This change modifies the C and C++ APIs to limit the returned keys to
those missing from the GraphDef, and plumbs this through to the C
API-enabled import_graph_def implementation.
Note that this is a change to the existing C API. Luckily the modified
method hasn't been released yet, so it's ok to change it.
PiperOrigin-RevId: 178783957
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.cc')
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 71 |
1 files changed, 46 insertions, 25 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 6e72d73918..e19f4aebba 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -113,20 +113,20 @@ class GraphConstructor { typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; // versions and library may be nullptr - static Status Construct(const Options& opts, NodeDefSlice node_defs, - const VersionDef* versions, - const FunctionDefLibrary* library, Graph* g, - ShapeRefiner* refiner, - std::vector<std::pair<Node*, int>>* return_tensors, - std::vector<Node*>* return_nodes, - std::vector<TensorId>* unused_input_map_keys) { + static Status Construct( + const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, + const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, + std::vector<std::pair<Node*, int>>* return_tensors, + std::vector<Node*>* return_nodes, + std::vector<TensorId>* missing_unused_input_map_keys) { if (versions) { TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph")); } GraphConstructor c(opts, node_defs, versions, library, g, refiner, - return_tensors, return_nodes, unused_input_map_keys); + return_tensors, return_nodes, + missing_unused_input_map_keys); const Status s = c.TryImport(); if (!s.ok()) c.Undo(); return s; @@ -139,7 +139,7 @@ class GraphConstructor { ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, std::vector<Node*>* return_nodes, - std::vector<TensorId>* unused_input_map_keys) + std::vector<TensorId>* missing_unused_input_map_keys) : opts_(opts), node_defs_(node_defs), versions_(versions), @@ -150,7 +150,7 @@ class GraphConstructor { refiner_(refiner), return_tensors_(return_tensors), return_nodes_(return_nodes), - unused_input_map_keys_(unused_input_map_keys) {} + missing_unused_input_map_keys_(missing_unused_input_map_keys) {} Status TryImport() { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); @@ -162,6 +162,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(UpdateVersionDef()); TF_RETURN_IF_ERROR(PopulateReturnTensors()); TF_RETURN_IF_ERROR(PopulateReturnNodes()); + TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); UpdateUniquifiedColocationNames(); FixupSourceAndSinkEdges(g_); return Status::OK(); @@ -176,6 +177,7 @@ class GraphConstructor { Status UpdateVersionDef(); Status PopulateReturnTensors(); Status PopulateReturnNodes(); + Status PopulateMissingUnusedInputMapKeys(); void Undo(); @@ -242,9 +244,10 @@ class GraphConstructor { std::vector<Node*>* return_nodes_; // May be null. Not owned. - std::vector<TensorId>* unused_input_map_keys_; + std::vector<TensorId>* missing_unused_input_map_keys_; - // Intermediate datastructure used to populate `unused_input_map_keys_`. + // Intermediate datastructure used to populate + // `missing_unused_input_map_keys_`. std::set<TensorId> used_input_map_keys_; // Mapping from node name to the index within node_defs_. @@ -1024,15 +1027,6 @@ Status GraphConstructor::Convert() { " nodes in a cycle"); } - // Update unused_input_map_keys_ - if (unused_input_map_keys_ != nullptr) { - for (const auto& pair : opts_.input_map) { - if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) { - unused_input_map_keys_->push_back(pair.first); - } - } - } - return Status::OK(); } @@ -1122,6 +1116,33 @@ Status GraphConstructor::PopulateReturnNodes() { return Status::OK(); } +Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { + if (missing_unused_input_map_keys_ == nullptr) return Status::OK(); + for (const auto& input_map_pair : opts_.input_map) { + TensorId key = input_map_pair.first; + if (used_input_map_keys_.count(key) > 0) continue; + + auto pair = gdef_nodes_.find(key.first); + if (pair == gdef_nodes_.end()) { + // key's node doesn't exist in GraphDef + missing_unused_input_map_keys_->push_back(key); + continue; + } + + // Check that key's index is in bounds. Get the number of outputs from the + // NodeDef, rather than the imported Node, since the Node may not exist if + // opts_.skip_mapped_nodes is true. + const NodeDef* node_def = node_defs_[pair->second.gdef_index]; + const OpDef* op_def; + TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); + if (key.second >= op_def->output_arg_size()) { + // key's index out of bounds + missing_unused_input_map_keys_->push_back(key); + } + } + return Status::OK(); +} + void GraphConstructor::Undo() { for (const auto& iter : gdef_nodes_) { if (iter.second.node != nullptr) { @@ -1153,7 +1174,7 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, - /*unused_input_map_keys=*/nullptr); + /*missing_unused_input_map_keys=*/nullptr); } Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, @@ -1167,7 +1188,7 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, &refiner, /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, - /*unused_input_map_keys=*/nullptr); + /*missing_unused_input_map_keys=*/nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, @@ -1196,7 +1217,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, if (results != nullptr) { if (!results->return_tensors.empty() || !results->return_nodes.empty() || - !results->unused_input_map_keys.empty()) { + !results->missing_unused_input_map_keys.empty()) { return errors::InvalidArgument( "All fields in results argument to ImportGraphDef() must be empty."); } @@ -1239,7 +1260,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, &results->return_tensors, &results->return_nodes, - &results->unused_input_map_keys); + &results->missing_unused_input_map_keys); } } |