diff options
author | 2018-10-01 10:41:58 -0700 | |
---|---|---|
committer | 2018-10-01 10:48:05 -0700 | |
commit | 57a831d20929e71279d164905fed93e1f518ee37 (patch) | |
tree | ddf8a2f66aa82f64ce2091348efe83a9ca467efc /tensorflow/compiler/jit | |
parent | a6478312ef296ba9684931135851e9c7bb460444 (diff) |
Bugfix: When a subgraph is encapsulated and replaced by XlaLaunch op, the requested device placement of the XlaLaunch op must be derived from the subgraph.
PiperOrigin-RevId: 215239672
Diffstat (limited to 'tensorflow/compiler/jit')
3 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e0632ff7e4..15faf31077 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -748,6 +748,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { graph_->set_versions(graph_in->versions()); } + // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is + // determined. In case of hard placement, ensure all the encapsulated nodes + // have the same requested device, which in turn will be the requested device + // for the entire encapsulated subgraph. In case of soft placement, use a + // deterministic approach to fill in the requested device. Handle co-location + // constraints similarly if they exist. if (device_.empty()) { device_ = node->assigned_device_name().empty() ? node->requested_device() diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 97ef8cd3cb..755c364c62 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; + VLOG(2) << "Device is " << launch->requested_device(); def.set_op("XlaLaunch"); + def.set_device(launch->requested_device()); AddNodeAttr("Tconstants", DataTypeVector{}, &def); AddNodeAttr("Targs", arg_types, &def); AddNodeAttr("Nresources", num_variables, &def); diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index f643fb0cfe..479038ac8e 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -55,6 +55,7 @@ static std::unique_ptr<Graph> MakeOuterGraph( .Input(u.node()->name(), 0, DT_RESOURCE) .Input(v.node()->name(), 0, DT_RESOURCE) .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") .Attr("_variable_start_index", 4) .Finalize(&def)); @@ -107,10 +108,11 @@ static std::unique_ptr<Graph> MakeBodyGraph() { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - + add_attrs(b_identity.node()); auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); add_attrs(read_u.node()); auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); @@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); @@ -317,8 +320,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { NameAttrList function; function.set_name("launch0"); auto launch = ops::XlaLaunch( - scope.WithOpName("launch0"), std::initializer_list<Input>{}, - std::initializer_list<Input>{a, b, c, d}, + scope.WithOpName("launch0").WithDevice("/gpu:0"), + std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d}, std::initializer_list<Input>{u, v, w}, DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); |