aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-30 14:05:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-30 14:09:44 -0700
commitefcbf6e34e4519172d38be76c08c2d99792fd7be (patch)
treed2226010e7fd0548f9ca137052959fd55006cf1e /tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
parent682a6ed64f961d73ecdde5c3b80c6188fedcf5ee (diff)
Supported in this CL:
* Attaching sharding descriptors to HLO ops * Partitioning the HLO graph into per-device computations based on those sharding descriptors. * All operator support for device placement and ops replicated on all devices. * Elementwise op support for tiled shardings. * 2D Convolution support for tiled shardings (no stride or dilation support). PiperOrigin-RevId: 173946036
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 2007a8f11d..06abe00747 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -198,9 +198,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
- if (instruction->device_assignment().has_device()) {
- node_def->set_device(
- GetDeviceName(instruction->device_assignment().device()));
+ if (instruction->has_sharding() &&
+ instruction->sharding().HasUniqueDevice()) {
+ TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
+ node_def->set_device(GetDeviceName(device));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {