diff options
Diffstat (limited to 'tensorflow/core/graph/graph_constructor_test.cc')
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 191 |
1 files changed, 132 insertions, 59 deletions
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index f88d707ec5..5242c56ce6 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -71,14 +71,12 @@ class GraphConstructorTest : public ::testing::Test { void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts, const std::vector<string>& expected_error_strs, ShapeRefiner* refiner = nullptr, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr) { + ImportGraphDefResults* results = nullptr) { // Used to verify that errors don't change graph const string original_graph_description = GraphDebugString(); Convert(gdef_ascii); - Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, - return_tensors, unused_input_map_keys); + Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_FALSE(status.ok()); for (const string& error : expected_error_strs) { @@ -97,11 +95,9 @@ class GraphConstructorTest : public ::testing::Test { void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts, ShapeRefiner* refiner = nullptr, - std::vector<std::pair<Node*, int>>* return_tensors = nullptr, - std::vector<TensorId>* unused_input_map_keys = nullptr) { + ImportGraphDefResults* results = nullptr) { Convert(gdef_ascii); - Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors, - unused_input_map_keys); + Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_EQ(Status::OK(), s) << s; } @@ -1440,26 +1436,25 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) { TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); - std::vector<TensorId> unused_input_map_keys; - // No input map ImportGraphDefOptions opts; + ImportGraphDefResults results; ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }", - opts, &refiner, nullptr, &unused_input_map_keys); - EXPECT_TRUE(unused_input_map_keys.empty()); + opts, &refiner, &results); + EXPECT_TRUE(results.unused_input_map_keys.empty()); // Non-empty unused_input_map_keys - unused_input_map_keys.push_back(TensorId()); - ExpectError("node { name: 'W2' op: 'TestParams' }", opts, - {"If non-null, unused_input_map_keys argument to ImportGraphDef()" - " should be empty (has size 1)"}, - &refiner, nullptr, &unused_input_map_keys); + results.unused_input_map_keys.push_back(TensorId()); + ExpectError( + "node { name: 'W2' op: 'TestParams' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + &refiner, &results); // Input map with some used, some unused keys const int kControlSlot = Graph::kControlSlot; - unused_input_map_keys.clear(); + results.unused_input_map_keys.clear(); opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot); opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0); opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0); @@ -1473,11 +1468,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } )EOF", - opts, &refiner, nullptr, &unused_input_map_keys); + opts, &refiner, &results); std::vector<TensorId> expected_unused_keys = { TensorId("new_input", kControlSlot), TensorId("t1", 1)}; - EXPECT_EQ(unused_input_map_keys, expected_unused_keys); + EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys); } TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) { @@ -1567,11 +1562,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) { opts.return_tensors.push_back({"input", 1}); opts.return_tensors.push_back({"t1", 0}); opts.return_tensors.push_back({"input", 0}); - std::vector<std::pair<Node*, int>> return_tensors; + ImportGraphDefResults results; ExpectOK( "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: ['input:0', 'input:1'] }", - opts, &refiner, &return_tensors); + opts, &refiner, &results); // Sanity checks EXPECT_TRUE(HasNode("input")); @@ -1580,74 +1575,70 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) { EXPECT_TRUE(HasEdge("input", 1, "t1", 1)); // Check return tensors - ASSERT_EQ(return_tensors.size(), 3); - EXPECT_EQ(return_tensors[0].first->name(), "input"); - EXPECT_EQ(return_tensors[0].second, 1); - EXPECT_EQ(return_tensors[1].first->name(), "t1"); - EXPECT_EQ(return_tensors[1].second, 0); - EXPECT_EQ(return_tensors[2].first->name(), "input"); - EXPECT_EQ(return_tensors[2].second, 0); + ASSERT_EQ(results.return_tensors.size(), 3); + EXPECT_EQ(results.return_tensors[0].first->name(), "input"); + EXPECT_EQ(results.return_tensors[0].second, 1); + EXPECT_EQ(results.return_tensors[1].first->name(), "t1"); + EXPECT_EQ(results.return_tensors[1].second, 0); + EXPECT_EQ(results.return_tensors[2].first->name(), "input"); + EXPECT_EQ(results.return_tensors[2].second, 0); // Test using prefix and returning element from input_map opts.return_tensors.clear(); - return_tensors.clear(); + results = ImportGraphDefResults(); opts.prefix = "import"; opts.input_map[{"new_input", 1}] = {"input", 0}; opts.return_tensors.push_back({"new_input", 0}); opts.return_tensors.push_back({"new_input", 1}); ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, - &return_tensors); + &results); EXPECT_TRUE(HasNode("import/new_input")); - ASSERT_EQ(return_tensors.size(), 2); - EXPECT_EQ(return_tensors[0].first->name(), "import/new_input"); - EXPECT_EQ(return_tensors[0].second, 0); - EXPECT_EQ(return_tensors[1].first->name(), "input"); - EXPECT_EQ(return_tensors[1].second, 0); + ASSERT_EQ(results.return_tensors.size(), 2); + EXPECT_EQ(results.return_tensors[0].first->name(), "import/new_input"); + EXPECT_EQ(results.return_tensors[0].second, 0); + EXPECT_EQ(results.return_tensors[1].first->name(), "input"); + EXPECT_EQ(results.return_tensors[1].second, 0); // Test returning node remapped to source node opts.prefix.clear(); opts.input_map.clear(); opts.return_tensors.clear(); - return_tensors.clear(); + results = ImportGraphDefResults(); opts.input_map[{"new_input", 0}] = {"_SOURCE", 0}; opts.return_tensors.push_back({"new_input", 0}); ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, - &return_tensors); + &results); EXPECT_TRUE(HasNode("new_input")); - ASSERT_EQ(return_tensors.size(), 1); - EXPECT_EQ(return_tensors[0].first->name(), "_SOURCE"); - EXPECT_EQ(return_tensors[0].second, 0); + ASSERT_EQ(results.return_tensors.size(), 1); + EXPECT_EQ(results.return_tensors[0].first->name(), "_SOURCE"); + EXPECT_EQ(results.return_tensors[0].second, 0); } TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) { - // Passing in return_tensors with empty opts.return_tensors is OK + // Null results with non-empty opts.return_tensors ImportGraphDefOptions opts; - std::vector<std::pair<Node*, int>> return_tensors; - ExpectOK("node { name: 'input' op: 'TestInput' }", opts, nullptr, - &return_tensors); - - // Null return_tensors with non-empty opts.return_tensors opts.return_tensors.push_back({"new_input", 0}); ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, - {"return_tensors argument to ImportGraphDef() must be non-null " - "if opts.return_tensors is non-empty"}); + {"results argument to ImportGraphDef() must be non-null if " + "opts.return_tensors is non-empty"}); - // Non-empty return_tensors - return_tensors.push_back({nullptr, 0}); - ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, - {"return_tensors argument to ImportGraphDef() should be empty " - "(has size 1)"}, - nullptr, &return_tensors); + // Non-empty results.return_tensors + ImportGraphDefResults results; + results.return_tensors.push_back({nullptr, 0}); + ExpectError( + "node { name: 'new_input' op: 'TestInput' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + nullptr, &results); // Requesting tensor that isn't in graph def - return_tensors.clear(); + results.return_tensors.clear(); ExpectError("node { name: 'W1' op: 'TestParams' }", opts, - {"Requested return node 'new_input' not found in graph def"}, - nullptr, &return_tensors); + {"Requested return tensor 'new_input:0' not found in graph def"}, + nullptr, &results); // Requesting invalid node index opts.return_tensors.clear(); @@ -1655,7 +1646,89 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) { ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, {"Invalid return output 2 of node 'new_input', which has 2 " "output(s)"}, - nullptr, &return_tensors); + nullptr, &results); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + + ImportGraphDefOptions opts; + opts.return_nodes.push_back("input"); + opts.return_nodes.push_back("t1"); + ImportGraphDefResults results; + ExpectOK( + "node { name: 'input' op: 'TestInput' }" + "node { name: 'input2' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: ['input:0', 'input2:1'] }", + opts, &refiner, &results); + + // Sanity checks + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("input2")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasEdge("input", 0, "t1", 0)); + EXPECT_TRUE(HasEdge("input2", 1, "t1", 1)); + + // Check return tensors + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_tensors.size(), 0); + EXPECT_EQ(results.unused_input_map_keys.size(), 0); + EXPECT_EQ(results.return_nodes[0]->name(), "input"); + EXPECT_EQ(results.return_nodes[1]->name(), "t1"); + + // Test using prefix + opts = ImportGraphDefOptions(); + results = ImportGraphDefResults(); + opts.prefix = "import"; + opts.return_nodes.push_back("input"); + ExpectOK("node { name: 'input' op: 'TestInput' }", opts, &refiner, &results); + + EXPECT_TRUE(HasNode("import/input")); + + ASSERT_EQ(results.return_nodes.size(), 1); + EXPECT_EQ(results.return_nodes[0]->name(), "import/input"); + + // Test that input_map has no effect + opts = ImportGraphDefOptions(); + results = ImportGraphDefResults(); + opts.input_map[{"new_input", 0}] = {"input", 0}; + opts.return_nodes.push_back("new_input"); + ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner, + &results); + + EXPECT_TRUE(HasNode("new_input")); + + ASSERT_EQ(results.return_nodes.size(), 1); + EXPECT_EQ(results.return_nodes[0]->name(), "new_input"); +} + +TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodesErrors) { + // Null results with non-empty opts.return_nodes + ImportGraphDefOptions opts; + opts.return_nodes.push_back("new_input"); + ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, + {"results argument to ImportGraphDef() must be non-null if " + "opts.return_nodes is non-empty"}); + + // Non-empty results.return_nodes + ImportGraphDefResults results; + results.return_nodes.push_back(nullptr); + ExpectError( + "node { name: 'new_input' op: 'TestInput' }", opts, + {"All fields in results argument to ImportGraphDef() must be empty."}, + nullptr, &results); + + // Requesting node that isn't in graph def + results.return_nodes.clear(); + ExpectError("node { name: 'W1' op: 'TestParams' }", opts, + {"Requested return node 'new_input' not found in graph def"}, + nullptr, &results); + + // Requesting return_nodes with skip_mapped_nodes not yet implemented + opts.skip_mapped_nodes = true; + ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, + {"Requesting return_nodes with skip_mapped_nodes set is not " + "currently supported"}); } TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) { |