aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-12-12 10:58:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 11:02:23 -0800
commit968da4bf2722b1303cc223e8342357d62c27dfc1 (patch)
treedab12578913f6bfc11b50d92f39ae95fd7301963 /tensorflow/core/graph/graph_constructor.cc
parentc8a5ffdeb2a17df2d2822c7a6df8a44f3ab85614 (diff)
Raise exception on missing unused input_map keys with C API enabled.
Without this change, the C++ ImportGraphDef API returns unused input_map keys (which are plumbed through to the C API as well). However, the Python import_graph_def API requires slightly different semantics: it throws an error for unused input_map keys that are missing from the GraphDef. This change modifies the C and C++ APIs to limit the returned keys to those missing from the GraphDef, and plumbs this through to the C API-enabled import_graph_def implementation. Note that this is a change to the existing C API. Luckily the modified method hasn't been released yet, so it's ok to change it. PiperOrigin-RevId: 178783957
Diffstat (limited to 'tensorflow/core/graph/graph_constructor.cc')
-rw-r--r--tensorflow/core/graph/graph_constructor.cc71
1 files changed, 46 insertions, 25 deletions
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 6e72d73918..e19f4aebba 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -113,20 +113,20 @@ class GraphConstructor {
typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
// versions and library may be nullptr
- static Status Construct(const Options& opts, NodeDefSlice node_defs,
- const VersionDef* versions,
- const FunctionDefLibrary* library, Graph* g,
- ShapeRefiner* refiner,
- std::vector<std::pair<Node*, int>>* return_tensors,
- std::vector<Node*>* return_nodes,
- std::vector<TensorId>* unused_input_map_keys) {
+ static Status Construct(
+ const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
+ const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
+ std::vector<std::pair<Node*, int>>* return_tensors,
+ std::vector<Node*>* return_nodes,
+ std::vector<TensorId>* missing_unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
- return_tensors, return_nodes, unused_input_map_keys);
+ return_tensors, return_nodes,
+ missing_unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
@@ -139,7 +139,7 @@ class GraphConstructor {
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* unused_input_map_keys)
+ std::vector<TensorId>* missing_unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
versions_(versions),
@@ -150,7 +150,7 @@ class GraphConstructor {
refiner_(refiner),
return_tensors_(return_tensors),
return_nodes_(return_nodes),
- unused_input_map_keys_(unused_input_map_keys) {}
+ missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
Status TryImport() {
TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
@@ -162,6 +162,7 @@ class GraphConstructor {
TF_RETURN_IF_ERROR(UpdateVersionDef());
TF_RETURN_IF_ERROR(PopulateReturnTensors());
TF_RETURN_IF_ERROR(PopulateReturnNodes());
+ TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
UpdateUniquifiedColocationNames();
FixupSourceAndSinkEdges(g_);
return Status::OK();
@@ -176,6 +177,7 @@ class GraphConstructor {
Status UpdateVersionDef();
Status PopulateReturnTensors();
Status PopulateReturnNodes();
+ Status PopulateMissingUnusedInputMapKeys();
void Undo();
@@ -242,9 +244,10 @@ class GraphConstructor {
std::vector<Node*>* return_nodes_;
// May be null. Not owned.
- std::vector<TensorId>* unused_input_map_keys_;
+ std::vector<TensorId>* missing_unused_input_map_keys_;
- // Intermediate datastructure used to populate `unused_input_map_keys_`.
+ // Intermediate datastructure used to populate
+ // `missing_unused_input_map_keys_`.
std::set<TensorId> used_input_map_keys_;
// Mapping from node name to the index within node_defs_.
@@ -1024,15 +1027,6 @@ Status GraphConstructor::Convert() {
" nodes in a cycle");
}
- // Update unused_input_map_keys_
- if (unused_input_map_keys_ != nullptr) {
- for (const auto& pair : opts_.input_map) {
- if (used_input_map_keys_.find(pair.first) == used_input_map_keys_.end()) {
- unused_input_map_keys_->push_back(pair.first);
- }
- }
- }
-
return Status::OK();
}
@@ -1122,6 +1116,33 @@ Status GraphConstructor::PopulateReturnNodes() {
return Status::OK();
}
+Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
+ if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
+ for (const auto& input_map_pair : opts_.input_map) {
+ TensorId key = input_map_pair.first;
+ if (used_input_map_keys_.count(key) > 0) continue;
+
+ auto pair = gdef_nodes_.find(key.first);
+ if (pair == gdef_nodes_.end()) {
+ // key's node doesn't exist in GraphDef
+ missing_unused_input_map_keys_->push_back(key);
+ continue;
+ }
+
+ // Check that key's index is in bounds. Get the number of outputs from the
+ // NodeDef, rather than the imported Node, since the Node may not exist if
+ // opts_.skip_mapped_nodes is true.
+ const NodeDef* node_def = node_defs_[pair->second.gdef_index];
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
+ if (key.second >= op_def->output_arg_size()) {
+ // key's index out of bounds
+ missing_unused_input_map_keys_->push_back(key);
+ }
+ }
+ return Status::OK();
+}
+
void GraphConstructor::Undo() {
for (const auto& iter : gdef_nodes_) {
if (iter.second.node != nullptr) {
@@ -1153,7 +1174,7 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
- /*unused_input_map_keys=*/nullptr);
+ /*missing_unused_input_map_keys=*/nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
@@ -1167,7 +1188,7 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, /*return_tensors=*/nullptr,
/*return_nodes=*/nullptr,
- /*unused_input_map_keys=*/nullptr);
+ /*missing_unused_input_map_keys=*/nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
@@ -1196,7 +1217,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
if (results != nullptr) {
if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
- !results->unused_input_map_keys.empty()) {
+ !results->missing_unused_input_map_keys.empty()) {
return errors::InvalidArgument(
"All fields in results argument to ImportGraphDef() must be empty.");
}
@@ -1239,7 +1260,7 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
&results->return_tensors, &results->return_nodes,
- &results->unused_input_map_keys);
+ &results->missing_unused_input_map_keys);
}
}