aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-01-20 13:25:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-20 13:48:06 -0800
commitbb71ec089658fb8a91423a7cf7195e5c900c2c98 (patch)
treecc0dc350bbd9d1cd5f87bcc1441fe616bb00dabe /tensorflow/c/c_api_test.cc
parent177684c5002ae877c25b9a9d0654347ec6a27c9c (diff)
Expose more ImportGraphDef functionality in the C API.
This patch addes the following methods to the C API: TF_ImportGraphDefOptionsAddInputMapping() TF_ImportGraphDefOptionsAddControlDependency() TF_ImportGraphDefOptionsAddReturnOutput() TF_ImportGraphDefOptionsNumReturnOutputs() TF_GraphImportGraphDefWithReturnOutputs() Change: 145120572
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc76
1 files changed, 72 insertions, 4 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 4ea53ab230..22026f81aa 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -675,9 +675,12 @@ TEST(CAPI, ImportGraphDef) {
Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
- ScalarConst(3, graph, s);
+ TF_Operation* oper = ScalarConst(3, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
+ Neg(oper, graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
// Export to a GraphDef
TF_Buffer* graph_def = TF_NewBuffer();
@@ -692,13 +695,78 @@ TEST(CAPI, ImportGraphDef) {
TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_DeleteImportGraphDefOptions(opts);
- TF_DeleteBuffer(graph_def);
-
TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
+ TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg");
ASSERT_TRUE(scalar != nullptr);
ASSERT_TRUE(feed != nullptr);
+ ASSERT_TRUE(neg != nullptr);
+
+ // Import it again, with an input mapping and return outputs, into the same
+ // graph.
+ TF_DeleteImportGraphDefOptions(opts);
+ opts = TF_NewImportGraphDefOptions();
+ TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
+ TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
+ TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
+ TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
+ EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
+ TF_Output return_outputs[2];
+ TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
+ return_outputs, 2, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
+ TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed");
+ TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg");
+ ASSERT_TRUE(scalar2 != nullptr);
+ ASSERT_TRUE(feed2 != nullptr);
+ ASSERT_TRUE(neg2 != nullptr);
+
+ // Check input mapping
+ TF_Output neg_input = TF_OperationInput({neg, 0});
+ EXPECT_EQ(scalar, neg_input.oper);
+ EXPECT_EQ(0, neg_input.index);
+
+ // Check return outputs
+ EXPECT_EQ(feed2, return_outputs[0].oper);
+ EXPECT_EQ(0, return_outputs[0].index);
+ EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
+ EXPECT_EQ(0, return_outputs[1].index);
+
+ // Import again, with control dependencies, into the same graph.
+ TF_DeleteImportGraphDefOptions(opts);
+ opts = TF_NewImportGraphDefOptions();
+ TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
+ TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
+ TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
+ TF_GraphImportGraphDef(graph, graph_def, opts, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar");
+ TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed");
+ TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg");
+ ASSERT_TRUE(scalar3 != nullptr);
+ ASSERT_TRUE(feed3 != nullptr);
+ ASSERT_TRUE(neg3 != nullptr);
+
+ // Check that newly-imported scalar and feed have control deps (neg3 will
+ // inherit them from input)
+ TF_Operation* control_inputs[100];
+ int num_control_inputs = TF_OperationGetControlInputs(
+ scalar3, control_inputs, TF_OperationNumControlInputs(scalar3));
+ ASSERT_EQ(2, num_control_inputs);
+ EXPECT_EQ(feed, control_inputs[0]);
+ EXPECT_EQ(feed2, control_inputs[1]);
+
+ num_control_inputs = TF_OperationGetControlInputs(
+ feed3, control_inputs, TF_OperationNumControlInputs(feed3));
+ ASSERT_EQ(2, num_control_inputs);
+ EXPECT_EQ(feed, control_inputs[0]);
+ EXPECT_EQ(feed2, control_inputs[1]);
+
+ TF_DeleteImportGraphDefOptions(opts);
+ TF_DeleteBuffer(graph_def);
// Can add nodes to the imported graph without trouble.
Add(feed, scalar, graph, s);