aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-30 08:07:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-30 08:10:56 -0700
commitce0238198052358d102ca7786ad9be60a5e76d28 (patch)
treeb1694c3fe23b4933b7967f9494cb7337e673b07e /tensorflow/c/c_api_test.cc
parentef4490f637e17f3ce599f55522e63d06f470e540 (diff)
Add ability to fetch return nodes and unused input mappings from C API GraphDef import
This change introduces yet another ImportGraphDef function to the C API (TF_GraphImportGraphDefWithResults), but this one has extensible return values so we shouldn't have to add more in the future. This change also modifies the ImportGraphDef C interface to manage all string data for the user. PiperOrigin-RevId: 173894710
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc135
1 files changed, 129 insertions, 6 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index d220bc5e95..05881e619b 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -573,7 +573,7 @@ TEST(CAPI, ImportGraphDef) {
TF_GraphToGraphDef(graph, graph_def, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- // Import it again, with a prefix, in a fresh graph.
+ // Import it, with a prefix, in a fresh graph.
TF_DeleteGraph(graph);
graph = TF_NewGraph();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
@@ -588,8 +588,8 @@ TEST(CAPI, ImportGraphDef) {
ASSERT_TRUE(feed != nullptr);
ASSERT_TRUE(neg != nullptr);
- // Import it again, with an input mapping and return outputs, into the same
- // graph.
+ // Import it again, with an input mapping, return outputs, and a return
+ // operation, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
@@ -597,9 +597,10 @@ TEST(CAPI, ImportGraphDef) {
TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
- TF_Output return_outputs[2];
- TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
- return_outputs, 2, s);
+ TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
+ EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
+ TF_ImportGraphDefResults* results =
+ TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
@@ -615,11 +616,26 @@ TEST(CAPI, ImportGraphDef) {
EXPECT_EQ(0, neg_input.index);
// Check return outputs
+ TF_Output* return_outputs;
+ int num_return_outputs;
+ TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
+ &return_outputs);
+ ASSERT_EQ(2, num_return_outputs);
EXPECT_EQ(feed2, return_outputs[0].oper);
EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
EXPECT_EQ(0, return_outputs[1].index);
+ // Check return operation
+ TF_Operation** return_opers;
+ int num_return_opers;
+ TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
+ &return_opers);
+ ASSERT_EQ(1, num_return_opers);
+ EXPECT_EQ(scalar2, return_opers[0]); // not remapped
+
+ TF_DeleteImportGraphDefResults(results);
+
// Import again, with control dependencies, into the same graph.
TF_DeleteImportGraphDefOptions(opts);
opts = TF_NewImportGraphDefOptions();
@@ -689,6 +705,113 @@ TEST(CAPI, ImportGraphDef) {
TF_DeleteStatus(s);
}
+TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // Create a graph with two nodes: x and 3
+ Placeholder(graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
+ TF_Operation* oper = ScalarConst(3, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
+ Neg(oper, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
+
+ // Export to a GraphDef.
+ TF_Buffer* graph_def = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, graph_def, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Import it in a fresh graph with return outputs.
+ TF_DeleteGraph(graph);
+ graph = TF_NewGraph();
+ TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
+ TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
+ TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
+ EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
+ TF_Output return_outputs[2];
+ TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
+ return_outputs, 2, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
+ TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
+ TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
+ ASSERT_TRUE(scalar != nullptr);
+ ASSERT_TRUE(feed != nullptr);
+ ASSERT_TRUE(neg != nullptr);
+
+ // Check return outputs
+ EXPECT_EQ(feed, return_outputs[0].oper);
+ EXPECT_EQ(0, return_outputs[0].index);
+ EXPECT_EQ(scalar, return_outputs[1].oper);
+ EXPECT_EQ(0, return_outputs[1].index);
+
+ TF_DeleteImportGraphDefOptions(opts);
+ TF_DeleteBuffer(graph_def);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST(CAPI, ImportGraphDef_UnusedInputMappings) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // Create a graph with two nodes: x and 3
+ Placeholder(graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
+ TF_Operation* oper = ScalarConst(3, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
+ Neg(oper, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
+
+ // Export to a GraphDef.
+ TF_Buffer* graph_def = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, graph_def, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Import it in a fresh graph.
+ TF_DeleteGraph(graph);
+ graph = TF_NewGraph();
+ TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
+ TF_GraphImportGraphDef(graph, graph_def, opts, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
+
+ // Import it in a fresh graph with an unused input mapping.
+ TF_DeleteImportGraphDefOptions(opts);
+ opts = TF_NewImportGraphDefOptions();
+ TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
+ TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
+ TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
+ TF_ImportGraphDefResults* results =
+ TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Check unused input mappings
+ int num_unused_input_mappings;
+ const char** src_names;
+ int* src_indexes;
+ TF_ImportGraphDefResultsUnusedInputMappings(
+ results, &num_unused_input_mappings, &src_names, &src_indexes);
+ ASSERT_EQ(1, num_unused_input_mappings);
+ EXPECT_EQ(string("fake"), string(src_names[0]));
+ EXPECT_EQ(0, src_indexes[0]);
+
+ TF_DeleteImportGraphDefResults(results);
+ TF_DeleteImportGraphDefOptions(opts);
+ TF_DeleteBuffer(graph_def);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
TEST(CAPI, Session) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();