aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function_test.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-18 20:19:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-18 20:22:15 -0700
commit1da763a1cc94fc5e4ad1822788b444b77623538c (patch)
tree414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_api_function_test.cc
parentd10902f0a947da40f80479d74e9a487617759085 (diff)
Add function gradient support to C API
Also, change the internal representation of TF_Function and rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it clear that a copy of the function is added to the graph. Any subsequent modifications to the function will not be reflected in the copy added to the graph. PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r--tensorflow/c/c_api_function_test.cc159
1 files changed, 158 insertions, 1 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index c9dd38ea15..88d2f1bd27 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -174,7 +174,7 @@ class CApiFunctionTest : public ::testing::Test {
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(func_, nullptr);
- TF_GraphAddFunction(host_graph_, func_, s_);
+ TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
@@ -1035,5 +1035,162 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) {
string(TF_Message(s_)));
}
+void DefineFunction(const char* name, TF_Function** func) {
+ std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
+ TF_NewGraph(), TF_DeleteGraph);
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
+ TF_DeleteStatus);
+
+ TF_Operation* feed = Placeholder(func_graph.get(), s.get());
+ TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
+
+ TF_Output inputs[] = {{feed, 0}};
+ TF_Output outputs[] = {{neg, 0}};
+ *func = TF_GraphToFunction(func_graph.get(), name, -1,
+ /*opers=*/nullptr, 1, inputs, 1, outputs,
+ /*output_names=*/nullptr,
+ /*opts=*/nullptr, s.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
+ ASSERT_NE(*func, nullptr);
+}
+
+TEST_F(CApiFunctionTest, SetGradientAndRun) {
+ // Define the function and its grad
+ DefineFunction(func_name_, &func_);
+ TF_Function* grad_func;
+ DefineFunction("MyGrad", &grad_func);
+
+ // Add func and its gradient to host graph
+ TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Verify that function and its grad are in host graph's GraphDef
+ GraphDef gdef;
+ GetGraphDef(host_graph_, &gdef);
+ std::vector<string> func_names = GetFuncNames(gdef);
+ ASSERT_EQ(2, func_names.size());
+ ASSERT_EQ(func_name_, func_names[0]);
+ ASSERT_EQ("MyGrad", func_names[1]);
+ std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
+ ASSERT_EQ(1, grads.size());
+ ASSERT_EQ(func_name_, grads[0].first);
+ ASSERT_EQ("MyGrad", grads[0].second);
+
+ // These calls must be noops
+ TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Delete the gradient func.
+ // It is safe to delete after adding a copy to host graph.
+ TF_DeleteFunction(grad_func);
+
+ // Check that GraphDef did not change
+ GraphDef gdef2;
+ GetGraphDef(host_graph_, &gdef2);
+ ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
+
+ // Use and run func
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
+}
+
+TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
+ // Define the functions
+ TF_Function* func1;
+ TF_Function* func2;
+ TF_Function* grad_func;
+ DefineFunction("FooFunc1", &func1);
+ DefineFunction("FooFunc2", &func2);
+ DefineFunction("MyGrad", &grad_func);
+
+ // Make grad_func be a gradient of func1 and func2
+ TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Verify that functions and their gradients are in host graph's GraphDef
+ GraphDef gdef;
+ GetGraphDef(host_graph_, &gdef);
+ std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
+ ASSERT_EQ(2, grads.size());
+ ASSERT_EQ("FooFunc1", grads[0].first);
+ ASSERT_EQ("MyGrad", grads[0].second);
+ ASSERT_EQ("FooFunc2", grads[1].first);
+ ASSERT_EQ("MyGrad", grads[1].second);
+
+ TF_DeleteFunction(func1);
+ TF_DeleteFunction(func2);
+ TF_DeleteFunction(grad_func);
+}
+
+TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
+ // Define the functions
+ TF_Function* func;
+ TF_Function* grad_func;
+ DefineFunction("FooFunc", &func);
+ DefineFunction("MyGrad", &grad_func);
+
+ // Add functions individually
+ TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Check that functions are added but not linked
+ GraphDef gdef;
+ GetGraphDef(host_graph_, &gdef);
+ std::vector<string> func_names = GetFuncNames(gdef);
+ ASSERT_EQ(2, func_names.size());
+ ASSERT_EQ("FooFunc", func_names[0]);
+ ASSERT_EQ("MyGrad", func_names[1]);
+ ASSERT_EQ(0, GetGradDefs(gdef).size());
+
+ // Make grad_func a gradient of func
+ TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Verify that function and its grad are linked
+ gdef.Clear();
+ GetGraphDef(host_graph_, &gdef);
+ std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
+ ASSERT_EQ(1, grads.size());
+ ASSERT_EQ("FooFunc", grads[0].first);
+ ASSERT_EQ("MyGrad", grads[0].second);
+
+ TF_DeleteFunction(func);
+ TF_DeleteFunction(grad_func);
+}
+
+TEST_F(CApiFunctionTest, GradientErrorCases) {
+ // Define the function
+ DefineFunction(func_name_, &func_);
+ TF_Function* grad_func1;
+ TF_Function* grad_func2;
+ DefineFunction("MyGrad1", &grad_func1);
+ DefineFunction("MyGrad2", &grad_func2);
+
+ // func cannot be null
+ TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
+ string(TF_Message(s_)));
+
+ // Cannot change gradient
+ TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
+ "because it already has gradient function 'MyGrad1'"),
+ string(TF_Message(s_)));
+
+ TF_DeleteFunction(grad_func1);
+ TF_DeleteFunction(grad_func2);
+}
+
} // namespace
} // namespace tensorflow