aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function_test.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-19 14:48:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 14:52:35 -0700
commitd67679f1aee7c037fd9c2ac35121720133cd5bd9 (patch)
treeafbd64028c999013886dab1bbe8b45e55b31810b /tensorflow/c/c_api_function_test.cc
parent74680a3904f88238f58f9566d8bd8e80c3f9dca4 (diff)
Implement TF_FunctionImportFunctionDef
PiperOrigin-RevId: 169304057
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r--tensorflow/c/c_api_function_test.cc75
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 8cd910ccbc..9b0279dc17 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -354,6 +356,22 @@ class CApiFunctionTest : public ::testing::Test {
VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
}
+ // Serialize func_ to fdef and import it back
+ void Reincarnate() {
+ // func_ -> fdef
+ tensorflow::FunctionDef fdef;
+ ASSERT_TRUE(GetFunctionDef(func_, &fdef));
+ TF_DeleteFunction(func_);
+
+ // fdef -> func_
+ TF_Buffer* buf = TF_NewBuffer();
+ Status s = MessageToBuffer(fdef, buf);
+ ASSERT_EQ(Status::OK(), s) << s.error_message();
+ func_ = TF_FunctionImportFunctionDef(buf, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_DeleteBuffer(buf);
+ }
+
const char* func_name_ = "MyFunc";
const char* func_node_name_ = "MyFunc_0";
TF_Status* s_;
@@ -1331,5 +1349,62 @@ TEST_F(CApiFunctionTest, GradientErrorCases) {
TF_DeleteFunction(grad_func2);
}
+TEST_F(CApiFunctionTest, ImportFunctionDef) {
+ /*
+ * Using a fairly complex function with output names
+ *
+ * | | |
+ * v v /
+ * add /
+ * | |
+ * +------+ |
+ * | | |
+ * | v v
+ * | add
+ * | |
+ * v v
+ * internal_out final_out
+ */
+ // Define
+ TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
+ TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
+ TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
+ TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
+ TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
+ Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
+ {"internal_out", "final_out"});
+
+ // Save func_ to FunctionDef and import it back
+ Reincarnate();
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
+ TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, ten, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
+ VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
+ M({{"internal_out"}, {"final_out"}}),
+ {{"feed1", "add1:0"},
+ {"feed2", "add1:1"},
+ {"add1:sum:0", "add2:0"},
+ {"feed3", "add2:1"},
+ {"add1:sum:0", "internal_out"},
+ {"add2:sum:0", "final_out"}},
+ {});
+}
+
+TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
+ // Invalid protobuf data (protos cannot start with 4 bytes of zeros)
+ char proto[] = {0x0, 0x0, 0x0, 0x0};
+ TF_Buffer* buf = TF_NewBufferFromString(proto, 4);
+ func_ = TF_FunctionImportFunctionDef(buf, s_);
+ TF_DeleteBuffer(buf);
+ EXPECT_TRUE(func_ == nullptr);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
+ string(TF_Message(s_)));
+}
+
} // namespace
} // namespace tensorflow