aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/deadness_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/deadness_analysis_test.cc')
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc443
1 files changed, 443 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
new file mode 100644
index 0000000000..584385cab7
--- /dev/null
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -0,0 +1,443 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/deadness_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status AnalyzeDeadness(Graph* graph,
+ std::unique_ptr<DeadnessAnalysis>* result) {
+ FixupSourceAndSinkEdges(graph);
+ return DeadnessAnalysis::Run(*graph, result);
+}
+
+ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
+ Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
+ Output predicate =
+ ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
+ return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
+}
+
+Output CreateInductionVariable(const Scope& root, const string& prefix,
+ const string& frame_name, int32 init) {
+ Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
+ Output enter_initial_value = ops::internal::Enter(
+ root.WithOpName(prefix + "/enter"), initial_value, frame_name);
+
+ ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
+ Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
+ Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
+ Output loop_cond_expr =
+ ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value);
+ Output loop_cond =
+ ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
+ ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
+ ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
+ Output iv_next =
+ ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
+ Output next_iteration =
+ ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
+
+ root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
+ root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
+ root.graph()->AddControlEdge(iv.output.node(), final_value.node());
+
+ return iv.output;
+}
+
+TEST(DeadnessAnalysisTest, BasicPositive) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw = CreateSwitch(root, "0");
+ Output add =
+ ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, BasicNegative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
+ Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), a, b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndIsCommutative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+
+ Output a0 =
+ ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
+ Output a1 =
+ ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
+
+ Output b0 =
+ ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
+ Output b1 =
+ ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
+
+ Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
+ Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
+
+ Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
+ Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndIsAssociative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ Output a0 =
+ ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
+ Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
+
+ Output b0 =
+ ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
+ Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
+
+ Output add = ops::Add(root.WithOpName("add"), a1, b1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrIsCommutative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
+ ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
+ ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
+
+ Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
+ Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
+
+ Output halfdead0 =
+ ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
+ Output halfdead1 =
+ ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrIsAssociative) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
+ ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
+ ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
+
+ Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, AndOfOr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+ ops::Switch sw_3 = CreateSwitch(root, "3");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
+
+ Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
+ Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
+
+ Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, OrOfAnd) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+ ops::Switch sw_3 = CreateSwitch(root, "3");
+
+ Output add0 =
+ ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
+ Output add1 =
+ ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
+
+ ops::Merge m0(root.WithOpName("m0"), {add0, add1});
+ ops::Merge m1(root.WithOpName("m1"), {add0, add1});
+
+ Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
+ // This demonstrates one of the weaknesses in the current approach -- since we
+ // only do some basic simplifications we can't see that "(A|B)&C" ==
+ // "(A&C)|(B&C)".
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "0");
+ ops::Switch sw_1 = CreateSwitch(root, "1");
+ ops::Switch sw_2 = CreateSwitch(root, "2");
+
+ ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
+ Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
+
+ Output add1 =
+ ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
+ Output add2 =
+ ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
+ ops::Merge m1(root.WithOpName("m1"), {add1, add2});
+
+ Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+}
+
+TEST(DeadnessAnalysisTest, Ternary) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
+ Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
+ Output false_value =
+ ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
+
+ ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
+ predicate);
+
+ ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
+ predicate);
+ ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
+ predicated_false.output_false});
+ Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, Recv) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
+ "sender", 0, "receiver");
+ Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
+ "sender", 0, "receiver");
+ Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, HostRecv) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
+ "tensor_a", "sender", 0, "receiver");
+ Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
+ "tensor_b", "sender", 0, "receiver");
+ Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, Loop) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
+ Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
+ Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
+ Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
+ Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
+ // noticed that. Today we are pessimistic here because we assign an
+ // uninterpreted symbol to merges with backedges.
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlInputs) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ Output const0 = ops::Const(root.WithOpName("const0"), 1);
+ Output const1 = ops::Const(root.WithOpName("const1"), 2);
+
+ Output add = ops::Add(root.WithOpName("add"), const0, const1);
+
+ root.graph()->AddControlEdge(id0.node(), const0.node());
+ root.graph()->AddControlEdge(id1.node(), const1.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlTrigger) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
+ ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
+
+ Output const0 = ops::Const(root.WithOpName("const0"), 1);
+ Output const1 = ops::Const(root.WithOpName("const1"), 2);
+
+ Output add = ops::Add(root.WithOpName("add"), const0, const1);
+
+ root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
+ root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
+
+ root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
+ root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::Switch sw = CreateSwitch(root, "0");
+
+ Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
+ Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
+
+ Output constant = ops::Const(root.WithOpName("constant"), 5);
+ ops::Merge m0(root.WithOpName("m0"), {constant});
+ ops::Merge m1(root.WithOpName("m0"), {constant});
+ Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
+
+ root.graph()->AddControlEdge(id0.node(), m0.output.node());
+ root.graph()->AddControlEdge(id1.node(), m1.output.node());
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
+}
+
+TEST(DeadnessAnalysisTest, RecvVsSwitch) {
+ // Demonstrates why we need the must_be_true bit on SymbolP.
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
+ 0, "receiver");
+ Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
+ ops::Switch sw(root.WithOpName("switch"), value, recv);
+ Output logical_and =
+ ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
+}
+
+} // namespace
+} // namespace tensorflow