aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2017-11-09 14:48:37 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:40 -0800
commit51895becce83ef4dc8bac263377d158fc50e4d53 (patch)
tree81ba187df17b1c2a3b0784f776f68fc43060f544 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent2bb46f6376c35ec86279e80a24dc06b068b41556 (diff)
Change for asynchronous Send and Recv by splitting Send into {Send, SendDone}
and Recv into {Recv, RecvDone}. See operation_semantics.md for the updated semantics. PiperOrigin-RevId: 175216012
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc48
1 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4b8eb237a6..66a538fc51 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
}
+TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
+ // Test that a Send forwards its operand to the output tuple at {0}.
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+ auto send = builder.AddInstruction(
+ HloInstruction::CreateSend(param, /*channel_id=*/0));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ EXPECT_EQ(analysis.values().size(), 4);
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
+ EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
+}
+
+TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
+ // Test that a RecvDone forwards its operand tuple element at {0} to the
+ // output.
+ auto builder = HloComputation::Builder(TestName());
+ auto recv = builder.AddInstruction(
+ HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ EXPECT_EQ(analysis.values().size(), 3);
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
+ EXPECT_THAT(HloValuesAt(recv_done),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
+ EXPECT_TRUE(
+ analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
+}
+
TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
// A simple chain of elementwise operations. No values should interfere.
//