diff options
author | 2017-07-26 08:35:19 -0700 | |
---|---|---|
committer | 2017-07-26 08:39:12 -0700 | |
commit | 78a9b95436f45438abf3e818307f707e9ae92343 (patch) | |
tree | 94dfdfa894f0dec6ba917b905908985f6594b223 /tensorflow/compiler | |
parent | 49495697cddef73a0dd870176dab488bb2a65520 (diff) |
[XLA] Finish normalizing fusion computations into standard computations
PiperOrigin-RevId: 163210327
Diffstat (limited to 'tensorflow/compiler')
28 files changed, 302 insertions, 190 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a4612bb6c1..8fb0faf026 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1210,6 +1210,7 @@ cc_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4837402c15..691f9f2296 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1586,6 +1586,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) { // module, invalidating iteration. std::vector<HloComputation*> computations; for (auto& comp : module->computations()) { + if (comp->IsFusionComputation()) { + continue; + } computations.push_back(comp.get()); } for (auto& comp : computations) { diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index 5d5d3caa2f..ca2d413e11 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -268,6 +268,9 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) { // module, invalidating iteration. std::vector<HloComputation*> computations; for (auto& comp : module->computations()) { + if (comp->IsFusionComputation()) { + continue; + } computations.push_back(comp.get()); } for (auto& comp : computations) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index ddc3d11b7c..ae31135a1a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1219,6 +1219,9 @@ void BufferAssigner::BuildColocatedBufferSets( const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); for (const HloComputation* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; + } for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { const HloOpcode opcode = instruction->opcode(); @@ -1386,6 +1389,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); + if (computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR(AssignBuffersForComputation( computation, module->config().debug_options(), /*is_thread_local=*/true, colocated_buffers, colocated_allocations, diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 6720a90ef8..f085ffa6bc 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -47,6 +47,9 @@ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run( tensorflow::Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple // element in other instruction's output. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index a3803c34ba..c47abe9c62 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -551,6 +551,9 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { // Add copies of computation root instructions, if needed. FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } VLOG(2) << "computation " << computation->name(); InstructionCopier root_copier(computation->root_instruction(), /*copy_users=*/{}); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b86342d0b3..59e8c75b91 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -519,6 +519,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( new std::map<HloInstruction*, string>()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } auto parallel_computation_iter = parallel_computations.find(embedded_computation); // All parallel computations are considered to be an entry computation for @@ -591,6 +594,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR( ir_emitter .EmitComputation(embedded_computation, @@ -755,6 +761,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR( ir_emitter .EmitComputation(embedded_computation, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index af931f7b01..4d0e0f744a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -125,6 +125,9 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } HloInstruction* root = computation->root_instruction(); // Copy root instruction if it does not define its own top-level buffer. // TODO(b/32885001) Remove these copies (at least for the unambiguous case). diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index e698646d18..a9ef204b46 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -293,12 +293,19 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { StatusOr<bool> FusionMerger::Run(HloModule* module) { bool changed = false; VLOG(2) << "FusionMerger for module: " << module->name(); + std::vector<HloComputation*> computations; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& computation : computations) { VLOG(1) << "Before running FusionInstructionMerger for computation: " << computation->name(); XLA_VLOG_LINES(3, computation->ToString()); - FusionInstructionMerger fusion_merger(computation.get()); + FusionInstructionMerger fusion_merger(computation); TF_RETURN_IF_ERROR(fusion_merger.Run()); changed |= fusion_merger.changed(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index c61e47a93c..81e905a066 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -120,7 +120,8 @@ GpuHloOrdering::GpuHloOrdering( // do that yet since it's hard to ensure that the order here is the order used // by IrEmitterNested. And mismatched ordering bugs would be hard to find. for (auto& computation : module->computations()) { - if (computation.get() != module->entry_computation()) { + if (computation.get() != module->entry_computation() && + !computation->IsFusionComputation()) { predecessors_.emplace(computation.get(), computation->ComputeReachability()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 804efdd906..1a2eed5f60 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -42,6 +42,9 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } for (auto instruction : computation->MakeInstructionPostOrder()) { // Skip dead code. if (instruction->user_count() == 0 && diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index f745683165..0a288a77ad 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -329,7 +330,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count()); } -using FusionCostAnalysis = ::testing::Test; +using FusionCostAnalysis = HloTestBase; TEST_F(FusionCostAnalysis, LoopFusion) { // Do this 4 times with different per-second rates to test the computation of @@ -345,32 +346,32 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // mul = Mul(exp, C3) // sub = Sub(mul, clamp) // tuple = Tuple({sub, sub, mul, C1}) - auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( - /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( - /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)); - auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( - /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)); - - auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(), - c2.get()); - auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, - c2.get(), add.get(), add.get()); - auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get()); - auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, - exp.get(), c3.get()); - auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, - mul.get(), clamp.get()); - auto tuple = HloInstruction::CreateTuple( - {sub.get(), sub.get(), mul.get(), c1.get()}); - - auto fusion = HloInstruction::CreateFusion( - r2f32, HloInstruction::FusionKind::kLoop, tuple.get()); - fusion->FuseInstruction(sub.get()); - fusion->FuseInstruction(mul.get()); - fusion->FuseInstruction(exp.get()); - fusion->FuseInstruction(clamp.get()); - fusion->FuseInstruction(add.get()); + HloComputation::Builder builder(TestName()); + auto c1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); + auto c3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); + auto clamp = builder.AddInstruction( + HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2, add, add)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, exp, c3)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); + auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); // The time given these rates at i == 0 is exactly even among the properties // at 1.0 seconds. For other values, one of the rates is slower so that it @@ -398,18 +399,21 @@ TEST_F(FusionCostAnalysis, NoLayout) { Shape shape_without_layout = shape_with_layout; shape_without_layout.clear_layout(); - auto c1 = HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})); - - auto broadcast = - HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); - auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd, - c1.get(), broadcast.get()); - - auto fusion = HloInstruction::CreateFusion( - shape_with_layout, HloInstruction::FusionKind::kLoop, add.get()); - fusion->FuseInstruction(broadcast.get()); + HloComputation::Builder builder(TestName()); + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5)))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}))); + + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + shape_with_layout, HloOpcode::kAdd, c1, broadcast)); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); HloCostAnalysis fusion_analysis(ShapeSize); ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 0fef89a06d..690c084efb 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -92,6 +92,9 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { StatusOr<bool> HloCSE::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } changed |= CombineConstants(computation.get(), is_layout_sensitive_); std::list<HloInstruction*> post_order = diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 3755b9e4c0..5b2c57da4f 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -38,6 +38,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } std::unordered_set<HloInstruction*> live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f52882cca5..ed8a942d03 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -560,19 +560,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(instruction_to_fuse->IsFusable()); - + if (GetModule()) { + XLA_VLOG_LINES(1, GetModule()->ToString()); + } HloInstruction* clone = nullptr; - if (fused_instructions_computation_ == nullptr) { + if (called_computations_.empty()) { // New fusion instruction. auto builder = HloComputation::Builder("fused_computation", true); builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - fused_instructions_computation_ = builder.Build(); + called_computations_.push_back( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); clone->parent_fusion_instruction_ = this; } else { - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - clone = fused_instructions_computation_->AddInstruction( + clone = fused_instructions_computation()->AddInstruction( instruction_to_fuse->Clone(/*suffix=*/"")); clone->parent_fusion_instruction_ = this; // instruction_to_fuse is necessarily an operand of the fusion instruction. @@ -583,7 +584,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) != operands_.end()); const std::vector<HloInstruction*>& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { if (instruction_to_fuse == operands_[operand_num]) { // replace the fused parameter instruction's uses with the clone. @@ -593,7 +594,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Remove the corresponding fused parameter and operand from their // respective vectors. TF_CHECK_OK( - fused_instructions_computation_->RemoveParameter(operand_num)); + fused_instructions_computation()->RemoveParameter(operand_num)); operands_.erase(operands_.begin() + operand_num); break; } @@ -605,7 +606,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Reread the parameters in the computation. const std::vector<HloInstruction*>& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); // Add each operand of the clone as an operand of the fusion instruction. A // complication is that some clone operands may already be operands of the @@ -638,7 +639,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CreateParameter(param_no, operand->shape(), param_name); param_instruction->parent_fusion_instruction_ = this; - fused_param = fused_instructions_computation_->AddParameter( + fused_param = fused_instructions_computation()->AddParameter( std::move(param_instruction)); AppendOperand(operand); } @@ -652,7 +653,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( called_computations_.push_back(computation); } } - return clone; } @@ -663,17 +663,15 @@ RandomDistribution HloInstruction::random_distribution() const { void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ = - fused_instructions_computation_->instructions(); + fused_instructions_computation()->instructions(); // All instructions owned by this fusion instruction must be fused, and the // parent fusion instruction of the fused instructions must be 'this'. for (auto& instruction : fused_instructions_) { CHECK(instruction->IsFused()); CHECK_EQ(this, instruction->fusion_instruction()); - CHECK_EQ(fused_instructions_computation_.get(), instruction->parent()) + CHECK_EQ(fused_instructions_computation(), instruction->parent()) << instruction->ToString(); } @@ -976,8 +974,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); auto new_instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); @@ -992,9 +988,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( // fused instructions. std::vector<HloInstruction*> new_fused_parameters; const std::vector<HloInstruction*>& fused_parameters_ = - fused_instructions_computation_->parameter_instructions(); + fused_instructions_computation()->parameter_instructions(); const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ = - fused_instructions_computation_->instructions(); + fused_instructions_computation()->instructions(); for (HloInstruction* old_fused_parameter : fused_parameters_) { new_fused_instructions.push_back(old_fused_parameter->Clone()); @@ -1028,7 +1024,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( } new_instruction->fusion_kind_ = fusion_kind_; auto computation_builder = HloComputation::Builder( - fused_instructions_computation_->name() + ".clone", true); + fused_instructions_computation()->name() + ".clone", true); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); @@ -1037,8 +1033,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); } auto fused_root_ = fused_expression_root(); - new_instruction->fused_instructions_computation_ = - computation_builder.Build(FindOrDie(old_to_new, fused_root_)); + new_instruction->called_computations_.push_back( + CHECK_NOTNULL(GetModule()) + ->AddEmbeddedComputation( + computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); new_instruction->set_parent(parent()); new_instruction->CheckFusionInstruction(); return new_instruction; @@ -1769,7 +1767,10 @@ bool HloInstruction::IsFusable() const { HloComputation* HloInstruction::fused_instructions_computation() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation_.get(); + CHECK(!called_computations_.empty()); + auto* fused_instructions_computation = called_computations_.front(); + CHECK(fused_instructions_computation->IsFusionComputation()); + return fused_instructions_computation; } HloInstruction* HloInstruction::fusion_instruction() const { @@ -1779,32 +1780,24 @@ HloInstruction* HloInstruction::fusion_instruction() const { HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->root_instruction(); + return fused_instructions_computation()->root_instruction(); } HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->parameter_instruction( + return fused_instructions_computation()->parameter_instruction( parameter_number); } const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->parameter_instructions(); + return fused_instructions_computation()->parameter_instructions(); } const std::list<std::unique_ptr<HloInstruction>>& HloInstruction::fused_instructions() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(fused_instructions_computation_ != nullptr && - fused_instructions_computation_->IsFusionComputation()); - return fused_instructions_computation_->instructions(); + return fused_instructions_computation()->instructions(); } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -2039,7 +2032,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, bool ignore_control_predecessors) { - VLOG(2) << "HloInstruction::Accept(" << name() << ")"; + VLOG(3) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { @@ -2055,8 +2048,11 @@ Status HloInstruction::AcceptWithOperandOrder( TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order, /*ignore_control_predecessors=*/false)); if (call_finish_visit) { + VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT"; TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); + VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT"; } + VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT"; return Status::OK(); } @@ -2458,6 +2454,7 @@ HloModule* HloInstruction::GetModule() const { } void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { + string parent_str = parent() == nullptr ? "noparent" : parent()->name(); name_ = name_uniquer->GetUniqueName(name_); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e2e77e5219..3c188ec83f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -935,10 +935,6 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr<PaddingConfig> padding_config_; - // The computation that stores of instructions fused into this fusion - // instruction. Only set for fusion instructions. - std::unique_ptr<HloComputation> fused_instructions_computation_; - // If this instruction is fused into a fusion instruction, this field points // to the fusion instruction. HloInstruction* parent_fusion_instruction_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index bb1b477e13..5951c833db 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -557,78 +557,89 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { } TEST_F(HloInstructionTest, SingletonFusionOp) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. - auto constant = - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto exp = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); - - EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); - EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get())); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant)); + EXPECT_THAT(constant->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, BinaryFusionOp) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single binary operation. - auto constant1 = - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto constant2 = - HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)); - auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, - constant1.get(), constant2.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, add.get()); - - EXPECT_THAT(fusion->operands(), - ElementsAre(constant1.get(), constant2.get())); - EXPECT_THAT(constant1->users(), - UnorderedElementsAre(fusion.get(), add.get())); - EXPECT_THAT(constant2->users(), - UnorderedElementsAre(fusion.get(), add.get())); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {add}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2)); + EXPECT_THAT(constant1->users(), ElementsAre(fusion)); + EXPECT_THAT(constant2->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, ChainFusionOp) { + HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. - auto constant = - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto exp1 = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); - auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get()); - - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp3.get()); - fusion->FuseInstruction(exp2.get()); - fusion->FuseInstruction(exp1.get()); - - EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); - EXPECT_THAT(constant->users(), - UnorderedElementsAre(fusion.get(), exp1.get())); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); + auto exp3 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(fusion->operands(), ElementsAre(constant)); + EXPECT_THAT(constant->users(), ElementsAre(fusion)); } TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. - auto constant = - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto exp1 = - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); - auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1)); OpMetadata metadata; metadata.set_op_name("tf_op"); exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); - auto* fused = fusion->FuseInstruction(exp1.get()); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp2, exp1}, HloInstruction::FusionKind::kLoop); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); - EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + metadata, fusion->fused_expression_root()->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + metadata, fusion->fused_expression_root()->operand(0)->metadata())); } TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { + HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -642,33 +653,36 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { std::unique_ptr<HloComputation> computation_x = make_map_computation(); std::unique_ptr<HloComputation> computation_y = make_map_computation(); - auto constant = - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto map_1_x = - HloInstruction::CreateMap(scalar_shape, {constant.get()}, - computation_x.get(), /*static_operands=*/{}); - auto map_2_x = - HloInstruction::CreateMap(scalar_shape, {map_1_x.get()}, - computation_x.get(), /*static_operands=*/{}); - auto map_3_y = - HloInstruction::CreateMap(scalar_shape, {map_2_x.get()}, - computation_y.get(), /*static_operands=*/{}); - - auto fusion = HloInstruction::CreateFusion( - scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get()); - - EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get())); - - fusion->FuseInstruction(map_2_x.get()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {constant}, computation_x.get(), /*static_operands=*/{})); + auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {map_1_x}, computation_x.get(), /*static_operands=*/{})); + auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( + scalar_shape, {map_2_x}, computation_y.get(), /*static_operands=*/{})); + + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {map_3_y}, HloInstruction::FusionKind::kLoop); + auto* fused_computation = fusion->fused_instructions_computation(); EXPECT_THAT(fusion->called_computations(), - ElementsAre(computation_y.get(), computation_x.get())); + ElementsAre(fused_computation, computation_y.get())); - fusion->FuseInstruction(map_1_x.get()); - EXPECT_THAT(fusion->called_computations(), - ElementsAre(computation_y.get(), computation_x.get())); + fusion->FuseInstruction(map_2_x); + EXPECT_THAT( + fusion->called_computations(), + ElementsAre(fused_computation, computation_y.get(), computation_x.get())); + + fusion->FuseInstruction(map_1_x); + EXPECT_THAT( + fusion->called_computations(), + ElementsAre(fused_computation, computation_y.get(), computation_x.get())); } TEST_F(HloInstructionTest, ComplexFusionOp) { + HloComputation::Builder builder(TestName()); // Fuse all instructions in complicated expression: // // add = Add(C1, C2) @@ -680,35 +694,35 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // // Notable complexities are repeated operands in a same instruction, different // shapes, use of value in different expressions. - auto c1 = HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)); - auto c3 = HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)); - - auto add = - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get()); - auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, - c2.get(), add.get(), add.get()); - auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get()); - auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, - exp.get(), c3.get()); - auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, - mul.get(), clamp.get()); + auto c1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f))); + auto c2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f))); + auto c3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); + auto clamp = builder.AddInstruction( + HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp)); auto tuple = - HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()}); + builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto fusion = HloInstruction::CreateFusion( - r0f32_, HloInstruction::FusionKind::kLoop, tuple.get()); - fusion->FuseInstruction(sub.get()); - fusion->FuseInstruction(mul.get()); - fusion->FuseInstruction(exp.get()); - fusion->FuseInstruction(clamp.get()); - fusion->FuseInstruction(add.get()); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); // Operands in the fusion instruction's operands() vector should be in the // order in which their users were added fused. - EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get())); - EXPECT_THAT(c1->users(), - UnorderedElementsAre(add.get(), tuple.get(), fusion.get())); + EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2)); + EXPECT_THAT(c1->users(), ElementsAre(fusion)); } // Convenience function for comparing two HloInstructions inside of @@ -864,7 +878,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -906,7 +921,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -945,7 +961,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); - auto computation = builder.Build(); + HloModule module(TestName()); + auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 7230682d0b..4c3ff3bdaf 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -183,6 +183,9 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } predecessors_.emplace(computation.get(), computation->ComputeReachability()); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index d19e8034ac..fd08796e50 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1202,6 +1202,9 @@ StatusOr<bool> HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. for (const auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } if (sequence->at(computation.get()).size() != computation->instruction_count()) { // A size mismatch between the computation instruction count and the size diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 17f55f9cfb..922236ee1e 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -400,6 +400,9 @@ CreateMemoryMinimizingSequence( TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(&module)); for (const auto& computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } TF_ASSIGN_OR_RETURN(sequence[computation.get()], CreateMemoryMinimizingSequence( *computation, *points_to_analysis, size_function)); @@ -410,6 +413,7 @@ CreateMemoryMinimizingSequence( StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { + CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); return CreateMemoryMinimizingSequence(computation, *points_to_analysis, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 482ab9b94a..24af07bd4b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -211,8 +211,17 @@ bool InstructionFusion::CanFuseOnAllPaths( StatusOr<bool> InstructionFusion::Run(HloModule* module) { bool changed = false; + + std::vector<HloComputation*> computations; for (auto& computation : module->computations()) { - computation_ = computation.get(); + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& computation : computations) { + CHECK(!computation->IsFusionComputation()); + computation_ = computation; // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index aafface0b9..7d41be94ce 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -611,6 +611,9 @@ Status CheckLayouts( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } for (auto& instruction : computation->instructions()) { // Verify every instruction has a layout and the layout is valid for the // shape. @@ -1356,6 +1359,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) { if (computation == module->entry_computation()) { TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_, module->entry_computation())); + } else if (computation->IsFusionComputation()) { + continue; } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); // Setting all embedded computations to the default layout is potentially diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 4014856b9b..069f85af72 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -29,7 +29,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { return root; } else { tensorflow::strings::StrAppend(&root, separator_, *count); + // Increment lookup under old 'root' name. (*count)++; + // Initialize count under new 'root' name. + count = &(generated_names_[root]); + *count = 1; return root; } } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index e083226b14..9f12471ffd 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -26,6 +26,9 @@ StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) { VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } std::vector<HloInstruction*> instructions_to_suffix; for (auto& instruction : computation->instructions()) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 2d35ba5e54..1c648d58c7 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -312,10 +312,17 @@ StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr<bool> ReshapeMover::Run(HloModule* module) { bool changed = false; - for (const auto& comp : module->computations()) { + std::vector<HloComputation*> computations; + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (const auto& comp : computations) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp.get(), instruction)); + TrySinkReshapeOrTranspose(comp, instruction)); changed |= did_change; } } diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 49c1755520..1589d52a25 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -351,16 +351,15 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - auto fusion = computation->AddInstruction(HloInstruction::CreateFusion( - add->shape(), HloInstruction::FusionKind::kLoop, add)); - TF_CHECK_OK(computation->ReplaceInstruction(add, fusion)); + HloModule module(TestName()); + auto computation = module.AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index a0c88c6bbc..5858335736 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -172,7 +172,14 @@ StatusOr<bool> TransposeFolding::Run(HloModule* module) { return tensorflow::Status::OK(); }; - for (auto& comp : module->computations()) { + std::vector<HloComputation*> computations; + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + computations.push_back(computation.get()); + } + for (auto& comp : computations) { TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 182e99cf1c..3c4dc19aef 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -135,6 +135,9 @@ TuplePointsToAnalysis::Run(const HloModule* module) { Status TuplePointsToAnalysis::Analyze() { points_to_.clear(); for (auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); @@ -451,6 +454,9 @@ string TuplePointsToAnalysis::ToString() const { string output = tensorflow::strings::Printf( "TuplePointsToSet for module %s:\n", module_->name().c_str()); for (const auto& computation : module_->computations()) { + if (computation->IsFusionComputation()) { + continue; + } const char* entry = computation.get() == module_->entry_computation() ? "entry " : ""; tensorflow::strings::StrAppend(&output, entry, "computation ", |