diff options
author | 2017-09-18 20:19:13 -0700 | |
---|---|---|
committer | 2017-09-18 20:22:15 -0700 | |
commit | 1da763a1cc94fc5e4ad1822788b444b77623538c (patch) | |
tree | 414b168e2f3b1dc8be71c58b71e36c34818ff1c6 /tensorflow/c/c_api_function_test.cc | |
parent | d10902f0a947da40f80479d74e9a487617759085 (diff) |
Add function gradient support to C API
Also, change the internal representation of TF_Function and
rename TF_GraphAddFunction to TF_GraphAddFunctionCopy to make it
clear that a copy of the function is added to the graph. Any
subsequent modifications to the function will not be reflected
in the copy added to the graph.
PiperOrigin-RevId: 169187793
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 159 |
1 files changed, 158 insertions, 1 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index c9dd38ea15..88d2f1bd27 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -174,7 +174,7 @@ class CApiFunctionTest : public ::testing::Test { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_NE(func_, nullptr); - TF_GraphAddFunction(host_graph_, func_, s_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -1035,5 +1035,162 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) { string(TF_Message(s_))); } +void DefineFunction(const char* name, TF_Function** func) { + std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation* feed = Placeholder(func_graph.get(), s.get()); + TF_Operation* neg = Neg(feed, func_graph.get(), s.get()); + + TF_Output inputs[] = {{feed, 0}}; + TF_Output outputs[] = {{neg, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, -1, + /*opers=*/nullptr, 1, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); +} + +TEST_F(CApiFunctionTest, SetGradientAndRun) { + // Define the function and its grad + DefineFunction(func_name_, &func_); + TF_Function* grad_func; + DefineFunction("MyGrad", &grad_func); + + // Add func and its gradient to host graph + TF_GraphCopyFunction(host_graph_, func_, grad_func, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Verify that function and its grad are in host graph's GraphDef + GraphDef gdef; + GetGraphDef(host_graph_, &gdef); + std::vector<string> func_names = GetFuncNames(gdef); + ASSERT_EQ(2, func_names.size()); + ASSERT_EQ(func_name_, func_names[0]); + ASSERT_EQ("MyGrad", func_names[1]); + std::vector<std::pair<string, string>> grads = GetGradDefs(gdef); + ASSERT_EQ(1, grads.size()); + ASSERT_EQ(func_name_, grads[0].first); + ASSERT_EQ("MyGrad", grads[0].second); + + // These calls must be noops + TF_GraphCopyFunction(host_graph_, func_, grad_func, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Delete the gradient func. + // It is safe to delete after adding a copy to host graph. + TF_DeleteFunction(grad_func); + + // Check that GraphDef did not change + GraphDef gdef2; + GetGraphDef(host_graph_, &gdef2); + ASSERT_EQ(gdef.DebugString(), gdef2.DebugString()); + + // Use and run func + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, -3); +} + +TEST_F(CApiFunctionTest, SameGradForTwoFunctions) { + // Define the functions + TF_Function* func1; + TF_Function* func2; + TF_Function* grad_func; + DefineFunction("FooFunc1", &func1); + DefineFunction("FooFunc2", &func2); + DefineFunction("MyGrad", &grad_func); + + // Make grad_func be a gradient of func1 and func2 + TF_GraphCopyFunction(host_graph_, func1, grad_func, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_GraphCopyFunction(host_graph_, func2, grad_func, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Verify that functions and their gradients are in host graph's GraphDef + GraphDef gdef; + GetGraphDef(host_graph_, &gdef); + std::vector<std::pair<string, string>> grads = GetGradDefs(gdef); + ASSERT_EQ(2, grads.size()); + ASSERT_EQ("FooFunc1", grads[0].first); + ASSERT_EQ("MyGrad", grads[0].second); + ASSERT_EQ("FooFunc2", grads[1].first); + ASSERT_EQ("MyGrad", grads[1].second); + + TF_DeleteFunction(func1); + TF_DeleteFunction(func2); + TF_DeleteFunction(grad_func); +} + +TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) { + // Define the functions + TF_Function* func; + TF_Function* grad_func; + DefineFunction("FooFunc", &func); + DefineFunction("MyGrad", &grad_func); + + // Add functions individually + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Check that functions are added but not linked + GraphDef gdef; + GetGraphDef(host_graph_, &gdef); + std::vector<string> func_names = GetFuncNames(gdef); + ASSERT_EQ(2, func_names.size()); + ASSERT_EQ("FooFunc", func_names[0]); + ASSERT_EQ("MyGrad", func_names[1]); + ASSERT_EQ(0, GetGradDefs(gdef).size()); + + // Make grad_func a gradient of func + TF_GraphCopyFunction(host_graph_, func, grad_func, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Verify that function and its grad are linked + gdef.Clear(); + GetGraphDef(host_graph_, &gdef); + std::vector<std::pair<string, string>> grads = GetGradDefs(gdef); + ASSERT_EQ(1, grads.size()); + ASSERT_EQ("FooFunc", grads[0].first); + ASSERT_EQ("MyGrad", grads[0].second); + + TF_DeleteFunction(func); + TF_DeleteFunction(grad_func); +} + +TEST_F(CApiFunctionTest, GradientErrorCases) { + // Define the function + DefineFunction(func_name_, &func_); + TF_Function* grad_func1; + TF_Function* grad_func2; + DefineFunction("MyGrad1", &grad_func1); + DefineFunction("MyGrad2", &grad_func2); + + // func cannot be null + TF_GraphCopyFunction(host_graph_, nullptr, func_, s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"), + string(TF_Message(s_))); + + // Cannot change gradient + TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' " + "because it already has gradient function 'MyGrad1'"), + string(TF_Message(s_))); + + TF_DeleteFunction(grad_func1); + TF_DeleteFunction(grad_func2); +} + } // namespace } // namespace tensorflow |