diff options
-rw-r--r-- | tensorflow/c/c_api.h | 2 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 27 |
2 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index a686f7f701..9b08f9d981 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -818,7 +818,7 @@ extern void TF_ImportGraphDefOptionsAddInputMapping( // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported // into. -extern void TF_GraphImportGraphDefOptionsRemapControlDependency( +extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); // Cause the imported graph to have a control dependency on `oper`. `oper` diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 32143f4f2f..d846daa71b 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -821,6 +821,33 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); + // Export to a graph def so we can import a graph with control dependencies + TF_DeleteBuffer(graph_def); + graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import again, with remapped control dependency, into the same graph + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar4 = + TF_GraphOperationByName(graph, "imported4/imported3/scalar"); + TF_Operation* feed4 = + TF_GraphOperationByName(graph, "imported4/imported2/feed"); + + // Check that imported `imported3/scalar` has remapped control dep from + // original graph and imported control dep + num_control_inputs = TF_OperationGetControlInputs( + scalar4, control_inputs, TF_OperationNumControlInputs(scalar4)); + ASSERT_EQ(2, num_control_inputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed4, control_inputs[1]); + TF_DeleteImportGraphDefOptions(opts); TF_DeleteBuffer(graph_def); |