diff options
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 30 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 33 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.h | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_test.cc | 59 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 33 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 104 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_schedule.cc | 52 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_schedule.h | 13 |
15 files changed, 346 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ab86dce510..b8ee6a093e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -291,6 +291,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -303,6 +304,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -331,6 +333,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -1037,7 +1041,6 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1065,7 +1068,6 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", - ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1086,7 +1088,6 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1108,7 +1109,6 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", - ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1177,22 +1177,6 @@ cc_library( ], ) -cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], @@ -1202,7 +1186,6 @@ tf_cc_test( ":hlo_dce", ":hlo_ordering", ":hlo_parser", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1222,7 +1205,6 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1969,6 +1951,8 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1977,6 +1961,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -2413,7 +2398,6 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ccff35f8..a68b7a1bef 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,7 +813,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 99d0cf50ca..93ec2c9438 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -199,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map<int64, InstructionSequence> sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -214,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -322,8 +326,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fe7f2be888..233d2199d1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fe2d3bbbe5..91c5234a6f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -170,6 +170,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span<const HloInstruction* const> instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3a1bc4e328..cfe906d9c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule( return const_cast<HloInstruction*>(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, bool uniquify_names) { @@ -198,12 +206,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } @@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3c3371426b..26fd1b2438 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include <vector> #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -235,6 +237,19 @@ class HloModule { StatusOr<HloInstruction*> LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr<HloComputation> computation, bool is_entry, @@ -262,6 +277,11 @@ class HloModule { static std::atomic<int> next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional<HloSchedule> schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 4bc1bacd7d..400bd4d947 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,9 +19,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -30,6 +33,8 @@ namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 2105f7a349..f1dc08bafa 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector<const HloInstruction*>* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b21071c4b2..b0361c3f02 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -72,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 7c848ba7b4..c54360b063 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -44,6 +45,20 @@ using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() { return false; } + absl::optional<bool> is_scheduled; + std::unordered_map<string, AttrConfig> attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + module_ = absl::make_unique<HloModule>(name, config_); - return ParseComputations(); + if (!ParseComputations()) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK( + module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + } + + return true; } // computations ::= (computation)+ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 43e8736532..cca50fab54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1133,8 +1133,21 @@ ENTRY Computation { } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } @@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { EXPECT_EQ(convolution->feature_group_count(), 1); } +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558..b9c0b0c4ee 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a65b33bf40..3fc5dbeb02 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -21,12 +21,64 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace xla { +/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr<HloScheduleProto> HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + void HloSchedule::set_sequence( const HloComputation* computation, absl::Span<const HloInstruction* const> sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 21c6988638..270fe6039f 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -21,18 +21,20 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/status.h" namespace xla { +class HloModule; + // Class representing a sequence of HLO instructions such as the sequential // execution order of an HLO computation. class HloInstructionSequence { public: HloInstructionSequence() = default; - HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) { + explicit HloInstructionSequence( + absl::Span<const HloInstruction* const> instructions) { for (const HloInstruction* instruction : instructions) { push_back(instruction); } @@ -77,7 +79,12 @@ class HloInstructionSequence { // non-fusion computation in the HLO module. class HloSchedule { public: - HloSchedule(const HloModule* module) : module_(module) {} + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr<HloSchedule> CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr<HloScheduleProto> ToProto() const; // Returns a reference to the sequence for the given computation. const HloInstructionSequence& sequence( |