aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-16 15:29:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 15:34:07 -0700
commitdc442f4ce2d3b11b56721337fe2b9e2282be93be (patch)
treeee2d7796823a1430bc4c7a9f2dd577204aa28321
parent7b6eec7e1175624458a48945bba3f6400e754d33 (diff)
Add return_nodes option to ImportGraphDef
The is similar to the return_tensors option. return_tensors cannot be used to fetch nodes with no outputs, so return_nodes is necessary. In addition, this change also refactors the ImportGraphDef signature to return all optional return values in a single struct. This is to keep the ImportGraphDef signature from getting too long, and also makes the call sites simpler. PiperOrigin-RevId: 172388270
-rw-r--r--tensorflow/c/c_api.cc18
-rw-r--r--tensorflow/c/while_loop_test.cc6
-rw-r--r--tensorflow/core/graph/graph_constructor.cc73
-rw-r--r--tensorflow/core/graph/graph_constructor.h60
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc191
5 files changed, 240 insertions, 108 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 334f867e47..79fbd8c90c 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -1854,18 +1854,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
return;
}
const int last_node_id = graph->graph.num_node_ids();
- std::vector<std::pair<Node*, int>> return_outputs_vec;
- status->status = tensorflow::ImportGraphDef(
- opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec);
+ tensorflow::ImportGraphDefResults results;
+ status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
+ &graph->refiner, &results);
if (!status->status.ok()) return;
for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
auto* node = graph->graph.FindNodeId(i);
if (node != nullptr) graph->name_map[node->name()] = node;
}
- DCHECK_EQ(return_outputs_vec.size(), num_return_outputs);
+ DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
for (int i = 0; i < num_return_outputs; ++i) {
- return_outputs[i].oper = ToOperation(return_outputs_vec[i].first);
- return_outputs[i].index = return_outputs_vec[i].second;
+ return_outputs[i].oper = ToOperation(results.return_tensors[i].first);
+ return_outputs[i].index = results.return_tensors[i].second;
}
}
@@ -1945,11 +1945,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
}
// TOOD(skyewm): change to OutputTensor
- std::vector<std::pair<Node*, int>> return_tensors;
+ tensorflow::ImportGraphDefResults results;
TF_RETURN_IF_ERROR(
- ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors));
+ ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
- for (const auto& pair : return_tensors) {
+ for (const auto& pair : results.return_tensors) {
return_nodes->emplace_back(pair.first, pair.second);
}
return Status::OK();
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index 2423d83dda..d2d887f32c 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -318,7 +318,7 @@ TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
// TODO(skyewm): this error message could be more informative. Add explicit
// checks for this case in the while loop implementation?
ExpectError(TF_INVALID_ARGUMENT,
- "Requested return node 'p0' not found in graph def");
+ "Requested return tensor 'p0:0' not found in graph def");
}
TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
@@ -358,7 +358,7 @@ TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
// TODO(skyewm): this error message could be more informative. Add explicit
// checks for this case in the while loop implementation?
ExpectError(TF_INVALID_ARGUMENT,
- "Requested return node 'p0' not found in graph def");
+ "Requested return tensor 'p0:0' not found in graph def");
}
// TODO(skyewm): enable this when it works (currently segfaults!)
@@ -389,7 +389,7 @@ TEST_F(CApiWhileLoopTest, WrongGraph) {
params_->body_outputs[0] = inputs_[0];
// TODO(skyewm): improve error message
ExpectError(TF_INVALID_ARGUMENT,
- "Requested return node 'p0' not found in graph def");
+ "Requested return tensor 'p0:0' not found in graph def");
}
TEST_F(CApiWhileLoopTest, BadTypes) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 15f7b9fe8c..92b4843221 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -79,6 +79,7 @@ class GraphConstructor {
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
return_tensors(in.return_tensors),
+ return_nodes(in.return_nodes),
importing(true) {}
bool allow_internal_ops;
@@ -89,6 +90,7 @@ class GraphConstructor {
bool skip_mapped_nodes;
std::vector<string> control_dependencies;
std::vector<TensorId> return_tensors;
+ std::vector<StringPiece> return_nodes;
// TODO(ashankar): This bool exists to separate out functionality required
// to make ImportGraphDef a close equivalent of Python's import_graph_def
@@ -109,6 +111,7 @@ class GraphConstructor {
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
+ std::vector<Node*>* return_nodes,
std::vector<TensorId>* unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
@@ -116,7 +119,7 @@ class GraphConstructor {
"GraphDef", "graph"));
}
GraphConstructor c(opts, node_defs, versions, library, g, refiner,
- return_tensors, unused_input_map_keys);
+ return_tensors, return_nodes, unused_input_map_keys);
const Status s = c.TryImport();
if (!s.ok()) c.Undo();
return s;
@@ -128,6 +131,7 @@ class GraphConstructor {
const FunctionDefLibrary* library, Graph* g,
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
+ std::vector<Node*>* return_nodes,
std::vector<TensorId>* unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
@@ -137,6 +141,7 @@ class GraphConstructor {
original_versions_(g->versions()),
refiner_(refiner),
return_tensors_(return_tensors),
+ return_nodes_(return_nodes),
unused_input_map_keys_(unused_input_map_keys) {}
Status TryImport() {
@@ -148,6 +153,7 @@ class GraphConstructor {
TF_RETURN_IF_ERROR(AddBackEdges());
TF_RETURN_IF_ERROR(UpdateVersionDef());
TF_RETURN_IF_ERROR(PopulateReturnTensors());
+ TF_RETURN_IF_ERROR(PopulateReturnNodes());
FixupSourceAndSinkEdges(g_);
return Status::OK();
}
@@ -160,6 +166,7 @@ class GraphConstructor {
Status AddBackEdges();
Status UpdateVersionDef();
Status PopulateReturnTensors();
+ Status PopulateReturnNodes();
void Undo();
@@ -197,6 +204,9 @@ class GraphConstructor {
std::vector<std::pair<Node*, int>>* return_tensors_;
// May be null. Not owned.
+ std::vector<Node*>* return_nodes_;
+
+ // May be null. Not owned.
std::vector<TensorId>* unused_input_map_keys_;
// Intermediate datastructure used to populate `unused_input_map_keys_`.
@@ -913,7 +923,8 @@ Status GraphConstructor::PopulateReturnTensors() {
// Locate id in imported nodes
auto iter = gdef_nodes_.find(id.first);
if (iter == gdef_nodes_.end()) {
- return errors::InvalidArgument("Requested return node '", id.first,
+ return errors::InvalidArgument("Requested return tensor '",
+ id.ToString(),
"' not found in graph def");
}
int num_outputs = iter->second.node->num_outputs();
@@ -935,6 +946,19 @@ Status GraphConstructor::PopulateReturnTensors() {
return Status::OK();
}
+Status GraphConstructor::PopulateReturnNodes() {
+ if (opts_.return_nodes.empty()) return Status::OK();
+ for (StringPiece name : opts_.return_nodes) {
+ auto iter = gdef_nodes_.find(name);
+ if (iter == gdef_nodes_.end()) {
+ return errors::InvalidArgument("Requested return node '", name,
+ "' not found in graph def");
+ }
+ return_nodes_->push_back(iter->second.node);
+ }
+ return Status::OK();
+}
+
void GraphConstructor::Undo() {
for (const auto& iter : gdef_nodes_) {
if (iter.second.node != nullptr) {
@@ -965,7 +989,8 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
- /*return_tensors=*/nullptr, /*unused_input_map_keys=*/nullptr);
+ /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
+ /*unused_input_map_keys=*/nullptr);
}
Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
@@ -978,31 +1003,40 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
}
return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
&refiner, /*return_tensors=*/nullptr,
+ /*return_nodes=*/nullptr,
/*unused_input_map_keys=*/nullptr);
}
Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
Graph* g, ShapeRefiner* refiner,
- std::vector<std::pair<Node*, int>>* return_tensors,
- std::vector<TensorId>* unused_input_map_keys) {
+ ImportGraphDefResults* results) {
if (!opts.return_tensors.empty()) {
- if (return_tensors == nullptr) {
+ if (results == nullptr) {
return errors::InvalidArgument(
- "return_tensors argument to ImportGraphDef() must be non-null if "
+ "results argument to ImportGraphDef() must be non-null if "
"opts.return_tensors is non-empty");
}
- if (!return_tensors->empty()) {
+ }
+
+ if (!opts.return_nodes.empty()) {
+ if (opts.skip_mapped_nodes) {
+ return errors::InvalidArgument(
+ "Requesting return_nodes with skip_mapped_nodes set is not currently "
+ "supported");
+ }
+ if (results == nullptr) {
return errors::InvalidArgument(
- "return_tensors argument to ImportGraphDef() should be empty (has "
- "size ",
- return_tensors->size(), ")");
+ "results argument to ImportGraphDef() must be non-null if "
+ "opts.return_nodes is non-empty");
}
}
- if (unused_input_map_keys != nullptr && !unused_input_map_keys->empty()) {
- return errors::InvalidArgument(
- "If non-null, unused_input_map_keys argument to ImportGraphDef() should"
- " be empty (has size ",
- unused_input_map_keys->size(), ")");
+
+ if (results != nullptr) {
+ if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
+ !results->unused_input_map_keys.empty()) {
+ return errors::InvalidArgument(
+ "All fields in results argument to ImportGraphDef() must be empty.");
+ }
}
ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
@@ -1034,9 +1068,10 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer()));
- return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
- &gdef.library(), g, refiner,
- return_tensors, unused_input_map_keys);
+ return GraphConstructor::Construct(
+ opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
+ &results->return_tensors, &results->return_nodes,
+ &results->unused_input_map_keys);
}
void CopyGraph(const Graph& src, Graph* dest) {
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index a8f9f2b245..6cd9347d96 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -72,8 +72,6 @@ struct ImportGraphDefOptions {
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
- //
- // TODO(skyewm): add functionality to retrieve unused `input_map` keys
std::map<TensorId, TensorId> input_map;
// If true, nodes that will have all output edges removed because of
@@ -88,10 +86,10 @@ struct ImportGraphDefOptions {
// other nodes in `gdef`.
std::vector<string> control_dependencies;
- // Tensors in `gdef` that will be returned via the `return_tensors` output
- // parameter of `ImportGraphDef()`. If this list is non-empty, the caller must
- // pass an empty vector to `ImportGraphDef()`. The vector will be populated
- // with the imported nodes in `g`.
+ // Tensors in `gdef` that will be returned via the ImportGraphDefResults
+ // output parameter of `ImportGraphDef()`. If this list is non-empty, the
+ // caller must pass a results object to `ImportGraphDef()`. The
+ // `return_tensors` field will be populated with the imported nodes in `g`.
//
// Entries should not include `prefix`, i.e., each TensorId's name should be
// the name as it originally appears in `gdef`.
@@ -100,12 +98,43 @@ struct ImportGraphDefOptions {
// corresponding existing tensor in `g` will be returned.
std::vector<TensorId> return_tensors;
+ // The names of nodes in `gdef` that will be returned via the
+ // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
+ // is non-empty, the caller must pass a results object to
+ // `ImportGraphDef()`. The `return_nodes` field will be populated with the
+ // imported nodes in `g`.
+ //
+ // Entries should not include `prefix`, i.e., each node's name should be the
+ // name as it originally appears in `gdef`.
+ //
+ // Unlike `return_tensors`, `input_map` has no effect on the nodes
+ // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
+ // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
+ std::vector<StringPiece> return_nodes;
+
// TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
// with ops that are not defined in the binary calling ImportGraphDef.
// Similar to the producer_op_list argument to import_graph_def in the
// python API.
};
+// Optional results that may be returned by ImportGraphDef.
+struct ImportGraphDefResults {
+ // The requested tensors associated with
+ // ImportGraphDefOptions::return_tensors. Note that the index may be different
+ // than the requested index if the returned tensor has been remapped according
+ // to `input_map`.
+ typedef int Index;
+ std::vector<std::pair<Node*, Index>> return_tensors;
+
+ // The requested nodes associated with ImportGraphDefOptions::return_nodes.
+ std::vector<Node*> return_nodes;
+
+ // Keys in ImportGraphDefOptions::input_map that weren't used as an input to
+ // any node in`gdef`.
+ std::vector<TensorId> unused_input_map_keys;
+};
+
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
//
// On error, returns non-OK and leaves `*g` unmodified.
@@ -115,21 +144,16 @@ struct ImportGraphDefOptions {
// allows the caller to validate shapes of those nodes (since
// ShapeRefiner::AddNode must be called in topological order).
//
-// Each `return_tensors` entry is the requested node and output index. The index
-// is included in case the returned tensor has been remapped according to
-// `input_map`.
-//
-// If `unused_input_map_keys` is non-null, it should be empty and will be
-// populated with any keys in `opts.input_map` that aren't used as an input to
-// any node in `gdef`.
+// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
+// non-empty. It can also be set to fetch the unused input map keys. If it's
+// non-null, all the vector fields must be empty.
//
// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
// as a means of enhancing an existing Graph.
-extern Status ImportGraphDef(
- const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g,
- ShapeRefiner* refiner,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr);
+extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
+ const GraphDef& gdef, Graph* g,
+ ShapeRefiner* refiner,
+ ImportGraphDefResults* results = nullptr);
// Make a copy of "src" into "*dest".
//
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index f88d707ec5..5242c56ce6 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -71,14 +71,12 @@ class GraphConstructorTest : public ::testing::Test {
void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts,
const std::vector<string>& expected_error_strs,
ShapeRefiner* refiner = nullptr,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr) {
+ ImportGraphDefResults* results = nullptr) {
// Used to verify that errors don't change graph
const string original_graph_description = GraphDebugString();
Convert(gdef_ascii);
- Status status = ImportGraphDef(opts, gdef_, &graph_, refiner,
- return_tensors, unused_input_map_keys);
+ Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, results);
EXPECT_FALSE(status.ok());
for (const string& error : expected_error_strs) {
@@ -97,11 +95,9 @@ class GraphConstructorTest : public ::testing::Test {
void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts,
ShapeRefiner* refiner = nullptr,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr) {
+ ImportGraphDefResults* results = nullptr) {
Convert(gdef_ascii);
- Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors,
- unused_input_map_keys);
+ Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results);
EXPECT_EQ(Status::OK(), s) << s;
}
@@ -1440,26 +1436,25 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
- std::vector<TensorId> unused_input_map_keys;
-
// No input map
ImportGraphDefOptions opts;
+ ImportGraphDefResults results;
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }",
- opts, &refiner, nullptr, &unused_input_map_keys);
- EXPECT_TRUE(unused_input_map_keys.empty());
+ opts, &refiner, &results);
+ EXPECT_TRUE(results.unused_input_map_keys.empty());
// Non-empty unused_input_map_keys
- unused_input_map_keys.push_back(TensorId());
- ExpectError("node { name: 'W2' op: 'TestParams' }", opts,
- {"If non-null, unused_input_map_keys argument to ImportGraphDef()"
- " should be empty (has size 1)"},
- &refiner, nullptr, &unused_input_map_keys);
+ results.unused_input_map_keys.push_back(TensorId());
+ ExpectError(
+ "node { name: 'W2' op: 'TestParams' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ &refiner, &results);
// Input map with some used, some unused keys
const int kControlSlot = Graph::kControlSlot;
- unused_input_map_keys.clear();
+ results.unused_input_map_keys.clear();
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
@@ -1473,11 +1468,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
)EOF",
- opts, &refiner, nullptr, &unused_input_map_keys);
+ opts, &refiner, &results);
std::vector<TensorId> expected_unused_keys = {
TensorId("new_input", kControlSlot), TensorId("t1", 1)};
- EXPECT_EQ(unused_input_map_keys, expected_unused_keys);
+ EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys);
}
TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) {
@@ -1567,11 +1562,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
opts.return_tensors.push_back({"input", 1});
opts.return_tensors.push_back({"t1", 0});
opts.return_tensors.push_back({"input", 0});
- std::vector<std::pair<Node*, int>> return_tensors;
+ ImportGraphDefResults results;
ExpectOK(
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: ['input:0', 'input:1'] }",
- opts, &refiner, &return_tensors);
+ opts, &refiner, &results);
// Sanity checks
EXPECT_TRUE(HasNode("input"));
@@ -1580,74 +1575,70 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
EXPECT_TRUE(HasEdge("input", 1, "t1", 1));
// Check return tensors
- ASSERT_EQ(return_tensors.size(), 3);
- EXPECT_EQ(return_tensors[0].first->name(), "input");
- EXPECT_EQ(return_tensors[0].second, 1);
- EXPECT_EQ(return_tensors[1].first->name(), "t1");
- EXPECT_EQ(return_tensors[1].second, 0);
- EXPECT_EQ(return_tensors[2].first->name(), "input");
- EXPECT_EQ(return_tensors[2].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 3);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[0].second, 1);
+ EXPECT_EQ(results.return_tensors[1].first->name(), "t1");
+ EXPECT_EQ(results.return_tensors[1].second, 0);
+ EXPECT_EQ(results.return_tensors[2].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[2].second, 0);
// Test using prefix and returning element from input_map
opts.return_tensors.clear();
- return_tensors.clear();
+ results = ImportGraphDefResults();
opts.prefix = "import";
opts.input_map[{"new_input", 1}] = {"input", 0};
opts.return_tensors.push_back({"new_input", 0});
opts.return_tensors.push_back({"new_input", 1});
ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
- &return_tensors);
+ &results);
EXPECT_TRUE(HasNode("import/new_input"));
- ASSERT_EQ(return_tensors.size(), 2);
- EXPECT_EQ(return_tensors[0].first->name(), "import/new_input");
- EXPECT_EQ(return_tensors[0].second, 0);
- EXPECT_EQ(return_tensors[1].first->name(), "input");
- EXPECT_EQ(return_tensors[1].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 2);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "import/new_input");
+ EXPECT_EQ(results.return_tensors[0].second, 0);
+ EXPECT_EQ(results.return_tensors[1].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[1].second, 0);
// Test returning node remapped to source node
opts.prefix.clear();
opts.input_map.clear();
opts.return_tensors.clear();
- return_tensors.clear();
+ results = ImportGraphDefResults();
opts.input_map[{"new_input", 0}] = {"_SOURCE", 0};
opts.return_tensors.push_back({"new_input", 0});
ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
- &return_tensors);
+ &results);
EXPECT_TRUE(HasNode("new_input"));
- ASSERT_EQ(return_tensors.size(), 1);
- EXPECT_EQ(return_tensors[0].first->name(), "_SOURCE");
- EXPECT_EQ(return_tensors[0].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 1);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "_SOURCE");
+ EXPECT_EQ(results.return_tensors[0].second, 0);
}
TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
- // Passing in return_tensors with empty opts.return_tensors is OK
+ // Null results with non-empty opts.return_tensors
ImportGraphDefOptions opts;
- std::vector<std::pair<Node*, int>> return_tensors;
- ExpectOK("node { name: 'input' op: 'TestInput' }", opts, nullptr,
- &return_tensors);
-
- // Null return_tensors with non-empty opts.return_tensors
opts.return_tensors.push_back({"new_input", 0});
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
- {"return_tensors argument to ImportGraphDef() must be non-null "
- "if opts.return_tensors is non-empty"});
+ {"results argument to ImportGraphDef() must be non-null if "
+ "opts.return_tensors is non-empty"});
- // Non-empty return_tensors
- return_tensors.push_back({nullptr, 0});
- ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
- {"return_tensors argument to ImportGraphDef() should be empty "
- "(has size 1)"},
- nullptr, &return_tensors);
+ // Non-empty results.return_tensors
+ ImportGraphDefResults results;
+ results.return_tensors.push_back({nullptr, 0});
+ ExpectError(
+ "node { name: 'new_input' op: 'TestInput' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ nullptr, &results);
// Requesting tensor that isn't in graph def
- return_tensors.clear();
+ results.return_tensors.clear();
ExpectError("node { name: 'W1' op: 'TestParams' }", opts,
- {"Requested return node 'new_input' not found in graph def"},
- nullptr, &return_tensors);
+ {"Requested return tensor 'new_input:0' not found in graph def"},
+ nullptr, &results);
// Requesting invalid node index
opts.return_tensors.clear();
@@ -1655,7 +1646,89 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
{"Invalid return output 2 of node 'new_input', which has 2 "
"output(s)"},
- nullptr, &return_tensors);
+ nullptr, &results);
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ ImportGraphDefOptions opts;
+ opts.return_nodes.push_back("input");
+ opts.return_nodes.push_back("t1");
+ ImportGraphDefResults results;
+ ExpectOK(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 'input2' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: ['input:0', 'input2:1'] }",
+ opts, &refiner, &results);
+
+ // Sanity checks
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("input2"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input2", 1, "t1", 1));
+
+ // Check return tensors
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_tensors.size(), 0);
+ EXPECT_EQ(results.unused_input_map_keys.size(), 0);
+ EXPECT_EQ(results.return_nodes[0]->name(), "input");
+ EXPECT_EQ(results.return_nodes[1]->name(), "t1");
+
+ // Test using prefix
+ opts = ImportGraphDefOptions();
+ results = ImportGraphDefResults();
+ opts.prefix = "import";
+ opts.return_nodes.push_back("input");
+ ExpectOK("node { name: 'input' op: 'TestInput' }", opts, &refiner, &results);
+
+ EXPECT_TRUE(HasNode("import/input"));
+
+ ASSERT_EQ(results.return_nodes.size(), 1);
+ EXPECT_EQ(results.return_nodes[0]->name(), "import/input");
+
+ // Test that input_map has no effect
+ opts = ImportGraphDefOptions();
+ results = ImportGraphDefResults();
+ opts.input_map[{"new_input", 0}] = {"input", 0};
+ opts.return_nodes.push_back("new_input");
+ ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
+ &results);
+
+ EXPECT_TRUE(HasNode("new_input"));
+
+ ASSERT_EQ(results.return_nodes.size(), 1);
+ EXPECT_EQ(results.return_nodes[0]->name(), "new_input");
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodesErrors) {
+ // Null results with non-empty opts.return_nodes
+ ImportGraphDefOptions opts;
+ opts.return_nodes.push_back("new_input");
+ ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
+ {"results argument to ImportGraphDef() must be non-null if "
+ "opts.return_nodes is non-empty"});
+
+ // Non-empty results.return_nodes
+ ImportGraphDefResults results;
+ results.return_nodes.push_back(nullptr);
+ ExpectError(
+ "node { name: 'new_input' op: 'TestInput' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ nullptr, &results);
+
+ // Requesting node that isn't in graph def
+ results.return_nodes.clear();
+ ExpectError("node { name: 'W1' op: 'TestParams' }", opts,
+ {"Requested return node 'new_input' not found in graph def"},
+ nullptr, &results);
+
+ // Requesting return_nodes with skip_mapped_nodes not yet implemented
+ opts.skip_mapped_nodes = true;
+ ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
+ {"Requesting return_nodes with skip_mapped_nodes set is not "
+ "currently supported"});
}
TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) {