diff options
author | Igor Ganichev <iga@google.com> | 2017-09-19 14:48:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-19 14:52:35 -0700 |
commit | d67679f1aee7c037fd9c2ac35121720133cd5bd9 (patch) | |
tree | afbd64028c999013886dab1bbe8b45e55b31810b /tensorflow/c/c_api_function_test.cc | |
parent | 74680a3904f88238f58f9566d8bd8e80c3f9dca4 (diff) |
Implement TF_FunctionImportFunctionDef
PiperOrigin-RevId: 169304057
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 8cd910ccbc..9b0279dc17 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -354,6 +356,22 @@ class CApiFunctionTest : public ::testing::Test { VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges); } + // Serialize func_ to fdef and import it back + void Reincarnate() { + // func_ -> fdef + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + TF_DeleteFunction(func_); + + // fdef -> func_ + TF_Buffer* buf = TF_NewBuffer(); + Status s = MessageToBuffer(fdef, buf); + ASSERT_EQ(Status::OK(), s) << s.error_message(); + func_ = TF_FunctionImportFunctionDef(buf, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_DeleteBuffer(buf); + } + const char* func_name_ = "MyFunc"; const char* func_node_name_ = "MyFunc_0"; TF_Status* s_; @@ -1331,5 +1349,62 @@ TEST_F(CApiFunctionTest, GradientErrorCases) { TF_DeleteFunction(grad_func2); } +TEST_F(CApiFunctionTest, ImportFunctionDef) { + /* + * Using a fairly complex function with output names + * + * | | | + * v v / + * add / + * | | + * +------+ | + * | | | + * | v v + * | add + * | | + * v v + * internal_out final_out + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, + {"internal_out", "final_out"}); + + // Save func_ to FunctionDef and import it back + Reincarnate(); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, ten, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15}); + VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), + M({{"internal_out"}, {"final_out"}}), + {{"feed1", "add1:0"}, + {"feed2", "add1:1"}, + {"add1:sum:0", "add2:0"}, + {"feed3", "add2:1"}, + {"add1:sum:0", "internal_out"}, + {"add2:sum:0", "final_out"}}, + {}); +} + +TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) { + // Invalid protobuf data (protos cannot start with 4 bytes of zeros) + char proto[] = {0x0, 0x0, 0x0, 0x0}; + TF_Buffer* buf = TF_NewBufferFromString(proto, 4); + func_ = TF_FunctionImportFunctionDef(buf, s_); + TF_DeleteBuffer(buf); + EXPECT_TRUE(func_ == nullptr); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"), + string(TF_Message(s_))); +} + } // namespace } // namespace tensorflow |