diff options
author | Derek Murray <mrry@google.com> | 2018-09-06 14:13:56 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 14:23:53 -0700 |
commit | 76a5936cd283d9a32c89635577b2da9c8e46785b (patch) | |
tree | 2c34f17ff00e23c2e8f06fb1fbee3235c9b5ae42 /tensorflow/core/common_runtime/function_test.cc | |
parent | 64fd29ca227707a4c6212638346a6b92885bf18a (diff) |
Enable unused "_Arg" nodes to be pruned from a function body.
Previously, because "_Arg" nodes are considered to be "stateful", these nodes were unconditionally included in the seed set of nodes for pruning a function body. Since an "_Arg" node has no visible side effect, we can safely prune these, which makes small projection functions (like `lambda x, y: y`) more efficient.
PiperOrigin-RevId: 211867380
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 120f480198..7bab9be9a6 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { // Name "SquareAndAddOneWithStatefulNodes", // Args - {"x: int32"}, + {"x: int32", "y: float32"}, // Return values - {"y: int32"}, + {"z: int32"}, // Attrs {}, // Nodes @@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { "RandomUniform", {"shape"}, {{"T", T}, {"dtype", DT_FLOAT}}}, - // y = Add<T>(a, o) - {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + // z = Add<T>(a, o) + {{"z"}, "Add", {"a", "o"}, {{"T", T}}}}); Init({stateful_func}); auto x = test::AsTensor<int32>({1, 2, 3, 4}); - Tensor y; + auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0}); + Tensor z; FunctionLibraryRuntime::Handle handle; TF_CHECK_OK( @@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { StepStatsCollector stats_collector(&stats); FunctionLibraryRuntime::Options opts; opts.stats_collector = &stats_collector; - TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y})); + TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z})); TF_CHECK_OK(flr0_->ReleaseHandle(handle)); TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {}, - {x}, {&y})); - test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17})); + {x, y}, {&z})); + test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17})); stats_collector.FinalizeAndSwap(&stats); - // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute. + // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to + // execute. std::set<string> expected_node_names( - {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"}); + {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"}); std::set<string> executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); |