diff options
author | Mark Heffernan <meheff@google.com> | 2018-09-05 17:17:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 17:22:22 -0700 |
commit | 6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch) | |
tree | 1afd3dff710c4f63bae267807435abdcec784edb | |
parent | 017599d0a1fa7a7227a43649db67e96311033a4e (diff) |
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description ***
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap...
***
PiperOrigin-RevId: 211726890
27 files changed, 1325 insertions, 905 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 64141ed191..ab86dce510 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -989,6 +989,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1036,6 +1037,7 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1049,6 +1051,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1062,6 +1065,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1082,6 +1086,7 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1089,6 +1094,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1102,6 +1108,7 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", + ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1125,6 +1132,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1170,6 +1178,43 @@ cc_library( ) cc_library( + name = "hlo_schedule", + srcs = ["hlo_schedule.cc"], + hdrs = ["hlo_schedule.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_ordering", + ":hlo_parser", + ":hlo_schedule", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], hdrs = ["hlo_scheduling.h"], @@ -1177,6 +1222,7 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1205,6 +1251,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2366,6 +2413,7 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8b8c6bfd26..0f0af57626 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -617,18 +617,24 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector<const HloInstruction*>* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector<const HloInstruction*>* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -1064,7 +1070,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(&assignment->module()); FlatSet<const LogicalBuffer*> all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1072,7 +1078,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector<const HloInstruction*>* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1090,7 +1096,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<LazyBestFitHeap>(alignment)), - assignment->module(), module_sequence, + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1121,7 +1127,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<LazyBestFitHeap>(alignment)), - *computation, *instruction_sequence, + *computation, HloInstructionSequence(*instruction_sequence), assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 56bd67fb55..5a231c173d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase { HloModule* module, absl::Span<const HloInstruction* const> instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector<const HloInstruction*>(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - absl::make_unique<SequentialHloOrdering>(module, - module_sequence), + module, absl::make_unique<SequentialHloOrdering>(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1785,11 +1783,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, - absl::make_unique<SequentialHloOrdering>(module, sequence), + module, absl::make_unique<SequentialHloOrdering>(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2096,17 +2093,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module, absl::make_unique<SequentialHloOrdering>(module, sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique<SequentialHloOrdering>(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2263,29 +2268,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector<const HloInstruction*>& sequence) { - tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2340,27 +2322,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module, absl::make_unique<SequentialHloOrdering>(module, sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique<SequentialHloOrdering>(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 26e26e316d..414bfe7999 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique<SequentialHloOrdering>( - module.get(), sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>(schedule)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector<const HloInstruction*> order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique<SequentialHloOrdering>( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>(schedule)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector<const HloInstruction*> order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique<SequentialHloOrdering>( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>(schedule)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 796f36510e..e7b6075994 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -584,16 +584,14 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run(module.get(), - absl::make_unique<SequentialHloOrdering>( - module.get(), module_sequence), + absl::make_unique<SequentialHloOrdering>(schedule), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -627,9 +625,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -637,9 +636,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector<char, 40> function_name_vector; @@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, - BufferAssigner::Run( - module, - absl::make_unique<SequentialHloOrdering>(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique<SequentialHloOrdering>(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e5cf15c686..df8c2a636b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -110,7 +110,7 @@ IrEmitter::IrEmitter( StatusOr<llvm::Function*> IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector<const HloInstruction*>* instruction_order) { + const std::vector<const HloInstruction*>* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 58a333b8fb..3df99464ba 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr<llvm::Function*> EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector<const HloInstruction*>* instruction_order); + const std::vector<const HloInstruction*>* instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a68b7a1bef..13ccff35f8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,6 +813,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 743035a84e..ea9376e101 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -198,11 +199,12 @@ StatusOr<std::unique_ptr<GpuHloSchedule>> GpuHloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 30a0e7cecd..07a7fc67aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -33,7 +33,9 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. class GpuHloSchedule { public: // Constructs an GpuHloSchedule for the given module, based on the given diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf..e0f3a7e0e2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,14 +47,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* @@ -71,13 +70,13 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector<const HloInstruction*>& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +85,13 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( /*static*/ StatusOr<HeapSimulator::Result> HeapSimulator::Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +101,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector<const BufferValue*> dead_buffers_to_free; std::vector<const BufferValue*> operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector<const HloInstruction*>& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation) : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index af05bedee7..ffbf947d5a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -88,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr<int64> MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr<int64> MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr<Result> Run( - std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,7 +126,7 @@ class HeapSimulator { static StatusOr<Result> Run( std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -134,21 +134,19 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr<HeapAlgorithm> algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -169,11 +167,11 @@ class HeapSimulator { const std::unique_ptr<HeapAlgorithm> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const HloSchedule* schedule_; const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 7ad8a107e1..00a25db467 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -149,10 +153,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<HeapCallRecorder>(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +191,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( absl::make_unique<HeapCallRecorder>(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 54abe3345d..0cd0ab36fc 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 72b236801a..510d6360a1 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector<const HloInstruction*> order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0581d5c404..2105f7a349 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector<const HloInstruction*>& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector<const HloInstruction*>& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector<const HloInstruction*>* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector<string> pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back( - absl::StrFormat("computation %s order:", computation->name())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector<const HloInstruction*> instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back(absl::StrFormat(" %s", instruction->name())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector<const HloInstruction*>& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d..b21071c4b2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap<const HloComputation*, - std::vector<const HloInstruction*>>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c..6b6005e7a5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +378,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c9629926ea..0a0a6a323e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -962,8 +962,7 @@ StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage( } StatusOr<bool> HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -971,7 +970,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1145,7 +1145,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1179,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,20 +1194,21 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( return changed; } -StatusOr<bool> HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); +StatusOr<bool> HloRematerialization::Run(HloModule* module, + HloSchedule* schedule, + int64 memory_limit_bytes, + RematerializationSizes* sizes, + CopyInsertion* copy_insertion) { + // The schedule is constructed entirely by this method. + TF_RET_CHECK(schedule->empty()); VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, + // Create initial schedule of HLO instructions. + TF_ASSIGN_OR_RETURN(*schedule, + ScheduleModule(*module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, @@ -1217,16 +1218,7 @@ StatusOr<bool> HloRematerialization::Run( // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. - - // First create a copy of the schedule which contains HloInstruction unique - // ids instead of HloInstruction*. This is necessary for updating the - // schedule below. - // TODO(b/113175018): Remove this when the HLO schedule is self-contained - // and can update itself. - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(*sequence); - - SequentialHloOrdering ordering(module, *sequence); + SequentialHloOrdering ordering(*schedule); TF_RETURN_IF_ERROR( copy_insertion->RemoveUnnecessaryCopies(ordering, module)); @@ -1241,10 +1233,10 @@ StatusOr<bool> HloRematerialization::Run( // The passes above can add and remove copies, update the schedule to // account for these transformations. Newly added instructions will be // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + TF_RETURN_IF_ERROR(schedule->Update()); TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(module, *sequence), module)); + SequentialHloOrdering(*schedule), module)); } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1271,12 +1263,13 @@ StatusOr<bool> HloRematerialization::Run( // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, schedule](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + ComputePeakMemory( + node.computation(), + schedule->sequence(node.computation()).instructions())); } return Status::OK(); }, @@ -1295,7 +1288,7 @@ StatusOr<bool> HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, + module->entry_computation(), schedule, adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an @@ -1305,30 +1298,7 @@ StatusOr<bool> HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector<const HloInstruction*>& order = sequence->at(computation); - std::vector<const HloInstruction*> old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(schedule->Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1366,11 +1336,10 @@ StatusOr<bool> HloRematerialization::Run( /* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, + MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, copy_insertion); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350a..fa0414b472 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -50,7 +51,7 @@ class HloRematerialization { // // hlo_module: HLO module to rematerialize instructions in. // - // sequence: Should point to an empty HloModuleSequence. Upon return + // schedule: Should point to an empty HloSchedule. Upon return // contains the HLO instruction order which was used for // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. @@ -75,8 +76,8 @@ class HloRematerialization { static StatusOr<bool> RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); + HloSchedule* schedule, RematerializationSizes* sizes, + CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -87,10 +88,9 @@ class HloRematerialization { // Runs rematerialization on the given module. Returns whether the module was // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence + // module. schedule should be an empty HloSchedule. Upon return sequence // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr<bool> Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, + StatusOr<bool> Run(HloModule* module, HloSchedule* schedule, int64 memory_limit, RematerializationSizes* sizes, CopyInsertion* copy_insertion); @@ -98,10 +98,9 @@ class HloRematerialization { // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr<bool> RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr<bool> RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380..83cb113bfb 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr<bool> RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module, + HloSchedule* schedule) { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + schedule, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +187,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +205,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +244,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +278,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +318,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +478,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +573,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000..a65b33bf40 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,291 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include <queue> +#include <vector> + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span<const HloInstruction* const> sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet<int> ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap<const HloInstruction*, + std::vector<const HloInstruction*>> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap<const HloInstruction*, int> + unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue<const HloInstruction*> worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector<const HloInstruction*>* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector<HloComputation*> nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + it = sequences_.erase(it); + } else { + it++; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector<HloComputation*> nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector<string> pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000..21c6988638 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include <vector> + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector<const HloInstruction*>& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector<int>& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector<const HloInstruction*> instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector<int> id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + HloSchedule(const HloModule* module) : module_(module) {} + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span<const HloInstruction* const> sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000..eb52582bb5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include <memory> +#include <string> + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector<const HloInstruction*>& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 0fc3b268c0..9bfb0af96c 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -70,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr<std::vector<const HloInstruction*>> Run( + static StatusOr<HloInstructionSequence> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -229,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector<const HloInstruction*> CreateSchedule() { - std::vector<const HloInstruction*> schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper( +StatusOr<HloInstructionSequence> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -392,7 +392,7 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper( } // namespace -StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( +StatusOr<HloInstructionSequence> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -443,7 +443,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector<const HloInstruction*> sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -463,7 +463,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler( +StatusOr<HloInstructionSequence> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -473,18 +473,16 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler( memory_by_computation); } -StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler( +StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap<const HloComputation*, int64>& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector<const HloInstruction*>{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( +StatusOr<HloInstructionSequence> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -499,7 +497,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector<const HloInstruction*> list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -508,7 +506,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -518,7 +516,7 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector<const HloInstruction*> post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -545,32 +543,35 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( } } -StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule( +StatusOr<HloSchedule> ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( +StatusOr<HloInstructionSequence> ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -581,187 +582,4 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( size_function, nullptr, empty_map); } -tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence; - for (const auto& computation_sequence : sequence) { - for (const HloInstruction* instruction : computation_sequence.second) { - id_sequence[computation_sequence.first].push_back( - instruction->unique_id()); - } - } - return id_sequence; -} - -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence) { - // Map from unique ID to HloInstruction pointer for instructions in the - // module. - tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction; - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet<int> ids_in_schedule; - std::vector<HloComputation*> nonfusion_computations = - module.MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK( - id_to_instruction.insert({instruction->unique_id(), instruction}) - .second); - } - for (int id : id_sequence.at(computation)) { - ids_in_schedule.insert(id); - } - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // module, but not in schedule) which use X. If an instruction is not in the - // map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap<const HloInstruction*, - std::vector<const HloInstruction*>> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap<const HloInstruction*, int> - unscheduled_operand_count; - // For each computation, this is the set of newly added instructions which - // have no operands. These must be handled specially and are added to the - // beginning of the schedule. - tensorflow::gtl::FlatMap<const HloComputation*, - std::vector<const HloInstruction*>> - new_zero_operand_instructions; - for (const HloComputation* computation : nonfusion_computations) { - new_zero_operand_instructions[computation] = {}; - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - if (instruction->operands().empty()) { - new_zero_operand_instructions[computation].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - for (const HloComputation* computation : nonfusion_computations) { - std::vector<const HloInstruction*> old_computation_sequence = - std::move(sequence->at(computation)); - sequence->at(computation).clear(); - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue<const HloInstruction*> worklist; - for (const HloInstruction* instruction : - new_zero_operand_instructions.at(computation)) { - worklist.push(instruction); - } - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - sequence->at(computation).push_back(instruction); - std::vector<const HloInstruction*>* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : id_sequence.at(computation)) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. - continue; - } - const HloInstruction* instruction = it->second; - worklist.push(instruction); - schedule_worklist(); - } - } - - TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); - return Status::OK(); -} - -Status VerifySchedule( - const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence) { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(2, module.ToString()); - VLOG(2) << sequence; - - // Verify the set of computations in the sequence is exactly the set of - // computations in the module. - std::vector<HloComputation*> nonfusion_computations = - module.MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); - tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module( - module.computations().begin(), module.computations().end()); - for (const auto& computation_sequence : sequence) { - TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position; - int pos = 0; - for (const HloInstruction* instruction : sequence.at(computation)) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index d06b8d9a5c..54e32340ba 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,14 +33,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function<StatusOr<std::vector<const HloInstruction*>>( +typedef std::function<StatusOr<HloInstructionSequence>( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler( +StatusOr<HloInstructionSequence> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -47,7 +48,7 @@ StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( +StatusOr<HloInstructionSequence> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -55,7 +56,7 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler( +StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -65,63 +66,26 @@ StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler( +StatusOr<HloInstructionSequence> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap<const HloComputation*, int64>& memory_by_computation); -// Returns an HloModuleSequence which seeks to minimize the memory required for +// Returns an HloSchedule which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule( +StatusOr<HloSchedule> ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( +StatusOr<HloInstructionSequence> ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); -// Transforms the given schedule such that it is (again) a valid schedule for -// the module. This is used to update a schedule after the HLO module has been -// transformed in some way. In general, the only transformations to the module -// for which a schedule can be updated is the addition or removal of -// instructions to/from the module. Updating the schedule after new dependencies -// between existing instructions in the module is not supported and may result -// in an error status returned. -// -// Instructions in the module which also exist in the given schedule will remain -// in the same order in the updated schedule. Instructions which exist in the -// module but not in the given schedule will be placed as early as possible in -// the updated schedule. -// -// 'id_sequence' is a mirror of the given schedule 'sequence' but with -// HloInstruction ids rather than HloInstruction pointers. This should be -// constructed using ComputeIdSchedule below after the schedule is constructed -// but before the HLO module is transformed. -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence); - -// Constructs a copy of the given schedule but with HloInstruction unique ids -// rather than HloInstruction pointers. This is necessary for updating a -// schedule as HloInstruction points in the schedule may become invalid if -// instructions are removed from the module. Used by UpdateSchedule above.. -// TODO(b/113175018): Remove this function when HLO schedule is its own class. -tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); - -// Verifies that the given schedule is valid for the given module. Specifically, -// the schedule contains exactly the instructions in the module and every -// dependency in the module is satisfied in the schedule. -Status VerifySchedule(const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index d49d09d459..6afe51997e 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <string> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -67,19 +68,20 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector<const HloInstruction*>& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } @@ -108,28 +110,26 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector<const HloInstruction*>& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); std::unordered_map<string, const HloInstruction*> instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { + for (const HloInstruction* instruction : sequence) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. The output buffer is aliased, // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -281,19 +281,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -332,18 +331,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -391,12 +390,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -406,262 +405,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. Cond is the largest one. // The output buffer of the while is aliased. EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } -TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(sequence); - std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second; - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(entry_schedule, sequence.begin()->second); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(sequence); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), - hlo) != sequence.at(entry).end(); - }; - - EXPECT_EQ(sequence.at(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(sequence); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 4); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(sequence); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 3); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 2); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> - id_sequence = ComputeIdSchedule(sequence); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(body).size(), 7); - EXPECT_EQ(sequence.at(cond).size(), 4); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(body).size(), 1); - EXPECT_EQ(sequence.at(cond).size(), 5); -} - } // namespace } // namespace xla |