diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-10-30 08:07:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-30 08:10:56 -0700 |
commit | ce0238198052358d102ca7786ad9be60a5e76d28 (patch) | |
tree | b1694c3fe23b4933b7967f9494cb7337e673b07e /tensorflow/c/c_api_test.cc | |
parent | ef4490f637e17f3ce599f55522e63d06f470e540 (diff) |
Add ability to fetch return nodes and unused input mappings from C API GraphDef import
This change introduces yet another ImportGraphDef function to the C
API (TF_GraphImportGraphDefWithResults), but this one has extensible
return values so we shouldn't have to add more in the future.
This change also modifies the ImportGraphDef C interface to manage all
string data for the user.
PiperOrigin-RevId: 173894710
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 135 |
1 files changed, 129 insertions, 6 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d220bc5e95..05881e619b 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -573,7 +573,7 @@ TEST(CAPI, ImportGraphDef) { TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - // Import it again, with a prefix, in a fresh graph. + // Import it, with a prefix, in a fresh graph. TF_DeleteGraph(graph); graph = TF_NewGraph(); TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); @@ -588,8 +588,8 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); - // Import it again, with an input mapping and return outputs, into the same - // graph. + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); @@ -597,9 +597,10 @@ TEST(CAPI, ImportGraphDef) { TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); - TF_Output return_outputs[2]; - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, - return_outputs, 2, s); + TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts)); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar"); @@ -615,11 +616,26 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(0, neg_input.index); // Check return outputs + TF_Output* return_outputs; + int num_return_outputs; + TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs, + &return_outputs); + ASSERT_EQ(2, num_return_outputs); EXPECT_EQ(feed2, return_outputs[0].oper); EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); // remapped EXPECT_EQ(0, return_outputs[1].index); + // Check return operation + TF_Operation** return_opers; + int num_return_opers; + TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers, + &return_opers); + ASSERT_EQ(1, num_return_opers); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + + TF_DeleteImportGraphDefResults(results); + // Import again, with control dependencies, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); @@ -689,6 +705,113 @@ TEST(CAPI, ImportGraphDef) { TF_DeleteStatus(s); } +TEST(CAPI, ImportGraphDef_WithReturnOutputs) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph with return outputs. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + TF_Output return_outputs[2]; + TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, + return_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + TF_Operation* feed = TF_GraphOperationByName(graph, "feed"); + TF_Operation* neg = TF_GraphOperationByName(graph, "neg"); + ASSERT_TRUE(scalar != nullptr); + ASSERT_TRUE(feed != nullptr); + ASSERT_TRUE(neg != nullptr); + + // Check return outputs + EXPECT_EQ(feed, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); + EXPECT_EQ(0, return_outputs[1].index); + + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, ImportGraphDef_UnusedInputMappings) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + + // Import it in a fresh graph with an unused input mapping. + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0}); + TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0}); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Check unused input mappings + int num_unused_input_mappings; + const char** src_names; + int* src_indexes; + TF_ImportGraphDefResultsUnusedInputMappings( + results, &num_unused_input_mappings, &src_names, &src_indexes); + ASSERT_EQ(1, num_unused_input_mappings); + EXPECT_EQ(string("fake"), string(src_names[0])); + EXPECT_EQ(0, src_indexes[0]); + + TF_DeleteImportGraphDefResults(results); + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + TEST(CAPI, Session) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); |