aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-26 08:35:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 08:39:12 -0700
commit78a9b95436f45438abf3e818307f707e9ae92343 (patch)
tree94dfdfa894f0dec6ba917b905908985f6594b223 /tensorflow/compiler
parent49495697cddef73a0dd870176dab488bb2a65520 (diff)
[XLA] Finish normalizing fusion computations into standard computations
PiperOrigin-RevId: 163210327
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_rewriter.cc3
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc6
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc3
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc82
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc69
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc215
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc4
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc11
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc5
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.cc3
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc11
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc11
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc9
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc6
28 files changed, 302 insertions, 190 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index a4612bb6c1..8fb0faf026 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1210,6 +1210,7 @@ cc_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4837402c15..691f9f2296 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1586,6 +1586,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
// module, invalidating iteration.
std::vector<HloComputation*> computations;
for (auto& comp : module->computations()) {
+ if (comp->IsFusionComputation()) {
+ continue;
+ }
computations.push_back(comp.get());
}
for (auto& comp : computations) {
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
index 5d5d3caa2f..ca2d413e11 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
@@ -268,6 +268,9 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
// module, invalidating iteration.
std::vector<HloComputation*> computations;
for (auto& comp : module->computations()) {
+ if (comp->IsFusionComputation()) {
+ continue;
+ }
computations.push_back(comp.get());
}
for (auto& comp : computations) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index ddc3d11b7c..ae31135a1a 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1219,6 +1219,9 @@ void BufferAssigner::BuildColocatedBufferSets(
const TuplePointsToAnalysis& points_to_analysis =
buffer_liveness.points_to_analysis();
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
for (const HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
const HloOpcode opcode = instruction->opcode();
@@ -1386,6 +1389,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// their own BufferAllocation.
for (auto* computation : thread_local_computations) {
TF_RET_CHECK(computation != module->entry_computation());
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
computation, module->config().debug_options(),
/*is_thread_local=*/true, colocated_buffers, colocated_allocations,
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index 6720a90ef8..f085ffa6bc 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -47,6 +47,9 @@ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
tensorflow::Status BufferLiveness::Analyze() {
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
for (auto& computation : module_->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
// Gather all instructions whose buffers might alias other instructions into
// the set aliased_buffers_. This includes those contained as a tuple
// element in other instruction's output.
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index a3803c34ba..c47abe9c62 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -551,6 +551,9 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Add copies of computation root instructions, if needed.
FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices;
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
VLOG(2) << "computation " << computation->name();
InstructionCopier root_copier(computation->root_instruction(),
/*copy_users=*/{});
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index b86342d0b3..59e8c75b91 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -519,6 +519,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
new std::map<HloInstruction*, string>());
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
+ if (embedded_computation->IsFusionComputation()) {
+ continue;
+ }
auto parallel_computation_iter =
parallel_computations.find(embedded_computation);
// All parallel computations are considered to be an entry computation for
@@ -591,6 +594,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
+ if (embedded_computation->IsFusionComputation()) {
+ continue;
+ }
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation,
@@ -755,6 +761,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
+ if (embedded_computation->IsFusionComputation()) {
+ continue;
+ }
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
index af931f7b01..4d0e0f744a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
@@ -125,6 +125,9 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
HloInstruction* root = computation->root_instruction();
// Copy root instruction if it does not define its own top-level buffer.
// TODO(b/32885001) Remove these copies (at least for the unambiguous case).
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index e698646d18..a9ef204b46 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -293,12 +293,19 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
StatusOr<bool> FusionMerger::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "FusionMerger for module: " << module->name();
+ std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ computations.push_back(computation.get());
+ }
+ for (auto& computation : computations) {
VLOG(1) << "Before running FusionInstructionMerger for computation: "
<< computation->name();
XLA_VLOG_LINES(3, computation->ToString());
- FusionInstructionMerger fusion_merger(computation.get());
+ FusionInstructionMerger fusion_merger(computation);
TF_RETURN_IF_ERROR(fusion_merger.Run());
changed |= fusion_merger.changed();
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index c61e47a93c..81e905a066 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -120,7 +120,8 @@ GpuHloOrdering::GpuHloOrdering(
// do that yet since it's hard to ensure that the order here is the order used
// by IrEmitterNested. And mismatched ordering bugs would be hard to find.
for (auto& computation : module->computations()) {
- if (computation.get() != module->entry_computation()) {
+ if (computation.get() != module->entry_computation() &&
+ !computation->IsFusionComputation()) {
predecessors_.emplace(computation.get(),
computation->ComputeReachability());
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 804efdd906..1a2eed5f60 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -42,6 +42,9 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
for (auto instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->user_count() == 0 &&
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index f745683165..0a288a77ad 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -329,7 +330,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
}
-using FusionCostAnalysis = ::testing::Test;
+using FusionCostAnalysis = HloTestBase;
TEST_F(FusionCostAnalysis, LoopFusion) {
// Do this 4 times with different per-second rates to test the computation of
@@ -345,32 +346,32 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
// mul = Mul(exp, C3)
// sub = Sub(mul, clamp)
// tuple = Tuple({sub, sub, mul, C1})
- auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
- /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
- /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2));
- auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
- /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2));
-
- auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(),
- c2.get());
- auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp,
- c2.get(), add.get(), add.get());
- auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get());
- auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply,
- exp.get(), c3.get());
- auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract,
- mul.get(), clamp.get());
- auto tuple = HloInstruction::CreateTuple(
- {sub.get(), sub.get(), mul.get(), c1.get()});
-
- auto fusion = HloInstruction::CreateFusion(
- r2f32, HloInstruction::FusionKind::kLoop, tuple.get());
- fusion->FuseInstruction(sub.get());
- fusion->FuseInstruction(mul.get());
- fusion->FuseInstruction(exp.get());
- fusion->FuseInstruction(clamp.get());
- fusion->FuseInstruction(add.get());
+ HloComputation::Builder builder(TestName());
+ auto c1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
+ auto c2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
+ auto c3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
+ auto clamp = builder.AddInstruction(
+ HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add));
+ auto mul = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3));
+ auto sub = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp));
+ auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1});
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
// The time given these rates at i == 0 is exactly even among the properties
// at 1.0 seconds. For other values, one of the rates is slower so that it
@@ -398,18 +399,21 @@ TEST_F(FusionCostAnalysis, NoLayout) {
Shape shape_without_layout = shape_with_layout;
shape_without_layout.clear_layout();
- auto c1 = HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
-
- auto broadcast =
- HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1});
- auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd,
- c1.get(), broadcast.get());
-
- auto fusion = HloInstruction::CreateFusion(
- shape_with_layout, HloInstruction::FusionKind::kLoop, add.get());
- fusion->FuseInstruction(broadcast.get());
+ HloComputation::Builder builder(TestName());
+ auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
+ auto c2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
+
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ shape_with_layout, HloOpcode::kAdd, c1, broadcast));
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {add, broadcast}, HloInstruction::FusionKind::kLoop);
HloCostAnalysis fusion_analysis(ShapeSize);
ASSERT_IS_OK(fusion->Accept(&fusion_analysis));
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index 0fef89a06d..690c084efb 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -92,6 +92,9 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
StatusOr<bool> HloCSE::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
changed |= CombineConstants(computation.get(), is_layout_sensitive_);
std::list<HloInstruction*> post_order =
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index 3755b9e4c0..5b2c57da4f 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -38,6 +38,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
bool changed = false;
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
std::unordered_set<HloInstruction*> live_instructions;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
[&live_instructions](HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f52882cca5..ed8a942d03 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -560,19 +560,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
HloInstruction* instruction_to_fuse) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(instruction_to_fuse->IsFusable());
-
+ if (GetModule()) {
+ XLA_VLOG_LINES(1, GetModule()->ToString());
+ }
HloInstruction* clone = nullptr;
- if (fused_instructions_computation_ == nullptr) {
+ if (called_computations_.empty()) {
// New fusion instruction.
auto builder = HloComputation::Builder("fused_computation", true);
builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
- fused_instructions_computation_ = builder.Build();
+ called_computations_.push_back(
+ CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
clone = fused_expression_root();
clone->parent_fusion_instruction_ = this;
} else {
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
- clone = fused_instructions_computation_->AddInstruction(
+ clone = fused_instructions_computation()->AddInstruction(
instruction_to_fuse->Clone(/*suffix=*/""));
clone->parent_fusion_instruction_ = this;
// instruction_to_fuse is necessarily an operand of the fusion instruction.
@@ -583,7 +584,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) !=
operands_.end());
const std::vector<HloInstruction*>& fused_parameters_ =
- fused_instructions_computation_->parameter_instructions();
+ fused_instructions_computation()->parameter_instructions();
for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
if (instruction_to_fuse == operands_[operand_num]) {
// replace the fused parameter instruction's uses with the clone.
@@ -593,7 +594,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// Remove the corresponding fused parameter and operand from their
// respective vectors.
TF_CHECK_OK(
- fused_instructions_computation_->RemoveParameter(operand_num));
+ fused_instructions_computation()->RemoveParameter(operand_num));
operands_.erase(operands_.begin() + operand_num);
break;
}
@@ -605,7 +606,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// Reread the parameters in the computation.
const std::vector<HloInstruction*>& fused_parameters_ =
- fused_instructions_computation_->parameter_instructions();
+ fused_instructions_computation()->parameter_instructions();
// Add each operand of the clone as an operand of the fusion instruction. A
// complication is that some clone operands may already be operands of the
@@ -638,7 +639,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
CreateParameter(param_no, operand->shape(), param_name);
param_instruction->parent_fusion_instruction_ = this;
- fused_param = fused_instructions_computation_->AddParameter(
+ fused_param = fused_instructions_computation()->AddParameter(
std::move(param_instruction));
AppendOperand(operand);
}
@@ -652,7 +653,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
called_computations_.push_back(computation);
}
}
-
return clone;
}
@@ -663,17 +663,15 @@ RandomDistribution HloInstruction::random_distribution() const {
void HloInstruction::CheckFusionInstruction() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
- fused_instructions_computation_->instructions();
+ fused_instructions_computation()->instructions();
// All instructions owned by this fusion instruction must be fused, and the
// parent fusion instruction of the fused instructions must be 'this'.
for (auto& instruction : fused_instructions_) {
CHECK(instruction->IsFused());
CHECK_EQ(this, instruction->fusion_instruction());
- CHECK_EQ(fused_instructions_computation_.get(), instruction->parent())
+ CHECK_EQ(fused_instructions_computation(), instruction->parent())
<< instruction->ToString();
}
@@ -976,8 +974,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(parent() != nullptr);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
auto new_instruction =
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
@@ -992,9 +988,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
// fused instructions.
std::vector<HloInstruction*> new_fused_parameters;
const std::vector<HloInstruction*>& fused_parameters_ =
- fused_instructions_computation_->parameter_instructions();
+ fused_instructions_computation()->parameter_instructions();
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
- fused_instructions_computation_->instructions();
+ fused_instructions_computation()->instructions();
for (HloInstruction* old_fused_parameter : fused_parameters_) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
@@ -1028,7 +1024,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
}
new_instruction->fusion_kind_ = fusion_kind_;
auto computation_builder = HloComputation::Builder(
- fused_instructions_computation_->name() + ".clone", true);
+ fused_instructions_computation()->name() + ".clone", true);
// We iterated the fusion instructions in reverse post order which means
// that we must reverse our new list of fusion instructions.
for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
@@ -1037,8 +1033,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
}
auto fused_root_ = fused_expression_root();
- new_instruction->fused_instructions_computation_ =
- computation_builder.Build(FindOrDie(old_to_new, fused_root_));
+ new_instruction->called_computations_.push_back(
+ CHECK_NOTNULL(GetModule())
+ ->AddEmbeddedComputation(
+ computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
new_instruction->set_parent(parent());
new_instruction->CheckFusionInstruction();
return new_instruction;
@@ -1769,7 +1767,10 @@ bool HloInstruction::IsFusable() const {
HloComputation* HloInstruction::fused_instructions_computation() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_computation_.get();
+ CHECK(!called_computations_.empty());
+ auto* fused_instructions_computation = called_computations_.front();
+ CHECK(fused_instructions_computation->IsFusionComputation());
+ return fused_instructions_computation;
}
HloInstruction* HloInstruction::fusion_instruction() const {
@@ -1779,32 +1780,24 @@ HloInstruction* HloInstruction::fusion_instruction() const {
HloInstruction* HloInstruction::fused_expression_root() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
- return fused_instructions_computation_->root_instruction();
+ return fused_instructions_computation()->root_instruction();
}
HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
- return fused_instructions_computation_->parameter_instruction(
+ return fused_instructions_computation()->parameter_instruction(
parameter_number);
}
const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
- return fused_instructions_computation_->parameter_instructions();
+ return fused_instructions_computation()->parameter_instructions();
}
const std::list<std::unique_ptr<HloInstruction>>&
HloInstruction::fused_instructions() const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(fused_instructions_computation_ != nullptr &&
- fused_instructions_computation_->IsFusionComputation());
- return fused_instructions_computation_->instructions();
+ return fused_instructions_computation()->instructions();
}
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
@@ -2039,7 +2032,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor,
Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit,
bool ignore_control_predecessors) {
- VLOG(2) << "HloInstruction::Accept(" << name() << ")";
+ VLOG(3) << "HloInstruction::Accept(" << name() << ")";
TF_RETURN_IF_ERROR(
PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
if (call_finish_visit) {
@@ -2055,8 +2048,11 @@ Status HloInstruction::AcceptWithOperandOrder(
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order,
/*ignore_control_predecessors=*/false));
if (call_finish_visit) {
+ VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
+ VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
}
+ VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
return Status::OK();
}
@@ -2458,6 +2454,7 @@ HloModule* HloInstruction::GetModule() const {
}
void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
+ string parent_str = parent() == nullptr ? "noparent" : parent()->name();
name_ = name_uniquer->GetUniqueName(name_);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e2e77e5219..3c188ec83f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -935,10 +935,6 @@ class HloInstruction {
// padding of this pad instruction. Only set for pad instructions.
std::unique_ptr<PaddingConfig> padding_config_;
- // The computation that stores of instructions fused into this fusion
- // instruction. Only set for fusion instructions.
- std::unique_ptr<HloComputation> fused_instructions_computation_;
-
// If this instruction is fused into a fusion instruction, this field points
// to the fusion instruction.
HloInstruction* parent_fusion_instruction_ = nullptr;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index bb1b477e13..5951c833db 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -557,78 +557,89 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
}
TEST_F(HloInstructionTest, SingletonFusionOp) {
+ HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
- auto constant =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto exp =
- HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
-
- auto fusion = HloInstruction::CreateFusion(
- r0f32_, HloInstruction::FusionKind::kLoop, exp.get());
-
- EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
- EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get()));
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {exp}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_THAT(fusion->operands(), ElementsAre(constant));
+ EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, BinaryFusionOp) {
+ HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single binary operation.
- auto constant1 =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto constant2 =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f));
- auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
- constant1.get(), constant2.get());
-
- auto fusion = HloInstruction::CreateFusion(
- r0f32_, HloInstruction::FusionKind::kLoop, add.get());
-
- EXPECT_THAT(fusion->operands(),
- ElementsAre(constant1.get(), constant2.get()));
- EXPECT_THAT(constant1->users(),
- UnorderedElementsAre(fusion.get(), add.get()));
- EXPECT_THAT(constant2->users(),
- UnorderedElementsAre(fusion.get(), add.get()));
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32_, HloOpcode::kAdd, constant1, constant2));
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {add}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
+ EXPECT_THAT(constant1->users(), ElementsAre(fusion));
+ EXPECT_THAT(constant2->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, ChainFusionOp) {
+ HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
- auto constant =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto exp1 =
- HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
- auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
- auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get());
-
- auto fusion = HloInstruction::CreateFusion(
- r0f32_, HloInstruction::FusionKind::kLoop, exp3.get());
- fusion->FuseInstruction(exp2.get());
- fusion->FuseInstruction(exp1.get());
-
- EXPECT_THAT(fusion->operands(), ElementsAre(constant.get()));
- EXPECT_THAT(constant->users(),
- UnorderedElementsAre(fusion.get(), exp1.get()));
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
+ auto exp2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
+ auto exp3 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_THAT(fusion->operands(), ElementsAre(constant));
+ EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
+ HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
- auto constant =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto exp1 =
- HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
- auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
+ auto exp2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
OpMetadata metadata;
metadata.set_op_name("tf_op");
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
- auto fusion = HloInstruction::CreateFusion(
- r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
- auto* fused = fusion->FuseInstruction(exp1.get());
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {exp2, exp1}, HloInstruction::FusionKind::kLoop);
+
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
- EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(
+ metadata, fusion->fused_expression_root()->metadata()));
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(
+ metadata, fusion->fused_expression_root()->operand(0)->metadata()));
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
+ HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -642,33 +653,36 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
std::unique_ptr<HloComputation> computation_x = make_map_computation();
std::unique_ptr<HloComputation> computation_y = make_map_computation();
- auto constant =
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto map_1_x =
- HloInstruction::CreateMap(scalar_shape, {constant.get()},
- computation_x.get(), /*static_operands=*/{});
- auto map_2_x =
- HloInstruction::CreateMap(scalar_shape, {map_1_x.get()},
- computation_x.get(), /*static_operands=*/{});
- auto map_3_y =
- HloInstruction::CreateMap(scalar_shape, {map_2_x.get()},
- computation_y.get(), /*static_operands=*/{});
-
- auto fusion = HloInstruction::CreateFusion(
- scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get());
-
- EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get()));
-
- fusion->FuseInstruction(map_2_x.get());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
+ scalar_shape, {constant}, computation_x.get(), /*static_operands=*/{}));
+ auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
+ scalar_shape, {map_1_x}, computation_x.get(), /*static_operands=*/{}));
+ auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
+ scalar_shape, {map_2_x}, computation_y.get(), /*static_operands=*/{}));
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {map_3_y}, HloInstruction::FusionKind::kLoop);
+ auto* fused_computation = fusion->fused_instructions_computation();
EXPECT_THAT(fusion->called_computations(),
- ElementsAre(computation_y.get(), computation_x.get()));
+ ElementsAre(fused_computation, computation_y.get()));
- fusion->FuseInstruction(map_1_x.get());
- EXPECT_THAT(fusion->called_computations(),
- ElementsAre(computation_y.get(), computation_x.get()));
+ fusion->FuseInstruction(map_2_x);
+ EXPECT_THAT(
+ fusion->called_computations(),
+ ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
+
+ fusion->FuseInstruction(map_1_x);
+ EXPECT_THAT(
+ fusion->called_computations(),
+ ElementsAre(fused_computation, computation_y.get(), computation_x.get()));
}
TEST_F(HloInstructionTest, ComplexFusionOp) {
+ HloComputation::Builder builder(TestName());
// Fuse all instructions in complicated expression:
//
// add = Add(C1, C2)
@@ -680,35 +694,35 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
//
// Notable complexities are repeated operands in a same instruction, different
// shapes, use of value in different expressions.
- auto c1 = HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f));
- auto c3 = HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f));
-
- auto add =
- HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get());
- auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp,
- c2.get(), add.get(), add.get());
- auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get());
- auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply,
- exp.get(), c3.get());
- auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract,
- mul.get(), clamp.get());
+ auto c1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ auto c2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
+ auto c3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
+
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
+ auto clamp = builder.AddInstruction(
+ HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
+ auto mul = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
+ auto sub = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
auto tuple =
- HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()});
+ builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
- auto fusion = HloInstruction::CreateFusion(
- r0f32_, HloInstruction::FusionKind::kLoop, tuple.get());
- fusion->FuseInstruction(sub.get());
- fusion->FuseInstruction(mul.get());
- fusion->FuseInstruction(exp.get());
- fusion->FuseInstruction(clamp.get());
- fusion->FuseInstruction(add.get());
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
+ auto* fusion = computation->CreateFusionInstruction(
+ {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
// Operands in the fusion instruction's operands() vector should be in the
// order in which their users were added fused.
- EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get()));
- EXPECT_THAT(c1->users(),
- UnorderedElementsAre(add.get(), tuple.get(), fusion.get()));
+ EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
+ EXPECT_THAT(c1->users(), ElementsAre(fusion));
}
// Convenience function for comparing two HloInstructions inside of
@@ -864,7 +878,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) {
HloInstruction* max = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
- auto computation = builder.Build();
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@@ -906,7 +921,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
- auto computation = builder.Build();
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
@@ -945,7 +961,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
- auto computation = builder.Build();
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 7230682d0b..4c3ff3bdaf 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -183,6 +183,9 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
// ordering based on dependencies. ExecutesBefore will return true iff there
// exists a path in the HLO computation graph from 'a' to 'b'.
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
predecessors_.emplace(computation.get(),
computation->ComputeReachability());
}
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index d19e8034ac..fd08796e50 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1202,6 +1202,9 @@ StatusOr<bool> HloRematerialization::Run(
// After DCE, the module sequence may include instructions which no longer
// exist.
for (const auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
if (sequence->at(computation.get()).size() !=
computation->instruction_count()) {
// A size mismatch between the computation instruction count and the size
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 17f55f9cfb..922236ee1e 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -400,6 +400,9 @@ CreateMemoryMinimizingSequence(
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
for (const auto& computation : module.computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
TF_ASSIGN_OR_RETURN(sequence[computation.get()],
CreateMemoryMinimizingSequence(
*computation, *points_to_analysis, size_function));
@@ -410,6 +413,7 @@ CreateMemoryMinimizingSequence(
StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
+ CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 482ab9b94a..24af07bd4b 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -211,8 +211,17 @@ bool InstructionFusion::CanFuseOnAllPaths(
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
bool changed = false;
+
+ std::vector<HloComputation*> computations;
for (auto& computation : module->computations()) {
- computation_ = computation.get();
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ computations.push_back(computation.get());
+ }
+ for (auto& computation : computations) {
+ CHECK(!computation->IsFusionComputation());
+ computation_ = computation;
// We want to be able to remove arbitrary instructions from the post order
// and also compare positions of instructions in the post order. To make
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index aafface0b9..7d41be94ce 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -611,6 +611,9 @@ Status CheckLayouts(
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
for (auto& instruction : computation->instructions()) {
// Verify every instruction has a layout and the layout is valid for the
// shape.
@@ -1356,6 +1359,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_,
module->entry_computation()));
+ } else if (computation->IsFusionComputation()) {
+ continue;
} else {
ComputationLayout computation_layout(computation->ComputeProgramShape());
// Setting all embedded computations to the default layout is potentially
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index 4014856b9b..069f85af72 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -29,7 +29,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
return root;
} else {
tensorflow::strings::StrAppend(&root, separator_, *count);
+ // Increment lookup under old 'root' name.
(*count)++;
+ // Initialize count under new 'root' name.
+ count = &(generated_names_[root]);
+ *count = 1;
return root;
}
}
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
index e083226b14..9f12471ffd 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc
@@ -26,6 +26,9 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
std::vector<HloInstruction*> instructions_to_suffix;
for (auto& instruction : computation->instructions()) {
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index 2d35ba5e54..1c648d58c7 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -312,10 +312,17 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
bool changed = false;
- for (const auto& comp : module->computations()) {
+ std::vector<HloComputation*> computations;
+ for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ computations.push_back(computation.get());
+ }
+ for (const auto& comp : computations) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool did_change,
- TrySinkReshapeOrTranspose(comp.get(), instruction));
+ TrySinkReshapeOrTranspose(comp, instruction));
changed |= did_change;
}
}
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 49c1755520..1589d52a25 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -351,16 +351,15 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) {
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
root_shape, HloOpcode::kAdd, reshape0, reshape1));
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
- auto fusion = computation->AddInstruction(HloInstruction::CreateFusion(
- add->shape(), HloInstruction::FusionKind::kLoop, add));
- TF_CHECK_OK(computation->ReplaceInstruction(add, fusion));
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ computation->CreateFusionInstruction({add},
+ HloInstruction::FusionKind::kLoop);
EXPECT_THAT(computation->root_instruction(),
op::Fusion(op::Reshape(param0), op::Reshape(param1)));
- EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Fusion(param0, param1)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index a0c88c6bbc..5858335736 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -172,7 +172,14 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) {
return tensorflow::Status::OK();
};
- for (auto& comp : module->computations()) {
+ std::vector<HloComputation*> computations;
+ for (auto& computation : module->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ computations.push_back(computation.get());
+ }
+ for (auto& comp : computations) {
TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 182e99cf1c..3c4dc19aef 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -135,6 +135,9 @@ TuplePointsToAnalysis::Run(const HloModule* module) {
Status TuplePointsToAnalysis::Analyze() {
points_to_.clear();
for (auto& computation : module_->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
TF_RETURN_IF_ERROR(computation->Accept(this));
TF_RETURN_IF_ERROR(
PopulateDefinedBuffersAndAliases(computation->instructions()));
@@ -451,6 +454,9 @@ string TuplePointsToAnalysis::ToString() const {
string output = tensorflow::strings::Printf(
"TuplePointsToSet for module %s:\n", module_->name().c_str());
for (const auto& computation : module_->computations()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
const char* entry =
computation.get() == module_->entry_computation() ? "entry " : "";
tensorflow::strings::StrAppend(&output, entry, "computation ",