diff options
author | 2018-09-18 15:42:44 -0700 | |
---|---|---|
committer | 2018-09-18 15:46:41 -0700 | |
commit | e1a32c98210f8ebba42a0397259d948e1433c09e (patch) | |
tree | 0a62289fe29cf2c0f481bdca90c03811a47caadf /tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | |
parent | 6c8f6920e8bad10429ac0b88abbe0ace5a5e9a72 (diff) |
"Isolate" must-be-constant side effecting operations
I first tried to fix this issue in cr/209996730 but didn't quite fix the problem
for for XLA_* devices. A node assigned to an XLA_* device must be compiled so
the cr/209996730 fix of simply not compiling the nodes doesn't generalize to
XLA_* devices. Instead we now "isolate" these nodes, only putting them in a
trivial one-node cluster. For non-XLA devices even this trivial cluster is
ignored because of flags->tf_xla_min_cluster_size.
I was initially considering a more principled data-flow-analysis based solution
but then decided the upfront work isn't worth it until I see a clear motivating
example.
PiperOrigin-RevId: 213531437
Diffstat (limited to 'tensorflow/compiler/jit/mark_for_compilation_pass_test.cc')
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index c59770a4c8..4f9145b479 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -894,5 +894,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { EXPECT_EQ(clusters["fn_call"], ""); } +TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = + ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, + ops::Const(root.WithOpName("test/minval"), 1), + ops::Const(root.WithOpName("test/maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/shape_rng"], ""); + EXPECT_NE(clusters["test/reshape"], ""); + EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); +} + +TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + Scope root = Scope::NewRootScope().ExitOnError(); + ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, + DT_INT32); + Output zero = ops::Const(root.WithOpName("test/zero"), 0); + ops::TensorArrayWrite tensor_array_write( + root.WithOpName("test/write"), tensor_array.handle, zero, + ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); + Output tensor_array_read = + ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, + zero, tensor_array_write.flow_out, DT_INT32); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), + tensor_array_read); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/read"], ""); + EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); +} + } // namespace } // namespace tensorflow |