aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc34
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc20
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc5
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc17
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc29
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc2
15 files changed, 152 insertions, 102 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 09e7e87918..a0004281cb 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1850,18 +1850,19 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
- // Send instruction produces a tuple of {aliased operand, U32 context}.
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
- *send_instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
send_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp send,
AddInstruction(std::move(send_instr), HloOpcode::kSend,
{operand, token}));
HloInstructionProto send_done_instr;
- *send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
send_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
{send});
@@ -1879,19 +1880,32 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
HloInstructionProto recv_instr;
- *recv_instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
recv_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
- *recv_done_instr.mutable_shape() = shape;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
recv_done_instr.set_channel_id(handle.handle());
- return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
- {recv});
+ TF_ASSIGN_OR_RETURN(XlaOp recv_done,
+ AddInstruction(std::move(recv_done_instr),
+ HloOpcode::kRecvDone, {recv}));
+
+ // The RecvDone instruction produces a tuple of the data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto recv_data;
+ *recv_data.mutable_shape() = shape;
+ recv_data.set_tuple_index(0);
+ return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
+ {recv_done});
});
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 34b18b0e21..e36bef60a3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -284,9 +284,8 @@ void HloComputation::set_root_instruction(
if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
- << new_root_instruction->shape().ShortDebugString()
- << " is incompatible with "
- << root_instruction_->shape().ShortDebugString();
+ << new_root_instruction->shape() << " is incompatible with "
+ << root_instruction_->shape();
}
bool root_found = false;
for (auto& instruction : instructions_) {
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 8a4a9b5986..ebed2cfc59 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -398,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
bool changed = false;
- // RecvDone forwards the operand value at {0} to the output.
+ // RecvDone forwards the operand value at {0} to element {0} of its output.
for (auto& pair : GetInstructionValueSet(recv_done)) {
ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
- ShapeIndex operand_index = {0};
- for (int64 i : index) {
- operand_index.push_back(i);
+ if (index.empty() || index[0] != 0) {
+ continue;
}
const HloValueSet& operand_value_set =
- GetValueSet(recv_done->operand(0), operand_index);
+ GetValueSet(recv_done->operand(0), index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
@@ -857,14 +856,18 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_top_level_only();
break;
case HloOpcode::kRecvDone:
- // RecvDone aliases its input tuple element {0}, therefore does not
- // define any values.
+ // RecvDone produces a two-element tuple. Element zero aliases its
+ // input tuple element {0}; element one is a token.
+ define_value_at(/*index=*/{});
+ define_value_at(/*index=*/{1});
break;
case HloOpcode::kSend:
- // Send produces a tuple of {aliased operand, U32 context}, therefore
- // only defines the top-level tuple and the tuple element at {1}.
+ // Send produces a tuple of {aliased operand, U32 context, token},
+ // therefore only defines the top-level tuple and the tuple elements
+ // at {1} and {2}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
+ define_value_at(/*index=*/{2});
break;
default:
define_all_values();
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 70254e2c1a..343f5e7b39 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1167,20 +1167,21 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 5);
+ EXPECT_EQ(analysis.values().size(), 6);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
}
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
- // Test that a RecvDone forwards its operand tuple element at {0} to the
- // output.
+ // Test that a RecvDone forwards its operand tuple element at {0} to element
+ // {0} of the output.
auto builder = HloComputation::Builder(TestName());
auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(
@@ -1191,13 +1192,16 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 4);
+ EXPECT_EQ(analysis.values().size(), 7);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
- EXPECT_THAT(HloValuesAt(recv_done),
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
+ EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
EXPECT_TRUE(
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index c1412f7c68..3859e4cae6 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -202,12 +202,13 @@ ENTRY entry {
p0 = (f32[4]) parameter(0)
a = f32[4] get-tuple-element(p0), index=0
token = token[] after-all()
- b = (f32[4], u32[]) send(a, token), channel_id=1, sharding={maximal device=0}
- c = () send-done(b), channel_id=1, sharding={maximal device=0}
- d = (f32[4], u32[]) recv(token), channel_id=2, sharding={maximal device=0}
- e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0}
- f = f32[4] add(a, e)
- g = f32[4] subtract(a, e)
+ b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0}
+ c = token[] send-done(b), channel_id=1, sharding={maximal device=0}
+ d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0}
+ e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0}
+ e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0}
+ f = f32[4] add(a, e_element)
+ g = f32[4] subtract(a, e_element)
ROOT h = (f32[4], f32[4]) tuple(f, g)
}
)";
@@ -220,7 +221,7 @@ ENTRY entry {
EXPECT_TRUE(isolator_changed);
EXPECT_TRUE(HasDomainEdge(module, "b", "a"));
- EXPECT_TRUE(HasDomainEdge(module, "f", "e"));
+ EXPECT_TRUE(HasDomainEdge(module, "f", "e_element"));
EXPECT_FALSE(HasDomainEdge(module, "a", "p0"));
EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
@@ -231,7 +232,7 @@ ENTRY entry {
EXPECT_TRUE(remover_changed);
EXPECT_FALSE(HasDomainEdge(module, "b", "a"));
- EXPECT_FALSE(HasDomainEdge(module, "f", "e"));
+ EXPECT_FALSE(HasDomainEdge(module, "f", "e_element"));
}
TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
@@ -240,11 +241,12 @@ HloModule Module
ENTRY entry {
token = token[] after-all(), sharding={maximal device=-1}
- a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=-1}
- b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1}
- c = f32[4] add(b, b), sharding={maximal device=-1}
- d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=-1}
- ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1}
+ a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1}
+ b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1}
+ b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1}
+ c = f32[4] add(b_element, b_element), sharding={maximal device=-1}
+ d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1}
+ ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1}
}
)";
@@ -262,11 +264,12 @@ HloModule Module
ENTRY entry {
token = token[] after-all(), sharding={maximal device=0}
- a = (f32[4], u32[]) recv(token), channel_id=1, sharding={maximal device=0}
- b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0}
- c = f32[4] add(b, b)
- d = (f32[4], u32[]) send(c, token), channel_id=2, sharding={maximal device=0}
- ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0}
+ a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0}
+ b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0}
+ b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0}
+ c = f32[4] add(b_element, b_element)
+ d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0}
+ ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0}
}
)";
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5533da6eb7..98dac792fa 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1623,8 +1623,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
- << old_operand->shape().ShortDebugString() << " is not compatible with "
- << new_operand->shape().ShortDebugString();
+ << old_operand->shape() << " is not compatible with "
+ << new_operand->shape();
operands_[operand_num] = new_operand;
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index dcc1e3c8af..7052e236cd 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -207,8 +207,9 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand,
HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kSend,
- ShapeUtil::MakeTupleShape(
- {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}),
+ ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
+ ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
channel_id) {
AppendOperand(operand);
AppendOperand(token);
@@ -224,7 +225,7 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
}
HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
- : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(),
+ : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
CHECK_NOTNULL(operand)->channel_id()) {
AppendOperand(operand);
}
@@ -244,7 +245,8 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape,
HloInstruction* token, int64 channel_id)
: HloSendRecvInstruction(
HloOpcode::kRecv,
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
channel_id) {
AppendOperand(token);
}
@@ -261,7 +263,9 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
: HloSendRecvInstruction(
HloOpcode::kRecvDone,
- ShapeUtil::GetTupleElementShape(operand->shape(), 0),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
+ ShapeUtil::MakeTokenShape()}),
CHECK_NOTNULL(operand)->channel_id()) {
AppendOperand(operand);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index bf33640db1..6bcd7b042d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -382,7 +382,8 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
// Check if the shapes match for each channel.
for (const Channel& channel : channels_) {
const Shape& send_shape = channel.send->operand(0)->shape();
- const Shape& recv_shape = channel.recv_done->shape();
+ const Shape& recv_shape =
+ ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0);
if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
return FailedPrecondition("send/recv shapes do not match");
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index f40cd60907..88f3309baa 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -277,13 +277,13 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
"SendRecv",
R"(HloModule TwoSendRecvBothWayRecvFist_module
-ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
%token = token[] after-all()
- %recv = (f32[], u32[]) recv(token[] %token), channel_id=15, sharding={maximal device=1}
- ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1}
+ ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
%constant = f32[] constant(2.1), sharding={maximal device=0}
- %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
@@ -1223,11 +1223,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token = token[] after-all()
- %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
@@ -1240,11 +1240,11 @@ TEST_F(HloParserTest, MissingAttribute) {
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token = token[] after-all()
- %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(-2.1)
- %send = (f32[], u32[]) send(f32[] %constant, token[] %token)
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token)
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
@@ -1257,11 +1257,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token = token[] after-all()
- %recv = (f32[], u32[]) recv(token[] %token), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done}
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done}
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 765245096b..2e6ea14426 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -346,9 +346,10 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1));
- return CheckShape(
- send, ShapeUtil::MakeTupleShape(
- {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
+ return CheckShape(send,
+ ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
+ ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
@@ -357,7 +358,7 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
- return CheckShape(send_done, ShapeUtil::MakeNil());
+ return CheckShape(send_done, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
@@ -366,9 +367,10 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0));
- return CheckShape(recv,
- ShapeUtil::MakeTupleShape(
- {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
+ return CheckShape(
+ recv, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::GetTupleElementShape(recv_done->shape(), 0),
+ ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
@@ -376,7 +378,9 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
const HloInstruction* recv = recv_done->operand(0);
TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
- return CheckShape(recv_done, recv->shape().tuple_shapes(0));
+ return CheckShape(recv_done,
+ ShapeUtil::MakeTupleShape({recv->shape().tuple_shapes(0),
+ ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleBatchNormTraining(
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 36fdfa868d..fedc83c8f8 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1630,7 +1630,8 @@ Status LayoutAssignment::ConstrainChannelLayouts(
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kRecvDone) {
const Layout* layout = channel_constraints->ConstrainChannel(
- instruction->channel_id(), instruction->shape().layout());
+ instruction->channel_id(),
+ ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
TF_RET_CHECK(layout == nullptr)
<< instruction->ToString()
<< " cannot constrain layout as it was set to "
@@ -1647,7 +1648,7 @@ Status LayoutAssignment::ConstrainChannelLayouts(
instruction->channel_id(), operand->shape().layout());
if (layout != nullptr) {
// We found an already constrained layout which does not match the one
- // the kSend wants to impose. Eitehr add a new kCopy, or use the
+ // the kSend wants to impose. Either add a new kCopy, or use the
// existing one to marshal the correct shape.
Shape shape = operand->shape();
*shape.mutable_layout() = *layout;
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 4cd584bf8b..a673901c75 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -830,12 +830,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
token = token[] after-all()
- recv = (f32[2,2], u32[]) recv(token), channel_id=1, sharding={maximal device=1}
- ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1,
+ recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1}
+ recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
sharding={maximal device=1}
- send = (f32[2,2], u32[]) send(gte, token), channel_id=1,
+ ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
+ send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1,
sharding={maximal device=0}
- send-done = () send-done(send), channel_id=1, sharding={maximal device=0}
+ send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
}
)";
@@ -854,7 +855,7 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
AssignLayouts(module.get(), &computation_layout, &channel_constraints);
EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
- EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
EXPECT_TRUE(
ShapeUtil::Equal(ShapeUtil::GetSubshape(
FindInstruction(module.get(), "send")->shape(), {0}),
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index f410921b4b..5da26d832b 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -131,18 +131,23 @@ Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
return Status::OK();
}
-Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
- // RecvDone doesn't create a new buffer but rather aliases its input (Recv)
- // tuple element at {0} to its output.
+Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
+ // RecvDone produces a two-element tuple containing the data value (which
+ // aliases part of its operand) and a token. Only the tuple index table and
+ // the token are defined by the RecvDone.
+ NewLogicalBuffer(recv_done, /*index=*/{});
+ NewLogicalBuffer(recv_done, /*index=*/{1});
return Status::OK();
}
Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
- // Send creates new buffers for the top-level tuple and the context (tuple
- // element at {1}). Tuple element at {0} is an alias of the Send operand, so
- // we don't need to create a new Logical Buffer for that.
+ // Send creates new buffers for the top-level tuple, the context (tuple
+ // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
+ // is an alias of the Send operand, so we don't need to create a new Logical
+ // Buffer for that.
NewLogicalBuffer(send, /*index=*/{});
NewLogicalBuffer(send, /*index=*/{1});
+ NewLogicalBuffer(send, /*index=*/{2});
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index d1e1744647..a1aa875009 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -292,22 +292,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) {
}
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
- // RecvDone aliases its input (Recv) tuple element {0} to its output.
+ // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
+ // output. The other indices ({} and {1}) define their own buffers.
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
+ points_to_set.AddPointedToBuffer(
+ logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
+ /*index=*/{});
+ points_to_set.AddPointedToBuffer(
+ logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
+ /*index=*/{1});
+
const PointsToSet& operand_points_to_set =
GetPointsToSet(recv_done->operand(0));
- // Recursively copy the points to set of the operand tuple {0}.
+ // Recursively copy the points to set of the operand tuple {0} to the output
+ // element {0}.
points_to_set.ForEachMutableElement(
[this, &points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
- ShapeIndex src_index({0});
- for (auto element : index) {
- src_index.push_back(element);
+ if (index.empty() || index[0] != 0) {
+ return;
}
- *buffers = operand_points_to_set.element(src_index);
- for (auto& tuple_source :
- operand_points_to_set.tuple_sources(src_index)) {
+ *buffers = operand_points_to_set.element(index);
+ for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
points_to_set.add_tuple_source(index, tuple_source);
}
});
@@ -315,7 +322,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
}
Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
- // Send creates a tuple of {aliased operand, U32 context}.
+ // Send creates a tuple of {aliased operand, U32 context, token}.
PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
// Creates the points to set for the tuple and its element at {1}.
@@ -328,6 +335,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
context_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
+ auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
+ token_buffer->push_back(
+ &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
+
// Recursively copy the points to set of the operand to output tuple {0}.
const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
operand_points_to_set.ForEachElement(
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a8f885fd86..1e7d058eb6 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -357,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
- ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
+ ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
}
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {