aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-06-19 18:38:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-19 18:41:40 -0700
commit5b6a203c5c759656b2b7018271219916ddd85cb6 (patch)
treee4ba01c8a30d2066ee7f05a147638e5d0cbe246b /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parenta36488d812780e78f869a3eb2b692cf3c236f1cc (diff)
[XLA] Add live range interference querying to dataflow analysis.
Add method MayInterfere to HloDataflowAnalysis which returns whether the live ranges of two values interfere. This will replace buffer_liveness.cc. The cl includes a few related changes: (1) HloOrdering: Apply an order to the condition and body computations. Specifically, for the purposes of HLO ordering the condition is ordered before the body. This ensures that the live ranges of values in the condition do not interfere with the live ranges in the body. (2) Add a Dominates method to CallGraph for determining whether a computation dominates another in the call graph. (3) Tightened the definition of "use" in the dataflow analysis. Now an instruction which passes through a value without reading it is no longer considered a use of the value. This new definition is reflected in the HloUse objects returned by HloValue::uses(). PiperOrigin-RevId: 159509724
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc528
1 files changed, 442 insertions, 86 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index a97620cd0d..79edd0fcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -39,14 +39,14 @@ using ::testing::UnorderedElementsAre;
class HloDataflowAnalysisTest : public HloTestBase,
public ::testing::WithParamInterface<bool> {
protected:
- HloDataflowAnalysisTest() : module_(TestName()) {}
+ HloDataflowAnalysisTest() : module_(CreateNewModule()) {}
// Run dataflow analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
bool bitcast_defines_value = false) {
analysis_ =
- HloDataflowAnalysis::Run(&module_, ssa_form, bitcast_defines_value)
+ HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
return *analysis_;
}
@@ -63,10 +63,22 @@ class HloDataflowAnalysisTest : public HloTestBase,
return values;
}
- HloModule module_;
+ // Returns true if the top-level values for instructions 'a' and 'b' may
+ // interfere. Precondition: 'a' and 'b' define array-shaped values.
+ bool InstructionsMayInterfere(const HloOrdering& ordering,
+ const HloInstruction* a,
+ const HloInstruction* b) {
+ EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
+ EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
+ return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
+ analysis_->GetValueDefinedAt(b), ordering);
+ }
+
+ std::unique_ptr<HloModule> module_;
std::unique_ptr<HloDataflowAnalysis> analysis_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
+ const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
};
TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
@@ -78,7 +90,7 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, constant1, constant2));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -126,7 +138,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -158,22 +170,16 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
// Verify uses. Of interest is that a GetTupleElement instruction is only a
// use of the top-level value in the tuple operand.
EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
- UnorderedElementsAre(HloUse{tuple, 0, {}}, HloUse{add, 0, {}}));
+ UnorderedElementsAre(HloUse{add, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
- UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{add, 1, {}}));
+ UnorderedElementsAre(HloUse{add, 1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, NestedTuple) {
- // Verify the dataflow through a nested tuple of the following form for two
- // constants %constant1 and %constant2:
- //
- // %nested_tuple = {{%constant1, %constant2},
- // {%constant1, %constant2},
- // %constant1}
- //
+ // Verify the dataflow through a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
@@ -187,7 +193,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
auto gte_out = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -202,18 +208,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}},
HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}},
HloLocation{gte_out, {}}));
- EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(
- HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}},
- HloUse{nested_tuple, 1, {0}}, HloUse{nested_tuple, 2, {}}));
- EXPECT_THAT(
- analysis.GetValueDefinedAt(constant2).uses(),
- UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{nested_tuple, 0, {1}},
- HloUse{nested_tuple, 1, {1}}));
+ // Constant values should have no uses though one is live out. The locations
+ // where they appear as operands are on instructions which do not use the
+ // values (eg, Tuple).
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
+
+ // The top-level tuple values are used in GTE instructions.
EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
- UnorderedElementsAre(HloUse{nested_tuple, 0, {}},
- HloUse{nested_tuple, 1, {}},
- HloUse{gte_out, 0, {}}));
+ UnorderedElementsAre(HloUse{gte_out, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
@@ -236,7 +239,7 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
HloComputation* called_computation =
- module_.AddEmbeddedComputation(subbuilder.Build());
+ module_->AddEmbeddedComputation(subbuilder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -245,7 +248,7 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -268,11 +271,12 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call, 0, {}}));
+ UnorderedElementsAre(HloUse{add, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
- UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call, 1, {}}));
+ UnorderedElementsAre(HloUse{add, 1, {}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
}
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
@@ -285,7 +289,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
HloComputation* called_computation =
- module_.AddEmbeddedComputation(subbuilder.Build());
+ module_->AddEmbeddedComputation(subbuilder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -298,7 +302,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
scalar_shape_, {constant1, constant2}, called_computation));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kSubtract, call1, call2));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -316,17 +320,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call1, 0, {}},
- HloUse{call2, 0, {}}));
+ UnorderedElementsAre(HloUse{add, 0, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
- UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call1, 1, {}},
- HloUse{call2, 1, {}}));
+ UnorderedElementsAre(HloUse{add, 1, {}}));
// The Add from the subcomputation is used as both operands of the Subtract.
EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
+
EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation());
}
TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
@@ -339,7 +344,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
HloComputation* called_computation =
- module_.AddEmbeddedComputation(subbuilder.Build());
+ module_->AddEmbeddedComputation(subbuilder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -350,7 +355,7 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {call1, constant2}, called_computation));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -392,7 +397,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
HloComputation* inner_computation =
- module_.AddEmbeddedComputation(inner_builder.Build());
+ module_->AddEmbeddedComputation(inner_builder.Build());
auto outer_builder = HloComputation::Builder("OuterComputation");
auto outer_param0 = outer_builder.AddInstruction(
@@ -400,19 +405,19 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
auto outer_param1 = outer_builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
// Swizzle parameters.
- auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
+ outer_builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {outer_param1, outer_param0}, inner_computation));
HloComputation* outer_computation =
- module_.AddEmbeddedComputation(outer_builder.Build());
+ module_->AddEmbeddedComputation(outer_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
- auto call = builder.AddInstruction(HloInstruction::CreateCall(
+ builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, outer_computation));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -423,14 +428,10 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
// Verify that the uses of the constants are properly swizzled by parameter
// permutation in nested_call.
- EXPECT_THAT(
- analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
- HloUse{add, 1, {}}));
- EXPECT_THAT(
- analysis.GetValueDefinedAt(constant2).uses(),
- UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
- HloUse{add, 0, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
+ UnorderedElementsAre(HloUse{add, 1, {}}));
+ EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
+ UnorderedElementsAre(HloUse{add, 0, {}}));
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
}
@@ -465,18 +466,18 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
- auto body_tuple = body_builder.AddInstruction(
+ body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_0, add}));
- HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build());
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
- cond_builder.AddInstruction(
+ auto cond_constant = cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
- module_.AddEmbeddedComputation(cond_builder.Build());
+ module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -487,11 +488,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+ EXPECT_TRUE(
+ analysis.GetValueDefinedAt(cond_constant).live_out_of_computation());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
+
if (ssa_form) {
// Element 0 of the tuple passed through the body so no phi value is
// defined.
@@ -507,15 +512,17 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
- EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{tuple, 0, {}},
- HloUse{xla_while, 0, {0}},
- HloUse{body_tuple, 0, {}}));
+ EXPECT_THAT(
+ analysis.GetValueDefinedAt(constant1).uses(),
+ UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
// Constant1 passes through the body and out of the module.
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
+
+ EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
+ EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
} else {
// While instruction and subcomputation parameters should not define values
// in non-ssa form.
@@ -528,6 +535,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
}
}
@@ -565,7 +573,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_0, add}));
- HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build());
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
@@ -573,7 +581,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
- module_.AddEmbeddedComputation(cond_builder.Build());
+ module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -588,7 +596,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
auto xla_while2 = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -632,7 +640,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
- module_.AddEmbeddedComputation(cond_builder.Build());
+ module_->AddEmbeddedComputation(cond_builder.Build());
// Element 0 passes transparently through the body.
auto inner_builder = HloComputation::Builder("inner_body");
@@ -647,7 +655,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
inner_builder.AddInstruction(
HloInstruction::CreateTuple({inner_element_0, add}));
HloComputation* inner_body =
- module_.AddEmbeddedComputation(inner_builder.Build());
+ module_->AddEmbeddedComputation(inner_builder.Build());
// Element 1 passes transparently through the body.
auto outer_builder = HloComputation::Builder("outer_body");
@@ -664,7 +672,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
tuple_shape, condition, inner_body, outer_tuple));
HloComputation* outer_body =
- module_.AddEmbeddedComputation(outer_builder.Build());
+ module_->AddEmbeddedComputation(outer_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -675,7 +683,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
HloInstruction::CreateTuple({constant1, constant2}));
auto entry_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -751,7 +759,7 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_1, body_element_0}));
- HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build());
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
@@ -759,7 +767,7 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
- module_.AddEmbeddedComputation(cond_builder.Build());
+ module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
@@ -770,7 +778,7 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -825,7 +833,7 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) {
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -868,7 +876,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
tuple_shape, HloOpcode::kSelect, pred, select12, select34));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -899,14 +907,16 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
analysis.GetValueDefinedAt(constant4)));
EXPECT_THAT(
- analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}},
- HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}},
- HloUse{select1234, 1, {0}}));
- EXPECT_THAT(
- analysis.GetValueDefinedAt(constant2).uses(),
- UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}},
- HloUse{select1234, 1, {0}}));
+ analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
+ UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
+ HloUse{select12, 1, {}}));
+
+ // The two constant values just pass through the Selects and are not
+ // used. They are live out however.
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
@@ -935,7 +945,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -993,7 +1003,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_0, add}));
- HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build());
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
@@ -1001,7 +1011,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* condition =
- module_.AddEmbeddedComputation(cond_builder.Build());
+ module_->AddEmbeddedComputation(cond_builder.Build());
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
@@ -1024,7 +1034,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -1066,7 +1076,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
{
@@ -1102,7 +1112,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
- module_.AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
@@ -1126,6 +1136,352 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
}
+TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
+ // A simple chain of elementwise operations. No values should interfere.
+ //
+ // param --> negate -> exp -> log
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
+
+ module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ DependencyHloOrdering ordering(module_.get());
+
+ // No values should interfere.
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
+
+ // Values should interfere with itself.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
+}
+
+TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
+ // Two entry params, which interfere with each other.
+ //
+ // param0 --> negate ---------------\
+ // param1 --> exp --> add
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, vector_shape_, "param1"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ vector_shape_, HloOpcode::kAdd, negate, exp));
+
+ auto entry = module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ SequentialHloOrdering::HloModuleSequence sequence;
+ sequence.insert({entry, {param0, negate, param1, exp, add}});
+ SequentialHloOrdering ordering(module_.get(), sequence);
+
+ // Entry parameters interfere as if they are defined simultaneously at
+ // the very beginning.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
+
+ // Negate and exp still interfere.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
+
+ // But {negate, add} and {exp, add} don't interfere.
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
+}
+
+TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
+ // Similar to MultipleEntryParameters_Sequential, but the parameter is of
+ // while body computation. Body computation in the sequential order:
+ //
+ // %constant = Constant(...)
+ // %exp = Exp(%constant)
+ // %param = Param(0)
+ // %add = Add(%param, %exp) ;; Root of body
+ // %dead_constant = Constant(...)
+ // %dead_negate = Negate(%dead_constant)
+ //
+ // %constant and its only use %exp are ordered before 'param'. However, the
+ // %constant and %param values still interfere because the parameter is
+ // considered live into the while body.
+ //
+ // Similarly, %dead_constant and %dead_negate are ordered after the root of
+ // the body computation %add. However, %add is liveout of the computation so
+ // %dead_constant and %add interfere.
+ auto body_builder = HloComputation::Builder(TestName());
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
+ auto constant = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto exp = body_builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
+ auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, exp, body_param));
+ auto dead_constant = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, dead_constant));
+ HloComputation* body = module_->AddEmbeddedComputation(
+ body_builder.Build(/*root_instruction=*/add));
+
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
+ auto cond_constant = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloComputation* condition =
+ module_->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
+
+ auto entry = module_->AddEntryComputation(builder.Build());
+ bool ssa_form = GetParam();
+ const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+ SequentialHloOrdering::HloModuleSequence sequence;
+ sequence.insert({entry, {param, xla_while}});
+ sequence.insert({condition, {cond_param, cond_constant}});
+ // Construct the order such that 'constant' and its use 'exp' are before
+ // body_param.
+ sequence.insert({body, {constant, exp, body_param, add}});
+
+ SequentialHloOrdering ordering(module_.get(), sequence);
+
+ // 'add' is the body root even though later instructions follow in the order
+ // like 'dead_negate'. Only 'add' should be live out of the computation.
+ EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation());
+ EXPECT_FALSE(
+ analysis.GetValueDefinedAt(dead_negate).live_out_of_computation());
+
+ // 'add' is live out of the body and will interfere with an later instructions
+ // such as 'dead_constant' and 'dead_negate'.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
+
+ // The remaining checks test phi values defined by body and condition
+ // parameters which only occur in the SSA form of the analysis.
+ if (ssa_form) {
+ // Though the ordering suggests 'constant' and 'param' should not interfere,
+ // 'param' is live in and thus interferes with any earlier instruction of
+ // the computation in the order (eg 'constant')'
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
+
+ // The following values end up in the same buffer:
+ // (1) the init value: 'param'
+ // (2) the body parameter: 'body_param'
+ // (3) the condition parameter: 'cond_param'
+ // (4) the root value of the while body: 'add'
+ // (5) the while value: 'xla_while'
+ // None should interfere.
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
+
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
+
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
+
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
+ }
+}
+
+TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
+ // A chain of operations with two elementwise and one non-elementwise. The
+ // elementwise op should not interfere with its operand, while the
+ // non-elementwise op should interfere. Entry params always interfere.
+ //
+ // param --> exp -> negate -> reverse
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param"));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(vector_shape_, negate, {0}));
+
+ module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ DependencyHloOrdering ordering(module_.get());
+
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
+
+ // Negate is elementwise, so doesn't interfere with its operand.
+ // Reverse is non-elementwise, so does interfere with its operand.
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
+}
+
+TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
+ // Verify simultaneously live values interfere (exp and negate).
+ //
+ // param --> negate -> add
+ // \---> exp -----/
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ vector_shape_, HloOpcode::kAdd, negate, exp));
+
+ module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ DependencyHloOrdering ordering(module_.get());
+
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
+
+ // Negate and exp interfere with each other, but not with add.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
+}
+
+TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
+ // Identical to the test OverlappedValue but using a sequential ordering of
+ // HLO instructions.
+ //
+ // param --> negate -> add
+ // \---> exp -----/
+ //
+ // Sequential order:
+ // param, negate, exp, add
+ //
+ // Liveness is identical to the DependencyHloOrdering.
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ vector_shape_, HloOpcode::kAdd, negate, exp));
+
+ auto entry = module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ SequentialHloOrdering::HloModuleSequence sequence;
+ std::vector<const HloInstruction*> order = {param, negate, exp, add};
+ sequence.emplace(entry, order);
+
+ SequentialHloOrdering ordering(module_.get(), sequence);
+
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
+
+ // Negate and exp interfere with each other, but not with add.
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
+}
+
+TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
+ // Test MayInterfere() for embedded computation, specifically the interference
+ // of values in different computations.
+ //
+ // embedded_computation:
+ // %embedded_param = Param(0)
+ // %embedded_log = Log(%embedded_param)
+ //
+ // entry computation:
+ // %param = Param(0)
+ // %negate = Negate(%param)
+ // %exp = Negate(%exp)
+ // %call = Call(embedded_computation, {%exp})
+ // %add = Add(%negate, %call)
+ //
+ // Note %negate is live across the call and should interfere with all values
+ // in the embedded computation.
+ auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
+ auto embedded_param = embedded_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
+ auto embedded_log =
+ embedded_builder.AddInstruction(HloInstruction::CreateUnary(
+ vector_shape_, HloOpcode::kLog, embedded_param));
+ auto embedded_computation =
+ module_->AddEmbeddedComputation(embedded_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vector_shape_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
+ auto call = builder.AddInstruction(
+ HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ vector_shape_, HloOpcode::kAdd, negate, call));
+ module_->AddEntryComputation(builder.Build());
+ RunAnalysis(GetParam());
+
+ DependencyHloOrdering ordering(module_.get());
+
+ // Exp only use is the call so it should not interfere with values inside the
+ // embedded computation.
+ EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
+
+ // Negate is live across the call and should interfere with values in the
+ // embedded computation
+ EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
+}
+
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));