aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-18 15:42:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 15:46:41 -0700
commite1a32c98210f8ebba42a0397259d948e1433c09e (patch)
tree0a62289fe29cf2c0f481bdca90c03811a47caadf /tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
parent6c8f6920e8bad10429ac0b88abbe0ace5a5e9a72 (diff)
"Isolate" must-be-constant side effecting operations
I first tried to fix this issue in cr/209996730 but didn't quite fix the problem for for XLA_* devices. A node assigned to an XLA_* device must be compiled so the cr/209996730 fix of simply not compiling the nodes doesn't generalize to XLA_* devices. Instead we now "isolate" these nodes, only putting them in a trivial one-node cluster. For non-XLA devices even this trivial cluster is ignored because of flags->tf_xla_min_cluster_size. I was initially considering a more principled data-flow-analysis based solution but then decided the upfront work isn't worth it until I see a clear motivating example. PiperOrigin-RevId: 213531437
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc66
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a4c8..4f9145b479 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -894,5 +894,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) {
EXPECT_EQ(clusters["fn_call"], "");
}
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape =
+ ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+ ops::Const(root.WithOpName("test/minval"), 1),
+ ops::Const(root.WithOpName("test/maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/shape_rng"], "");
+ EXPECT_NE(clusters["test/reshape"], "");
+ EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+ DT_INT32);
+ Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+ ops::TensorArrayWrite tensor_array_write(
+ root.WithOpName("test/write"), tensor_array.handle, zero,
+ ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+ Output tensor_array_read =
+ ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+ zero, tensor_array_write.flow_out, DT_INT32);
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"),
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+ tensor_array_read);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/read"], "");
+ EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
} // namespace
} // namespace tensorflow