diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-01-20 13:25:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-20 13:48:06 -0800 |
commit | bb71ec089658fb8a91423a7cf7195e5c900c2c98 (patch) | |
tree | cc0dc350bbd9d1cd5f87bcc1441fe616bb00dabe /tensorflow/c/c_api_test.cc | |
parent | 177684c5002ae877c25b9a9d0654347ec6a27c9c (diff) |
Expose more ImportGraphDef functionality in the C API.
This patch addes the following methods to the C API:
TF_ImportGraphDefOptionsAddInputMapping()
TF_ImportGraphDefOptionsAddControlDependency()
TF_ImportGraphDefOptionsAddReturnOutput()
TF_ImportGraphDefOptionsNumReturnOutputs()
TF_GraphImportGraphDefWithReturnOutputs()
Change: 145120572
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 76 |
1 files changed, 72 insertions, 4 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 4ea53ab230..22026f81aa 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -675,9 +675,12 @@ TEST(CAPI, ImportGraphDef) { Placeholder(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); - ScalarConst(3, graph, s); + TF_Operation* oper = ScalarConst(3, graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); // Export to a GraphDef TF_Buffer* graph_def = TF_NewBuffer(); @@ -692,13 +695,78 @@ TEST(CAPI, ImportGraphDef) { TF_GraphImportGraphDef(graph, graph_def, opts, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteImportGraphDefOptions(opts); - TF_DeleteBuffer(graph_def); - TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar"); TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed"); + TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg"); ASSERT_TRUE(scalar != nullptr); ASSERT_TRUE(feed != nullptr); + ASSERT_TRUE(neg != nullptr); + + // Import it again, with an input mapping and return outputs, into the same + // graph. + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); + TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0}); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + TF_Output return_outputs[2]; + TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, + return_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar"); + TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed"); + TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg"); + ASSERT_TRUE(scalar2 != nullptr); + ASSERT_TRUE(feed2 != nullptr); + ASSERT_TRUE(neg2 != nullptr); + + // Check input mapping + TF_Output neg_input = TF_OperationInput({neg, 0}); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Check return outputs + EXPECT_EQ(feed2, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); // remapped + EXPECT_EQ(0, return_outputs[1].index); + + // Import again, with control dependencies, into the same graph. + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); + TF_ImportGraphDefOptionsAddControlDependency(opts, feed); + TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar"); + TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed"); + TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg"); + ASSERT_TRUE(scalar3 != nullptr); + ASSERT_TRUE(feed3 != nullptr); + ASSERT_TRUE(neg3 != nullptr); + + // Check that newly-imported scalar and feed have control deps (neg3 will + // inherit them from input) + TF_Operation* control_inputs[100]; + int num_control_inputs = TF_OperationGetControlInputs( + scalar3, control_inputs, TF_OperationNumControlInputs(scalar3)); + ASSERT_EQ(2, num_control_inputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + num_control_inputs = TF_OperationGetControlInputs( + feed3, control_inputs, TF_OperationNumControlInputs(feed3)); + ASSERT_EQ(2, num_control_inputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed2, control_inputs[1]); + + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); // Can add nodes to the imported graph without trouble. Add(feed, scalar, graph, s); |