diff options
author | Derek Murray <mrry@google.com> | 2017-12-11 14:41:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-11 14:45:28 -0800 |
commit | 037f036b2c76ef363148276dce83b7dd1d79e878 (patch) | |
tree | 359d2af4fcec51ccb60d42409ce09b03000cfc8e /tensorflow/c/c_api_function_test.cc | |
parent | dd77f385591c8b6ef7ab8dae7429c7eff7813a1e (diff) |
Mark a FunctionDef's signature as stateful when it contains a stateful node.
This fixes a bug where two calls to the same stateful function will erroneously be eliminated as common subexpressions. It is also a step towards pruning nodes from function bodies, which is necessary for a variety of `Dataset` optimizations.
PiperOrigin-RevId: 178675527
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b6589..4ffc9d6931 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/0, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast<const char*>(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } |