aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-07-03 10:34:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 10:37:09 -0700
commit31a5fa1ee88f8f3bb1a46f3734136b6d85e8642f (patch)
tree8ada804c9a30d162778decc9fcef009acfa39d49 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parentbce9c2ef1ad72a5d962faac0d114932af6a69bf9 (diff)
Change Send, SendDone, Recv and RecvDone to produce tokens.
This is a follow up to cl/202069017 which added tokens as operands to Send and Recv. PiperOrigin-RevId: 203145403
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 70254e2c1a..343f5e7b39 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1167,20 +1167,21 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 5);
+ EXPECT_EQ(analysis.values().size(), 6);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
}
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
- // Test that a RecvDone forwards its operand tuple element at {0} to the
- // output.
+ // Test that a RecvDone forwards its operand tuple element at {0} to element
+ // {0} of the output.
auto builder = HloComputation::Builder(TestName());
auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(
@@ -1191,13 +1192,16 @@ TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 4);
+ EXPECT_EQ(analysis.values().size(), 7);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
- EXPECT_THAT(HloValuesAt(recv_done),
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
+ EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
EXPECT_TRUE(
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());