diff options
author | 2017-11-09 14:48:37 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:40 -0800 | |
commit | 51895becce83ef4dc8bac263377d158fc50e4d53 (patch) | |
tree | 81ba187df17b1c2a3b0784f776f68fc43060f544 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc | |
parent | 2bb46f6376c35ec86279e80a24dc06b068b41556 (diff) |
Change for asynchronous Send and Recv by splitting Send into {Send, SendDone}
and Recv into {Recv, RecvDone}. See operation_semantics.md for the updated
semantics.
PiperOrigin-RevId: 175216012
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4b8eb237a6..66a538fc51 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { + // Test that a Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(param, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); + EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param))); +} + +TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { + // Test that a RecvDone forwards its operand tuple element at {0} to the + // output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 3); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); + EXPECT_THAT(HloValuesAt(recv_done), + UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); + EXPECT_TRUE( + analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); +} + TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { // A simple chain of elementwise operations. No values should interfere. // |