aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 17:31:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 17:34:27 -0700
commita350f66ed250c3dee43cc27b0778c3759f07e810 (patch)
tree6416b984ae6ffb1147937ce0a61f376176680c01 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parentc7776b996d88c83e0e94aa0fde0f32c4fb23144b (diff)
Add backend specific lambda to decide whether a fusion instruction can share buffer with its operand.
PiperOrigin-RevId: 201615582
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc36
1 files changed, 34 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index a2d08f797c..0ea8bdcab6 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1880,9 +1880,14 @@ class HloDataflowAnalysisTestBase : public HloTestBase {
computation_ = module_->AddEntryComputation(std::move(computation));
}
- void RunAnalysis() {
+ void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr) {
CHECK_NOTNULL(module_.get());
- dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
+ dataflow_analysis_ =
+ HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
+ /*bitcast_defines_value=*/false,
+ fusion_can_share_buffer)
+ .ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
@@ -2283,6 +2288,33 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape, HloOpcode::kMultiply, operand, operand));
+ auto two = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, two, mul}, HloInstruction::FusionKind::kInput);
+ RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
+ const HloInstruction*) {
+ return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ });
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});