aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-06 14:13:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 14:23:53 -0700
commit76a5936cd283d9a32c89635577b2da9c8e46785b (patch)
tree2c34f17ff00e23c2e8f06fb1fbee3235c9b5ae42 /tensorflow/core/common_runtime/function_test.cc
parent64fd29ca227707a4c6212638346a6b92885bf18a (diff)
Enable unused "_Arg" nodes to be pruned from a function body.
Previously, because "_Arg" nodes are considered to be "stateful", these nodes were unconditionally included in the seed set of nodes for pruning a function body. Since an "_Arg" node has no visible side effect, we can safely prune these, which makes small projection functions (like `lambda x, y: y`) more efficient. PiperOrigin-RevId: 211867380
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc22
1 files changed, 12 insertions, 10 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480198..7bab9be9a6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
- {"x: int32"},
+ {"x: int32", "y: float32"},
// Return values
- {"y: int32"},
+ {"z: int32"},
// Attrs
{},
// Nodes
@@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
- // y = Add<T>(a, o)
- {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ // z = Add<T>(a, o)
+ {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- Tensor y;
+ auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+ Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
@@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
- TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
- {x}, {&y}));
- test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+ {x, y}, {&z}));
+ test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
- // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+ // execute.
std::set<string> expected_node_names(
- {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());