aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2017-12-18 09:16:40 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-18 11:56:05 -0800
commit6548e417f8d26e81d10ee577f8575b1cebc443a8 (patch)
tree7da9be3e8f0e0319672090f61d43913a59e6b317 /tensorflow/core/common_runtime/function_test.cc
parent511181cc1c4e70330ad46f4dbcabc511d1a9af4a (diff)
Prune unused stateless nodes from function bodies.
Previously, all nodes in a TensorFlow function would be executed unconditionally, which led to surprising performance issues (such as executing a expensive image summary op that was created but unused in a preprocessing function). We can prune nodes that are not reverse-reachable from the return values of a function if they are stateless and are not reverse-reachable from a stateful node. PiperOrigin-RevId: 179430810
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc60
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 52bfb9e0ed..7b553c2dcd 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@@ -566,6 +567,65 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) {
}
}
+TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
+ auto T = DT_INT32;
+ FunctionDef stateful_func = FDH::Define(
+ // Name
+ "SquareAndAddOneWithStatefulNodes",
+ // Args
+ {"x: int32"},
+ // Return values
+ {"y: int32"},
+ // Attrs
+ {},
+ // Nodes
+ {// a = Square<T>(x)
+ {{"a"}, "Square", {"x"}, {{"T", T}}},
+ // 1
+ FDH::Const("o", 1),
+ // A bunch of extra arithmetic that y doesn't depend on
+ {{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
+ {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
+ {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
+ FDH::Const<int32>("shape", {1, 2}),
+ // A stateful node.
+ {{"keep_me"},
+ "RandomUniform",
+ {"shape"},
+ {{"T", T}, {"dtype", DT_FLOAT}}},
+ // y = Add<T>(a, o)
+ {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ Init({stateful_func});
+
+ auto x = test::AsTensor<int32>({1, 2, 3, 4});
+ Tensor y;
+
+ FunctionLibraryRuntime::Handle handle;
+ TF_CHECK_OK(
+ Instantiate(flr0_, "SquareAndAddOneWithStatefulNodes", {}, &handle));
+
+ StepStats stats;
+ StepStatsCollector stats_collector(&stats);
+ FunctionLibraryRuntime::Options opts;
+ opts.stats_collector = &stats_collector;
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
+ {x}, {&y}));
+ test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+
+ stats_collector.FinalizeAndSwap(&stats);
+
+ // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ std::set<string> expected_node_names(
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ std::set<string> executed_node_names;
+ for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
+ executed_node_names.insert(node_stats.node_name());
+ }
+ EXPECT_EQ(expected_node_names, executed_node_names);
+}
+
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});