diff options
author | 2017-10-16 15:29:59 -0700 | |
---|---|---|
committer | 2017-10-16 15:34:07 -0700 | |
commit | dc442f4ce2d3b11b56721337fe2b9e2282be93be (patch) | |
tree | ee2d7796823a1430bc4c7a9f2dd577204aa28321 | |
parent | 7b6eec7e1175624458a48945bba3f6400e754d33 (diff) |
Add return_nodes option to ImportGraphDef
The is similar to the return_tensors option. return_tensors cannot be
used to fetch nodes with no outputs, so return_nodes is necessary.
In addition, this change also refactors the ImportGraphDef signature
to return all optional return values in a single struct. This is to
keep the ImportGraphDef signature from getting too long, and also
makes the call sites simpler.
PiperOrigin-RevId: 172388270
-rw-r--r-- | tensorflow/c/c_api.cc | 18 | ||||
-rw-r--r-- | tensorflow/c/while_loop_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 73 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.h | 60 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 191 |
5 files changed, 240 insertions, 108 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 334f867e47..79fbd8c90c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1854,18 +1854,18 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, return; } const int last_node_id = graph->graph.num_node_ids(); - std::vector<std::pair<Node*, int>> return_outputs_vec; - status->status = tensorflow::ImportGraphDef( - opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); if (!status->status.ok()) return; for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(return_outputs_vec.size(), num_return_outputs); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(return_outputs_vec[i].first); - return_outputs[i].index = return_outputs_vec[i].second; + return_outputs[i].oper = ToOperation(results.return_tensors[i].first); + return_outputs[i].index = results.return_tensors[i].second; } } @@ -1945,11 +1945,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, } // TOOD(skyewm): change to OutputTensor - std::vector<std::pair<Node*, int>> return_tensors; + tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( - ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); - for (const auto& pair : return_tensors) { + for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } return Status::OK(); diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 2423d83dda..d2d887f32c 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -318,7 +318,7 @@ TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { @@ -358,7 +358,7 @@ TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } // TODO(skyewm): enable this when it works (currently segfaults!) @@ -389,7 +389,7 @@ TEST_F(CApiWhileLoopTest, WrongGraph) { params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): improve error message ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, BadTypes) { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 15f7b9fe8c..92b4843221 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -79,6 +79,7 @@ class GraphConstructor { skip_mapped_nodes(in.skip_mapped_nodes), control_dependencies(in.control_dependencies), return_tensors(in.return_tensors), + return_nodes(in.return_nodes), importing(true) {} bool allow_internal_ops; @@ -89,6 +90,7 @@ class GraphConstructor { bool skip_mapped_nodes; std::vector<string> control_dependencies; std::vector<TensorId> return_tensors; + std::vector<StringPiece> return_nodes; // TODO(ashankar): This bool exists to separate out functionality required // to make ImportGraphDef a close equivalent of Python's import_graph_def @@ -109,6 +111,7 @@ class GraphConstructor { const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, + std::vector<Node*>* return_nodes, std::vector<TensorId>* unused_input_map_keys) { if (versions) { TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, @@ -116,7 +119,7 @@ class GraphConstructor { "GraphDef", "graph")); } GraphConstructor c(opts, node_defs, versions, library, g, refiner, - return_tensors, unused_input_map_keys); + return_tensors, return_nodes, unused_input_map_keys); const Status s = c.TryImport(); if (!s.ok()) c.Undo(); return s; @@ -128,6 +131,7 @@ class GraphConstructor { const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, std::vector<std::pair<Node*, int>>* return_tensors, + std::vector<Node*>* return_nodes, std::vector<TensorId>* unused_input_map_keys) : opts_(opts), node_defs_(node_defs), @@ -137,6 +141,7 @@ class GraphConstructor { original_versions_(g->versions()), refiner_(refiner), return_tensors_(return_tensors), + return_nodes_(return_nodes), unused_input_map_keys_(unused_input_map_keys) {} Status TryImport() { @@ -148,6 +153,7 @@ class GraphConstructor { TF_RETURN_IF_ERROR(AddBackEdges()); TF_RETURN_IF_ERROR(UpdateVersionDef()); TF_RETURN_IF_ERROR(PopulateReturnTensors()); + TF_RETURN_IF_ERROR(PopulateReturnNodes()); FixupSourceAndSinkEdges(g_); return Status::OK(); } @@ -160,6 +166,7 @@ class GraphConstructor { Status AddBackEdges(); Status UpdateVersionDef(); Status PopulateReturnTensors(); + Status PopulateReturnNodes(); void Undo(); @@ -197,6 +204,9 @@ class GraphConstructor { std::vector<std::pair<Node*, int>>* return_tensors_; // May be null. Not owned. + std::vector<Node*>* return_nodes_; + + // May be null. Not owned. std::vector<TensorId>* unused_input_map_keys_; // Intermediate datastructure used to populate `unused_input_map_keys_`. @@ -913,7 +923,8 @@ Status GraphConstructor::PopulateReturnTensors() { // Locate id in imported nodes auto iter = gdef_nodes_.find(id.first); if (iter == gdef_nodes_.end()) { - return errors::InvalidArgument("Requested return node '", id.first, + return errors::InvalidArgument("Requested return tensor '", + id.ToString(), "' not found in graph def"); } int num_outputs = iter->second.node->num_outputs(); @@ -935,6 +946,19 @@ Status GraphConstructor::PopulateReturnTensors() { return Status::OK(); } +Status GraphConstructor::PopulateReturnNodes() { + if (opts_.return_nodes.empty()) return Status::OK(); + for (StringPiece name : opts_.return_nodes) { + auto iter = gdef_nodes_.find(name); + if (iter == gdef_nodes_.end()) { + return errors::InvalidArgument("Requested return node '", name, + "' not found in graph def"); + } + return_nodes_->push_back(iter->second.node); + } + return Status::OK(); +} + void GraphConstructor::Undo() { for (const auto& iter : gdef_nodes_) { if (iter.second.node != nullptr) { @@ -965,7 +989,8 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, - /*return_tensors=*/nullptr, /*unused_input_map_keys=*/nullptr); + /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, + /*unused_input_map_keys=*/nullptr); } Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, @@ -978,31 +1003,40 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, } return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, &refiner, /*return_tensors=*/nullptr, + /*return_nodes=*/nullptr, /*unused_input_map_keys=*/nullptr); } Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g, ShapeRefiner* refiner, - std::vector<std::pair<Node*, int>>* return_tensors, - std::vector<TensorId>* unused_input_map_keys) { + ImportGraphDefResults* results) { if (!opts.return_tensors.empty()) { - if (return_tensors == nullptr) { + if (results == nullptr) { return errors::InvalidArgument( - "return_tensors argument to ImportGraphDef() must be non-null if " + "results argument to ImportGraphDef() must be non-null if " "opts.return_tensors is non-empty"); } - if (!return_tensors->empty()) { + } + + if (!opts.return_nodes.empty()) { + if (opts.skip_mapped_nodes) { + return errors::InvalidArgument( + "Requesting return_nodes with skip_mapped_nodes set is not currently " + "supported"); + } + if (results == nullptr) { return errors::InvalidArgument( - "return_tensors argument to ImportGraphDef() should be empty (has " - "size ", - return_tensors->size(), ")"); + "results argument to ImportGraphDef() must be non-null if " + "opts.return_nodes is non-empty"); } } - if (unused_input_map_keys != nullptr && !unused_input_map_keys->empty()) { - return errors::InvalidArgument( - "If non-null, unused_input_map_keys argument to ImportGraphDef() should" - " be empty (has size ", - unused_input_map_keys->size(), ")"); + + if (results != nullptr) { + if (!results->return_tensors.empty() || !results->return_nodes.empty() || + !results->unused_input_map_keys.empty()) { + return errors::InvalidArgument( + "All fields in results argument to ImportGraphDef() must be empty."); + } } ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); @@ -1034,9 +1068,10 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, refiner->set_graph_def_version( std::min(refiner->graph_def_version(), gdef.versions().producer())); - return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), - &gdef.library(), g, refiner, - return_tensors, unused_input_map_keys); + return GraphConstructor::Construct( + opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, + &results->return_tensors, &results->return_nodes, + &results->unused_input_map_keys); } void CopyGraph(const Graph& src, Graph* dest) { diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index a8f9f2b245..6cd9347d96 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -72,8 +72,6 @@ struct ImportGraphDefOptions { // used to create the existing nodes referenced in `input_map`. // TODO(skyewm): can we remove this requirement? How do we access the original // shape refiner? - // - // TODO(skyewm): add functionality to retrieve unused `input_map` keys std::map<TensorId, TensorId> input_map; // If true, nodes that will have all output edges removed because of @@ -88,10 +86,10 @@ struct ImportGraphDefOptions { // other nodes in `gdef`. std::vector<string> control_dependencies; - // Tensors in `gdef` that will be returned via the `return_tensors` output - // parameter of `ImportGraphDef()`. If this list is non-empty, the caller must - // pass an empty vector to `ImportGraphDef()`. The vector will be populated - // with the imported nodes in `g`. + // Tensors in `gdef` that will be returned via the ImportGraphDefResults + // output parameter of `ImportGraphDef()`. If this list is non-empty, the + // caller must pass a results object to `ImportGraphDef()`. The + // `return_tensors` field will be populated with the imported nodes in `g`. // // Entries should not include `prefix`, i.e., each TensorId's name should be // the name as it originally appears in `gdef`. @@ -100,12 +98,43 @@ struct ImportGraphDefOptions { // corresponding existing tensor in `g` will be returned. std::vector<TensorId> return_tensors; + // The names of nodes in `gdef` that will be returned via the + // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list + // is non-empty, the caller must pass a results object to + // `ImportGraphDef()`. The `return_nodes` field will be populated with the + // imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each node's name should be the + // name as it originally appears in `gdef`. + // + // Unlike `return_tensors`, `input_map` has no effect on the nodes + // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. + // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. + std::vector<StringPiece> return_nodes; + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef. // Similar to the producer_op_list argument to import_graph_def in the // python API. }; +// Optional results that may be returned by ImportGraphDef. +struct ImportGraphDefResults { + // The requested tensors associated with + // ImportGraphDefOptions::return_tensors. Note that the index may be different + // than the requested index if the returned tensor has been remapped according + // to `input_map`. + typedef int Index; + std::vector<std::pair<Node*, Index>> return_tensors; + + // The requested nodes associated with ImportGraphDefOptions::return_nodes. + std::vector<Node*> return_nodes; + + // Keys in ImportGraphDefOptions::input_map that weren't used as an input to + // any node in`gdef`. + std::vector<TensorId> unused_input_map_keys; +}; + // Adds the graph in GraphDef `gdef` into an existing Graph `*g`. // // On error, returns non-OK and leaves `*g` unmodified. @@ -115,21 +144,16 @@ struct ImportGraphDefOptions { // allows the caller to validate shapes of those nodes (since // ShapeRefiner::AddNode must be called in topological order). // -// Each `return_tensors` entry is the requested node and output index. The index -// is included in case the returned tensor has been remapped according to -// `input_map`. -// -// If `unused_input_map_keys` is non-null, it should be empty and will be -// populated with any keys in `opts.input_map` that aren't used as an input to -// any node in `gdef`. +// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is +// non-empty. It can also be set to fetch the unused input map keys. If it's +// non-null, all the vector fields must be empty. // // TODO(ashankar): Push this mechanism and get rid of Session::Extend() // as a means of enhancing an existing Graph. -extern Status ImportGraphDef( - const ImportGraphDefOptions& opts, const GraphDef& gdef, Graph* g, - ShapeRefiner* refiner, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr); +extern Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results = nullptr); // Make a copy of "src" into "*dest". // diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index f88d707ec5..5242c56ce6 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -71,14 +71,12 @@ class GraphConstructorTest : public ::testing::Test { void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts, const std::vector<string>& expected_error_strs, ShapeRefiner* refiner = nullptr, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr) { + ImportGraphDefResults* results = nullptr) { // Used to verify that errors don't change graph const string original_graph_description = GraphDebugString(); Convert(gdef_ascii); - Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, - return_tensors, unused_input_map_keys); + Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_FALSE(status.ok()); for (const string& error : expected_error_strs) { @@ -97,11 +95,9 @@ class GraphConstructorTest : public ::testing::Test { void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts, ShapeRefiner* refiner = nullptr, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr) { + ImportGraphDefResults* results = nullptr) { Convert(gdef_ascii); - Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors, - unused_input_map_keys); + Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_EQ(Status::OK(), s) << s; } @@ -1440,26 +1436,25 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) { TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); - std::vector<TensorId> unused_input_map_keys; - // No input map ImportGraphDefOptions opts; + ImportGraphDefResults results; ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }", - opts, &refiner, nullptr, &unused_input_map_keys); - EXPECT_TRUE(unused_input_map_keys.empty()); + opts, &refiner, &results); + EXPECT_TRUE(results.unused_input_map_keys.empty()); // Non-empty unused_input_map_keys - unused_input_map_keys.push_back(TensorId()); - ExpectError("node { name: 'W2' op: 'TestParams' }", opts, - {"If non-null, unused_input_map_keys argument to ImportGraphDef()" - " should be empty (has size 1)"}, - &refiner, nullptr, &unused_input_map_keys); + results.unused_input_map_keys.push_back(TensorId()); + ExpectError( + "node { name: 'W2' op: 'TestParams' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + &refiner, &results); // Input map with some used, some unused keys const int kControlSlot = Graph::kControlSlot; - unused_input_map_keys.clear(); + results.unused_input_map_keys.clear(); opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot); opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0); @@ -1473,11 +1468,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } )EOF", - opts, &refiner, nullptr, &unused_input_map_keys); + opts, &refiner, &results); std::vector<TensorId> expected_unused_keys = { TensorId("new_input", kControlSlot), TensorId("t1", 1)}; - EXPECT_EQ(unused_input_map_keys, expected_unused_keys); + EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys); } TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) { @@ -1567,11 +1562,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) { opts.return_tensors.push_back({"input", 1}); opts.return_tensors.push_back({"t1", 0}); opts.return_tensors.push_back({"input", 0}); - std::vector<std::pair<Node*, int>> return_tensors; + ImportGraphDefResults results; ExpectOK( "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: ['input:0', 'input:1'] }", - opts, &refiner, &return_tensors); + opts, &refiner, &results); // Sanity checks EXPECT_TRUE(HasNode("input")); @@ -1580,74 +1575,70 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) { EXPECT_TRUE(HasEdge("input", 1, "t1", 1)); // Check return tensors - ASSERT_EQ(return_tensors.size(), 3); - EXPECT_EQ(return_tensors[0].first->name(), "input"); - EXPECT_EQ(return_tensors[0].second, 1); - EXPECT_EQ(return_tensors[1].first->name(), "t1"); - EXPECT_EQ(return_tensors[1].second, 0); - EXPECT_EQ(return_tensors[2].first->name(), "input"); - EXPECT_EQ(return_tensors[2].second, 0); + ASSERT_EQ(results.return_tensors.size(), 3); + EXPECT_EQ(results.return_tensors[0].first->name(), "input"); + EXPECT_EQ(results.return_tensors[0].second, 1); + EXPECT_EQ(results.return_tensors[1].first->name(), "t1"); + EXPECT_EQ(results.return_tensors[1].second, 0); + EXPECT_EQ(results.return_tensors[2].first->name(), "input"); + EXPECT_EQ(results.return_tensors[2].second, 0); // Test using prefix and returning element from input_map opts.return_tensors.clear(); - return_tensors.clear(); + results = ImportGraphDefResults(); opts.prefix = "import"; opts.input_map[{"new_input", 1}] = {"input", 0}; opts.return_tensors.push_back({"new_input", 0}); opts.return_tensors.push_back({"new_input", 1}); ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, - &return_tensors); + &results); EXPECT_TRUE(HasNode("import/new_input")); - ASSERT_EQ(return_tensors.size(), 2); - EXPECT_EQ(return_tensors[0].first->name(), "import/new_input"); - EXPECT_EQ(return_tensors[0].second, 0); - EXPECT_EQ(return_tensors[1].first->name(), "input"); - EXPECT_EQ(return_tensors[1].second, 0); + ASSERT_EQ(results.return_tensors.size(), 2); + EXPECT_EQ(results.return_tensors[0].first->name(), "import/new_input"); + EXPECT_EQ(results.return_tensors[0].second, 0); + EXPECT_EQ(results.return_tensors[1].first->name(), "input"); + EXPECT_EQ(results.return_tensors[1].second, 0); // Test returning node remapped to source node opts.prefix.clear(); opts.input_map.clear(); opts.return_tensors.clear(); - return_tensors.clear(); + results = ImportGraphDefResults(); opts.input_map[{"new_input", 0}] = {"_SOURCE", 0}; opts.return_tensors.push_back({"new_input", 0}); ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, - &return_tensors); + &results); EXPECT_TRUE(HasNode("new_input")); - ASSERT_EQ(return_tensors.size(), 1); - EXPECT_EQ(return_tensors[0].first->name(), "_SOURCE"); - EXPECT_EQ(return_tensors[0].second, 0); + ASSERT_EQ(results.return_tensors.size(), 1); + EXPECT_EQ(results.return_tensors[0].first->name(), "_SOURCE"); + EXPECT_EQ(results.return_tensors[0].second, 0); } TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) { - // Passing in return_tensors with empty opts.return_tensors is OK + // Null results with non-empty opts.return_tensors ImportGraphDefOptions opts; - std::vector<std::pair<Node*, int>> return_tensors; - ExpectOK("node { name: 'input' op: 'TestInput' }", opts, nullptr, - &return_tensors); - - // Null return_tensors with non-empty opts.return_tensors opts.return_tensors.push_back({"new_input", 0}); ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, - {"return_tensors argument to ImportGraphDef() must be non-null " - "if opts.return_tensors is non-empty"}); + {"results argument to ImportGraphDef() must be non-null if " + "opts.return_tensors is non-empty"}); - // Non-empty return_tensors - return_tensors.push_back({nullptr, 0}); - ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, - {"return_tensors argument to ImportGraphDef() should be empty " - "(has size 1)"}, - nullptr, &return_tensors); + // Non-empty results.return_tensors + ImportGraphDefResults results; + results.return_tensors.push_back({nullptr, 0}); + ExpectError( + "node { name: 'new_input' op: 'TestInput' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + nullptr, &results); // Requesting tensor that isn't in graph def - return_tensors.clear(); + results.return_tensors.clear(); ExpectError("node { name: 'W1' op: 'TestParams' }", opts, - {"Requested return node 'new_input' not found in graph def"}, - nullptr, &return_tensors); + {"Requested return tensor 'new_input:0' not found in graph def"}, + nullptr, &results); // Requesting invalid node index opts.return_tensors.clear(); @@ -1655,7 +1646,89 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) { ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, {"Invalid return output 2 of node 'new_input', which has 2 " "output(s)"}, - nullptr, &return_tensors); + nullptr, &results); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + + ImportGraphDefOptions opts; + opts.return_nodes.push_back("input"); + opts.return_nodes.push_back("t1"); + ImportGraphDefResults results; + ExpectOK( + "node { name: 'input' op: 'TestInput' }" + "node { name: 'input2' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: ['input:0', 'input2:1'] }", + opts, &refiner, &results); + + // Sanity checks + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("input2")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasEdge("input", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input2", 1, "t1", 1)); + + // Check return tensors + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_tensors.size(), 0); + EXPECT_EQ(results.unused_input_map_keys.size(), 0); + EXPECT_EQ(results.return_nodes[0]->name(), "input"); + EXPECT_EQ(results.return_nodes[1]->name(), "t1"); + + // Test using prefix + opts = ImportGraphDefOptions(); + results = ImportGraphDefResults(); + opts.prefix = "import"; + opts.return_nodes.push_back("input"); + ExpectOK("node { name: 'input' op: 'TestInput' }", opts, &refiner, &results); + + EXPECT_TRUE(HasNode("import/input")); + + ASSERT_EQ(results.return_nodes.size(), 1); + EXPECT_EQ(results.return_nodes[0]->name(), "import/input"); + + // Test that input_map has no effect + opts = ImportGraphDefOptions(); + results = ImportGraphDefResults(); + opts.input_map[{"new_input", 0}] = {"input", 0}; + opts.return_nodes.push_back("new_input"); + ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, + &results); + + EXPECT_TRUE(HasNode("new_input")); + + ASSERT_EQ(results.return_nodes.size(), 1); + EXPECT_EQ(results.return_nodes[0]->name(), "new_input"); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodesErrors) { + // Null results with non-empty opts.return_nodes + ImportGraphDefOptions opts; + opts.return_nodes.push_back("new_input"); + ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, + {"results argument to ImportGraphDef() must be non-null if " + "opts.return_nodes is non-empty"}); + + // Non-empty results.return_nodes + ImportGraphDefResults results; + results.return_nodes.push_back(nullptr); + ExpectError( + "node { name: 'new_input' op: 'TestInput' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + nullptr, &results); + + // Requesting node that isn't in graph def + results.return_nodes.clear(); + ExpectError("node { name: 'W1' op: 'TestParams' }", opts, + {"Requested return node 'new_input' not found in graph def"}, + nullptr, &results); + + // Requesting return_nodes with skip_mapped_nodes not yet implemented + opts.skip_mapped_nodes = true; + ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, + {"Requesting return_nodes with skip_mapped_nodes set is not " + "currently supported"}); } TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) { |