diff options
author | Mark Heffernan <meheff@google.com> | 2018-07-17 18:01:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-17 18:07:00 -0700 |
commit | 07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (patch) | |
tree | c09bbe69d49db3b6a6acd6f38545915ae2369035 /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | 1662a105497e60d002e101161987cbbd48ba06c6 (diff) |
Add single-sided host send and receive operations.
Adds a bit on kSend/kReceive instructions and their Done variants indicated whether the operations communicates with the host or another device (the default). Host send/recv operations are single-sided without a complementary recv/send instruction in another module.
Host send/recv operations are exposed in the XLA builder API as SendToHost and RecvFromHost.
PiperOrigin-RevId: 205008138
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 40 |
1 files changed, 33 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index d387539350..496eca0739 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -119,6 +119,7 @@ class HloParser { // Types of attributes. enum class AttrTy { + kBool, kInt64, kInt32, kFloat, @@ -681,18 +682,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kRecv: { optional<tensorflow::int64> channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional<bool> is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } + // If the is_host_transfer attribute is not present then default to false. instruction = builder->AddInstruction(HloInstruction::CreateRecv( - shape.tuple_shapes(0), operands[0], *channel_id)); + shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer)); break; } case HloOpcode::kRecvDone: { optional<tensorflow::int64> channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional<bool> is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -700,24 +710,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kSend: { optional<tensorflow::int64> channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional<bool> is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], operands[1], *channel_id)); + instruction = builder->AddInstruction(HloInstruction::CreateSend( + operands[0], operands[1], *channel_id, *is_host_transfer)); break; } case HloOpcode::kSendDone: { optional<tensorflow::int64> channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional<bool> is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -725,8 +743,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateSendDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kGetTupleElement: { @@ -2043,6 +2061,14 @@ bool HloParser::ParseAttributeHelper( bool success = [&] { LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { + case AttrTy::kBool: { + bool result; + if (!ParseBool(&result)) { + return false; + } + static_cast<optional<bool>*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kInt64: { tensorflow::int64 result; if (!ParseInt64(&result)) { |