diff options
author | Derek Murray <mrry@google.com> | 2017-12-13 08:56:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-13 09:00:51 -0800 |
commit | 185c593cb71cb6d8116ba05c97e9385642648f1b (patch) | |
tree | 8853277b53c58d69e93735be50e75f4f5afa9516 /tensorflow/c/c_api_function_test.cc | |
parent | 2b1b7dffcd2c76876efdbcfc431424e259da3bf4 (diff) |
Automated g4 rollback of changelist 178759398
PiperOrigin-RevId: 178909147
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b6589..2e2293ca85 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast<const char*>(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } |