aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc70
1 files changed, 25 insertions, 45 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 2c5f4fb774..a780d4a936 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@@ -39,27 +39,6 @@ namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
-Status MarkForCompilation(std::unique_ptr<Graph>* graph,
- FunctionLibraryDefinition* flib_def) {
- // Assign all nodes to the CPU device.
- static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
- for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
- }
-
- GraphOptimizationPassOptions opt_options;
- opt_options.graph = graph;
- opt_options.flib_def = flib_def;
- MarkForCompilationPass pass;
- return pass.RunImpl(opt_options);
-}
-
-Status MarkForCompilation(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
- return MarkForCompilation(graph, &flib_def);
-}
-
std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
@@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
@@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size());
@@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
@@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_FALSE(clusters.empty());
}
@@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
@@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// Nothing should be compiled. In particular, 'd' and 'c' must not be
@@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
@@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
@@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
@@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) {
ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
TF_EXPECT_OK(root.ToGraph(graph.get()));
- Status status = MarkForCompilation(&graph);
+ Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.ToString(),
"Edge from c to a would create a cycle.\n"
@@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
auto r = ops::_Retval(root.WithOpName("R"), c, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), 0.5f);
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_EQ(1, GetClusters(*graph).size());
}
@@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), string("string"));
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_TRUE(GetClusters(*graph).empty());
}
}
@@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);