aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-09 15:01:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-09 15:04:53 -0700
commit56633fe0cecba03929738df0a0788216f57cf8e9 (patch)
tree7c0be8c004e12172e4e33f98eb1307a395454247 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent83accbb3745b4019c21f59c3e6f9ab92250261ba (diff)
Make HloDataFlowAnalysis updatable after transforming the HLO graph.
Updating is possible if operands/uses or computation roots change in the graph. Updating is not possible if instructions are deleted or if new instructions are added. Specific changes: * Add verification methods for asserting invariants and checking the analysis after updating. * Always add phi values at while instructions. Previously these were added only if the phi had different inputs. The advantage of using phi's unconditionally is that the set of values is fixed for a module. Updates due to changing operands/uses in the graph do not create new values. * Store values in a vector rather than a map. With unconditional phi values, the number of HloValues is fixed so the values can be held in a vector with stable references to elements. PiperOrigin-RevId: 164778750
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc346
1 files changed, 305 insertions, 41 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 2b685e355f..9f3dd539ef 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -43,8 +44,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
// Run dataflow analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
- const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
- bool bitcast_defines_value = false) {
+ HloDataflowAnalysis& RunAnalysis(bool ssa_form,
+ bool bitcast_defines_value = false) {
analysis_ =
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
@@ -498,26 +499,37 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
if (ssa_form) {
- // Element 0 of the tuple passed through the body so no phi value is
- // defined.
- EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
+ // While instruction should define phi values. The value at index {0} is a
+ // degenerate phi with a single input 'constant1'.
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
+ &analysis.GetValueDefinedAt(constant1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
+ &analysis.GetValueDefinedAt(constant1));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
+ &analysis.GetValueDefinedAt(constant1));
- // Element 1 of the tuple should be a phi value.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
- EXPECT_THAT(
- analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
+ UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
- // Constant1 passes through the body and out of the module.
- EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
+ .live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
@@ -601,15 +613,20 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- // Element 0 is passed through all the while instructions and out of the
- // module..
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ if (ssa_form) {
+ EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ } else {
+ // Element 0 is passed through all the while instructions and out of the
+ // module.
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ }
}
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
@@ -688,18 +705,13 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
-
- // Element 0 of the nested while is %negate.
- EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
- EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
- // Element 1 is a phi value (join of %add and %constant2).
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
+ EXPECT_TRUE(
+ analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
@@ -712,6 +724,8 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
EXPECT_TRUE(
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
} else {
+ EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
analysis.GetValueDefinedAt(constant2)));
@@ -952,17 +966,17 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
- EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
- analysis.GetValueDefinedAt(constant4)));
- EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
- analysis.GetValueDefinedAt(inner_tuple2)));
- EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
- analysis.GetValueDefinedAt(constant5)));
- EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
+ EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
+ analysis.GetValueDefinedAt(constant4)));
+ EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
+ analysis.GetValueDefinedAt(inner_tuple2)));
+ EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
+ analysis.GetValueDefinedAt(constant5)));
+ EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
}
TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
@@ -1482,6 +1496,256 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
}
+TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
+ // Test updating dataflow after modifying a module with an array shaped while:
+ //
+ // body(F32[] %param):
+ // %negate = Negate(%param)
+ //
+ // condition(F32[] %param):
+ // return Constant(false)
+ //
+ // entry:
+ // %constant = Constant(1.0)
+ // %exp = Exp(%constant)
+ // return While(%exp, body, condition)
+ //
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, body_param));
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
+
+ // Condition computation trivially returns a constant "false".
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloComputation* condition =
+ module_->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ // Sanity check the initial dataflow analysis before transforming the HLO
+ // graph.
+ if (ssa_form) {
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
+
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
+ EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
+
+ EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
+ } else {
+ EXPECT_THAT(HloValuesAt(body_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
+ analysis.GetValueDefinedAt(negate)));
+ EXPECT_THAT(HloValuesAt(cond_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
+ analysis.GetValueDefinedAt(negate)));
+ EXPECT_THAT(HloValuesAt(xla_while),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
+ analysis.GetValueDefinedAt(negate)));
+
+ EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
+ }
+
+ // Set the body root to the body_param. Previously it was Negate(body_param).
+ body->set_root_instruction(body_param);
+
+ // Prior to updating, verify that the dataflow analysis is no longer valid.
+ Status verify_status = analysis.VerifyAgainstReference();
+ EXPECT_FALSE(verify_status.ok());
+
+ analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
+ /*new_root=*/body_param);
+
+ // Analysis should be valid after the update.
+ TF_EXPECT_OK(analysis.VerifyAgainstReference());
+
+ if (ssa_form) {
+ // The phis should now be resolvable as 'exp' is passed through the body
+ // transparently.
+ EXPECT_EQ(analysis.ResolvePhi(body_param),
+ &analysis.GetValueDefinedAt(exp));
+ EXPECT_EQ(analysis.ResolvePhi(cond_param),
+ &analysis.GetValueDefinedAt(exp));
+ EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
+ EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
+ } else {
+ EXPECT_THAT(HloValuesAt(body_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
+ EXPECT_THAT(HloValuesAt(cond_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
+ EXPECT_THAT(HloValuesAt(xla_while),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
+ }
+ EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
+
+ // Now replace the operand of the while with %constant (was %exp).
+ TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
+ analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
+ /*new_operand=*/constant);
+
+ // Verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ if (ssa_form) {
+ // The phis now resolve to 'constant'.
+ EXPECT_EQ(analysis.ResolvePhi(body_param),
+ &analysis.GetValueDefinedAt(constant));
+ EXPECT_EQ(analysis.ResolvePhi(cond_param),
+ &analysis.GetValueDefinedAt(constant));
+ EXPECT_EQ(analysis.ResolvePhi(xla_while),
+ &analysis.GetValueDefinedAt(constant));
+ } else {
+ EXPECT_THAT(HloValuesAt(body_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
+ EXPECT_THAT(HloValuesAt(cond_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
+ EXPECT_THAT(HloValuesAt(xla_while),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
+ }
+
+ // And finally make the negate the root of the body again.
+ body->set_root_instruction(negate);
+ analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
+ /*new_root=*/negate);
+
+ // Verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ if (ssa_form) {
+ // Phis should no longer be resolvable.
+ EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
+ EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
+ EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
+ } else {
+ EXPECT_THAT(HloValuesAt(body_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
+ analysis.GetValueDefinedAt(negate)));
+ EXPECT_THAT(HloValuesAt(cond_param),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
+ analysis.GetValueDefinedAt(negate)));
+ EXPECT_THAT(HloValuesAt(xla_while),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
+ analysis.GetValueDefinedAt(negate)));
+
+ EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
+ }
+
+ // After the updates, verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+}
+
+TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
+ // Test changing the operands of kSelects of a tuple value and updating the
+ // dataflow.
+ auto builder = HloComputation::Builder(TestName());
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ auto a = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto b = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto c = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ auto d = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
+ auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
+ auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
+ auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
+ const Shape tuple_shape = tuple_a->shape();
+ auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
+ auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
+ auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
+ auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
+
+ module_->AddEntryComputation(builder.Build());
+
+ bool ssa_form = GetParam();
+ HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ // Sanity check dataflow before changing the graph and updating.
+ EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
+ EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a),
+ analysis.GetValueDefinedAt(b)));
+ EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(c),
+ analysis.GetValueDefinedAt(d)));
+ EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a),
+ analysis.GetValueDefinedAt(b),
+ analysis.GetValueDefinedAt(c),
+ analysis.GetValueDefinedAt(d)));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
+
+ // Set the rhs of 'select_aa' to be 'd'.
+ TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
+ analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
+ /*new_operand=*/tuple_d);
+
+ // Verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a),
+ analysis.GetValueDefinedAt(d)));
+
+ // Set the lhs of 'select_cd' to be 'a'.
+ TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
+ analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
+ /*new_operand=*/tuple_a);
+
+ // Verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a),
+ analysis.GetValueDefinedAt(d)));
+ EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(a),
+ analysis.GetValueDefinedAt(b),
+ analysis.GetValueDefinedAt(d)));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
+
+ // After the updates, verify that the dataflow is correct.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+}
+
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));