aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-07-17 18:01:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 18:07:00 -0700
commit07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (patch)
treec09bbe69d49db3b6a6acd6f38545915ae2369035 /tensorflow/compiler/xla/service/hlo_parser.cc
parent1662a105497e60d002e101161987cbbd48ba06c6 (diff)
Add single-sided host send and receive operations.
Adds a bit on kSend/kReceive instructions and their Done variants indicated whether the operations communicates with the host or another device (the default). Host send/recv operations are single-sided without a complementary recv/send instruction in another module. Host send/recv operations are exposed in the XLA builder API as SendToHost and RecvFromHost. PiperOrigin-RevId: 205008138
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc40
1 files changed, 33 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index d387539350..496eca0739 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -119,6 +119,7 @@ class HloParser {
// Types of attributes.
enum class AttrTy {
+ kBool,
kInt64,
kInt32,
kFloat,
@@ -681,18 +682,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kRecv: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
+ // If the is_host_transfer attribute is not present then default to false.
instruction = builder->AddInstruction(HloInstruction::CreateRecv(
- shape.tuple_shapes(0), operands[0], *channel_id));
+ shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kRecvDone: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
@@ -700,24 +710,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (channel_id != operands[0]->channel_id()) {
return false;
}
- instruction =
- builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kSend: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], operands[1], *channel_id));
+ instruction = builder->AddInstruction(HloInstruction::CreateSend(
+ operands[0], operands[1], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kSendDone: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
@@ -725,8 +743,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (channel_id != operands[0]->channel_id()) {
return false;
}
- instruction =
- builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kGetTupleElement: {
@@ -2043,6 +2061,14 @@ bool HloParser::ParseAttributeHelper(
bool success = [&] {
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
+ case AttrTy::kBool: {
+ bool result;
+ if (!ParseBool(&result)) {
+ return false;
+ }
+ static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
case AttrTy::kInt64: {
tensorflow::int64 result;
if (!ParseInt64(&result)) {