aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_constructor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_constructor_test.cc')
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc191
1 files changed, 132 insertions, 59 deletions
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index f88d707ec5..5242c56ce6 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -71,14 +71,12 @@ class GraphConstructorTest : public ::testing::Test {
void ExpectError(const string& gdef_ascii, const ImportGraphDefOptions& opts,
const std::vector<string>& expected_error_strs,
ShapeRefiner* refiner = nullptr,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr) {
+ ImportGraphDefResults* results = nullptr) {
// Used to verify that errors don't change graph
const string original_graph_description = GraphDebugString();
Convert(gdef_ascii);
- Status status = ImportGraphDef(opts, gdef_, &graph_, refiner,
- return_tensors, unused_input_map_keys);
+ Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, results);
EXPECT_FALSE(status.ok());
for (const string& error : expected_error_strs) {
@@ -97,11 +95,9 @@ class GraphConstructorTest : public ::testing::Test {
void ExpectOK(const string& gdef_ascii, const ImportGraphDefOptions& opts,
ShapeRefiner* refiner = nullptr,
- std::vector<std::pair<Node*, int>>* return_tensors = nullptr,
- std::vector<TensorId>* unused_input_map_keys = nullptr) {
+ ImportGraphDefResults* results = nullptr) {
Convert(gdef_ascii);
- Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, return_tensors,
- unused_input_map_keys);
+ Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results);
EXPECT_EQ(Status::OK(), s) << s;
}
@@ -1440,26 +1436,25 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapDuplicateNodeNames) {
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
- std::vector<TensorId> unused_input_map_keys;
-
// No input map
ImportGraphDefOptions opts;
+ ImportGraphDefResults results;
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }",
- opts, &refiner, nullptr, &unused_input_map_keys);
- EXPECT_TRUE(unused_input_map_keys.empty());
+ opts, &refiner, &results);
+ EXPECT_TRUE(results.unused_input_map_keys.empty());
// Non-empty unused_input_map_keys
- unused_input_map_keys.push_back(TensorId());
- ExpectError("node { name: 'W2' op: 'TestParams' }", opts,
- {"If non-null, unused_input_map_keys argument to ImportGraphDef()"
- " should be empty (has size 1)"},
- &refiner, nullptr, &unused_input_map_keys);
+ results.unused_input_map_keys.push_back(TensorId());
+ ExpectError(
+ "node { name: 'W2' op: 'TestParams' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ &refiner, &results);
// Input map with some used, some unused keys
const int kControlSlot = Graph::kControlSlot;
- unused_input_map_keys.clear();
+ results.unused_input_map_keys.clear();
opts.input_map[TensorId("W2", kControlSlot)] = TensorId("W1", kControlSlot);
opts.input_map[TensorId("new_input", 0)] = TensorId("input", 0);
opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0);
@@ -1473,11 +1468,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) {
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
)EOF",
- opts, &refiner, nullptr, &unused_input_map_keys);
+ opts, &refiner, &results);
std::vector<TensorId> expected_unused_keys = {
TensorId("new_input", kControlSlot), TensorId("t1", 1)};
- EXPECT_EQ(unused_input_map_keys, expected_unused_keys);
+ EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys);
}
TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) {
@@ -1567,11 +1562,11 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
opts.return_tensors.push_back({"input", 1});
opts.return_tensors.push_back({"t1", 0});
opts.return_tensors.push_back({"input", 0});
- std::vector<std::pair<Node*, int>> return_tensors;
+ ImportGraphDefResults results;
ExpectOK(
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: ['input:0', 'input:1'] }",
- opts, &refiner, &return_tensors);
+ opts, &refiner, &results);
// Sanity checks
EXPECT_TRUE(HasNode("input"));
@@ -1580,74 +1575,70 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensors) {
EXPECT_TRUE(HasEdge("input", 1, "t1", 1));
// Check return tensors
- ASSERT_EQ(return_tensors.size(), 3);
- EXPECT_EQ(return_tensors[0].first->name(), "input");
- EXPECT_EQ(return_tensors[0].second, 1);
- EXPECT_EQ(return_tensors[1].first->name(), "t1");
- EXPECT_EQ(return_tensors[1].second, 0);
- EXPECT_EQ(return_tensors[2].first->name(), "input");
- EXPECT_EQ(return_tensors[2].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 3);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[0].second, 1);
+ EXPECT_EQ(results.return_tensors[1].first->name(), "t1");
+ EXPECT_EQ(results.return_tensors[1].second, 0);
+ EXPECT_EQ(results.return_tensors[2].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[2].second, 0);
// Test using prefix and returning element from input_map
opts.return_tensors.clear();
- return_tensors.clear();
+ results = ImportGraphDefResults();
opts.prefix = "import";
opts.input_map[{"new_input", 1}] = {"input", 0};
opts.return_tensors.push_back({"new_input", 0});
opts.return_tensors.push_back({"new_input", 1});
ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
- &return_tensors);
+ &results);
EXPECT_TRUE(HasNode("import/new_input"));
- ASSERT_EQ(return_tensors.size(), 2);
- EXPECT_EQ(return_tensors[0].first->name(), "import/new_input");
- EXPECT_EQ(return_tensors[0].second, 0);
- EXPECT_EQ(return_tensors[1].first->name(), "input");
- EXPECT_EQ(return_tensors[1].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 2);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "import/new_input");
+ EXPECT_EQ(results.return_tensors[0].second, 0);
+ EXPECT_EQ(results.return_tensors[1].first->name(), "input");
+ EXPECT_EQ(results.return_tensors[1].second, 0);
// Test returning node remapped to source node
opts.prefix.clear();
opts.input_map.clear();
opts.return_tensors.clear();
- return_tensors.clear();
+ results = ImportGraphDefResults();
opts.input_map[{"new_input", 0}] = {"_SOURCE", 0};
opts.return_tensors.push_back({"new_input", 0});
ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
- &return_tensors);
+ &results);
EXPECT_TRUE(HasNode("new_input"));
- ASSERT_EQ(return_tensors.size(), 1);
- EXPECT_EQ(return_tensors[0].first->name(), "_SOURCE");
- EXPECT_EQ(return_tensors[0].second, 0);
+ ASSERT_EQ(results.return_tensors.size(), 1);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "_SOURCE");
+ EXPECT_EQ(results.return_tensors[0].second, 0);
}
TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
- // Passing in return_tensors with empty opts.return_tensors is OK
+ // Null results with non-empty opts.return_tensors
ImportGraphDefOptions opts;
- std::vector<std::pair<Node*, int>> return_tensors;
- ExpectOK("node { name: 'input' op: 'TestInput' }", opts, nullptr,
- &return_tensors);
-
- // Null return_tensors with non-empty opts.return_tensors
opts.return_tensors.push_back({"new_input", 0});
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
- {"return_tensors argument to ImportGraphDef() must be non-null "
- "if opts.return_tensors is non-empty"});
+ {"results argument to ImportGraphDef() must be non-null if "
+ "opts.return_tensors is non-empty"});
- // Non-empty return_tensors
- return_tensors.push_back({nullptr, 0});
- ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
- {"return_tensors argument to ImportGraphDef() should be empty "
- "(has size 1)"},
- nullptr, &return_tensors);
+ // Non-empty results.return_tensors
+ ImportGraphDefResults results;
+ results.return_tensors.push_back({nullptr, 0});
+ ExpectError(
+ "node { name: 'new_input' op: 'TestInput' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ nullptr, &results);
// Requesting tensor that isn't in graph def
- return_tensors.clear();
+ results.return_tensors.clear();
ExpectError("node { name: 'W1' op: 'TestParams' }", opts,
- {"Requested return node 'new_input' not found in graph def"},
- nullptr, &return_tensors);
+ {"Requested return tensor 'new_input:0' not found in graph def"},
+ nullptr, &results);
// Requesting invalid node index
opts.return_tensors.clear();
@@ -1655,7 +1646,89 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
{"Invalid return output 2 of node 'new_input', which has 2 "
"output(s)"},
- nullptr, &return_tensors);
+ nullptr, &results);
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodes) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ ImportGraphDefOptions opts;
+ opts.return_nodes.push_back("input");
+ opts.return_nodes.push_back("t1");
+ ImportGraphDefResults results;
+ ExpectOK(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 'input2' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: ['input:0', 'input2:1'] }",
+ opts, &refiner, &results);
+
+ // Sanity checks
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("input2"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input2", 1, "t1", 1));
+
+ // Check return tensors
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_tensors.size(), 0);
+ EXPECT_EQ(results.unused_input_map_keys.size(), 0);
+ EXPECT_EQ(results.return_nodes[0]->name(), "input");
+ EXPECT_EQ(results.return_nodes[1]->name(), "t1");
+
+ // Test using prefix
+ opts = ImportGraphDefOptions();
+ results = ImportGraphDefResults();
+ opts.prefix = "import";
+ opts.return_nodes.push_back("input");
+ ExpectOK("node { name: 'input' op: 'TestInput' }", opts, &refiner, &results);
+
+ EXPECT_TRUE(HasNode("import/input"));
+
+ ASSERT_EQ(results.return_nodes.size(), 1);
+ EXPECT_EQ(results.return_nodes[0]->name(), "import/input");
+
+ // Test that input_map has no effect
+ opts = ImportGraphDefOptions();
+ results = ImportGraphDefResults();
+ opts.input_map[{"new_input", 0}] = {"input", 0};
+ opts.return_nodes.push_back("new_input");
+ ExpectOK("node { name: 'new_input' op: 'TestInput' }", opts, &refiner,
+ &results);
+
+ EXPECT_TRUE(HasNode("new_input"));
+
+ ASSERT_EQ(results.return_nodes.size(), 1);
+ EXPECT_EQ(results.return_nodes[0]->name(), "new_input");
+}
+
+TEST_F(GraphConstructorTest, ImportGraphDef_ReturnNodesErrors) {
+ // Null results with non-empty opts.return_nodes
+ ImportGraphDefOptions opts;
+ opts.return_nodes.push_back("new_input");
+ ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
+ {"results argument to ImportGraphDef() must be non-null if "
+ "opts.return_nodes is non-empty"});
+
+ // Non-empty results.return_nodes
+ ImportGraphDefResults results;
+ results.return_nodes.push_back(nullptr);
+ ExpectError(
+ "node { name: 'new_input' op: 'TestInput' }", opts,
+ {"All fields in results argument to ImportGraphDef() must be empty."},
+ nullptr, &results);
+
+ // Requesting node that isn't in graph def
+ results.return_nodes.clear();
+ ExpectError("node { name: 'W1' op: 'TestParams' }", opts,
+ {"Requested return node 'new_input' not found in graph def"},
+ nullptr, &results);
+
+ // Requesting return_nodes with skip_mapped_nodes not yet implemented
+ opts.skip_mapped_nodes = true;
+ ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
+ {"Requesting return_nodes with skip_mapped_nodes set is not "
+ "currently supported"});
}
TEST_F(GraphConstructorTest, ImportGraphDef_WithCycle) {