aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function_test.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-12-13 08:56:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 09:00:51 -0800
commit185c593cb71cb6d8116ba05c97e9385642648f1b (patch)
tree8853277b53c58d69e93735be50e75f4f5afa9516 /tensorflow/c/c_api_function_test.cc
parent2b1b7dffcd2c76876efdbcfc431424e259da3bf4 (diff)
Automated g4 rollback of changelist 178759398
PiperOrigin-RevId: 178909147
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r--tensorflow/c/c_api_function_test.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index d5580b6589..2e2293ca85 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) {
EXPECT_EQ(op_def.name(), func_name_);
EXPECT_EQ(op_def.input_arg_size(), 1);
EXPECT_EQ(op_def.output_arg_size(), 1);
+ EXPECT_FALSE(op_def.is_stateful());
+
+ TF_DeleteBuffer(buffer);
+}
+
+void DefineStatefulFunction(const char* name, TF_Function** func) {
+ std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
+ TF_NewGraph(), TF_DeleteGraph);
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
+ TF_DeleteStatus);
+
+ TF_Tensor* tensor_shape = Int32Tensor({37, 1});
+ TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
+ TF_Operation* random =
+ RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
+
+ TF_Output inputs[] = {};
+ TF_Output outputs[] = {{random, 0}};
+ *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1,
+ /*opers=*/nullptr, 0, inputs, 1, outputs,
+ /*output_names=*/nullptr,
+ /*opts=*/nullptr, "", s.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
+ ASSERT_NE(*func, nullptr);
+ TF_DeleteTensor(tensor_shape);
+}
+
+TEST_F(CApiFunctionTest, StatefulOpDef) {
+ DefineStatefulFunction(func_name_, &func_);
+ TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Test we can retrieve function OpDef from graph
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Sanity check returned OpDef
+ string data(static_cast<const char*>(buffer->data), buffer->length);
+ OpDef op_def;
+ op_def.ParseFromString(data);
+ EXPECT_EQ(op_def.name(), func_name_);
+ EXPECT_EQ(op_def.input_arg_size(), 0);
+ EXPECT_EQ(op_def.output_arg_size(), 1);
+ EXPECT_TRUE(op_def.is_stateful());
TF_DeleteBuffer(buffer);
}