aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 10:41:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 10:48:05 -0700
commit57a831d20929e71279d164905fed93e1f518ee37 (patch)
treeddf8a2f66aa82f64ce2091348efe83a9ca467efc /tensorflow/compiler/jit
parenta6478312ef296ba9684931135851e9c7bb460444 (diff)
Bugfix: When a subgraph is encapsulated and replaced by XlaLaunch op, the requested device placement of the XlaLaunch op must be derived from the subgraph.
PiperOrigin-RevId: 215239672
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc6
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc2
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc9
3 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index e0632ff7e4..15faf31077 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -748,6 +748,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
graph_->set_versions(graph_in->versions());
}
+ // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
+ // determined. In case of hard placement, ensure all the encapsulated nodes
+ // have the same requested device, which in turn will be the requested device
+ // for the entire encapsulated subgraph. In case of soft placement, use a
+ // deterministic approach to fill in the requested device. Handle co-location
+ // constraints similarly if they exist.
if (device_.empty()) {
device_ = node->assigned_device_name().empty()
? node->requested_device()
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 97ef8cd3cb..755c364c62 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Target the XLA CPU/GPU backends.
VLOG(2) << "Replacing with XlaLaunch";
+ VLOG(2) << "Device is " << launch->requested_device();
def.set_op("XlaLaunch");
+ def.set_device(launch->requested_device());
AddNodeAttr("Tconstants", DataTypeVector{}, &def);
AddNodeAttr("Targs", arg_types, &def);
AddNodeAttr("Nresources", num_variables, &def);
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index f643fb0cfe..479038ac8e 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -55,6 +55,7 @@ static std::unique_ptr<Graph> MakeOuterGraph(
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
+ .Device("/gpu:0")
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
@@ -107,10 +108,11 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
-
+ add_attrs(b_identity.node());
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
add_attrs(read_u.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
@@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
@@ -317,8 +320,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
NameAttrList function;
function.set_name("launch0");
auto launch = ops::XlaLaunch(
- scope.WithOpName("launch0"), std::initializer_list<Input>{},
- std::initializer_list<Input>{a, b, c, d},
+ scope.WithOpName("launch0").WithDevice("/gpu:0"),
+ std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
std::initializer_list<Input>{u, v, w},
DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);