aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-06 08:56:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 09:01:16 -0700
commita41e270641f0613413e1929c9010f32882b4d26b (patch)
treee2b5bee9504f8e6319b1d40c8d56fb8f7fbf4cce
parent35f28c57da8aad4a79503db955b11fed63b1fe34 (diff)
Add HloSchedule to HloModule.
Add HloSchedule as a field on HloModule. This will enable scheduling to be a normal HLO pass and enable some passes such as copy insertion to more easily use tighter instruction live ranges based on the schedule. This change required adding HloSchedule to the "hlo" library because of circular dependencies. Nothing except for tests actually sets the schedule at the moment, but follow up cls will add a scheduling pass which will do so. PiperOrigin-RevId: 211815293
-rw-r--r--tensorflow/compiler/xla/service/BUILD30
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto26
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc104
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h13
15 files changed, 346 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ab86dce510..b8ee6a093e 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -291,6 +291,7 @@ cc_library(
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
+ "hlo_schedule.cc",
"hlo_sharding.cc",
],
hdrs = [
@@ -303,6 +304,7 @@ cc_library(
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
+ "hlo_schedule.h",
"hlo_sharding.h",
],
deps = [
@@ -331,6 +333,8 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@@ -1037,7 +1041,6 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1065,7 +1068,6 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
- ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1086,7 +1088,6 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1108,7 +1109,6 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
- ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1177,22 +1177,6 @@ cc_library(
],
)
-cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib_internal",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- ],
-)
-
tf_cc_test(
name = "hlo_schedule_test",
srcs = ["hlo_schedule_test.cc"],
@@ -1202,7 +1186,6 @@ tf_cc_test(
":hlo_dce",
":hlo_ordering",
":hlo_parser",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1222,7 +1205,6 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1969,6 +1951,8 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1977,6 +1961,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -2413,7 +2398,6 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 13ccff35f8..a68b7a1bef 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -813,7 +813,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 99d0cf50ca..93ec2c9438 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -199,6 +199,17 @@ message HloComputationProto {
int64 root_id = 6;
}
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+ message InstructionSequence {
+ repeated int64 instruction_ids = 1;
+ }
+
+ // Map from computation id to sequence.
+ map<int64, InstructionSequence> sequences = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -214,16 +225,9 @@ message HloModuleProto {
// The id of this module.
int64 id = 5;
-}
-// Serialization of HloOrdering.
-message HloOrderingProto {
- // NOTE: currently only sequential orderings are serialized.
- message SequentialComputation {
- string computation_name = 1;
- repeated string instruction_names = 2;
- }
- repeated SequentialComputation sequential_computations = 1;
+ // The schedule for this module.
+ HloScheduleProto schedule = 7;
}
// Serialization of LogicalBuffer.
@@ -322,8 +326,10 @@ message BufferAssignmentProto {
// Grouping message that contains all of the information above.
message HloProto {
+ reserved 2;
+ reserved "hlo_ordering";
+
HloModuleProto hlo_module = 1;
- HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index fe7f2be888..233d2199d1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
}
string HloComputation::ToString(const HloPrintOptions& options) const {
+ return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const {
+ CHECK_EQ(instruction_order.size(), instruction_count());
+
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
@@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
CanonicalNameMap name_map;
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (const HloInstruction* instruction : instruction_order) {
+ CHECK_EQ(this, instruction->parent());
+
for (int i = 0; i < new_options.indent_amount(); i++) {
s << " ";
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fe2d3bbbe5..91c5234a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -170,6 +170,11 @@ class HloComputation {
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
+ // Overload which accepts an order to emit the instructions in.
+ string ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const;
+
// Returns a serialized representation of this computation.
HloComputationProto ToProto() const;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 3a1bc4e328..cfe906d9c5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
return const_cast<HloInstruction*>(hlo);
}
+Status HloModule::set_schedule(HloSchedule schedule) {
+ TF_RET_CHECK(schedule.module() == this);
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ schedule_ = std::move(schedule);
+ return Status::OK();
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
@@ -198,12 +206,23 @@ void HloModule::ReplaceComputations(
string HloModule::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
- s << "HloModule " << name() << "\n\n";
+ s << "HloModule " << name();
+ if (has_schedule()) {
+ TF_CHECK_OK(schedule().Verify());
+ s << ", is_scheduled=true";
+ }
+ s << "\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
if (computation == entry_computation()) {
s << "ENTRY ";
}
- s << computation->ToString(options) << "\n\n";
+ if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+ s << computation->ToString(
+ options, schedule().sequence(computation).instructions())
+ << "\n\n";
+ } else {
+ s << computation->ToString(options) << "\n\n";
+ }
}
return s.str();
}
@@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const {
}
proto.add_computations()->Swap(&computation_proto);
}
+ if (has_schedule()) {
+ *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+ }
return proto;
}
@@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
}
+ if (proto.has_schedule()) {
+ TF_ASSIGN_OR_RETURN(
+ HloSchedule schedule,
+ HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ }
+
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3c3371426b..26fd1b2438 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -235,6 +237,19 @@ class HloModule {
StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
const HloInstruction* hlo);
+ // Sets the schedule of the module to the given schedule.
+ Status set_schedule(HloSchedule schedule);
+
+ // Clears the schedule of the module.
+ void clear_schedule() { schedule_.reset(); }
+
+ // Returns true if the module has a schedule set.
+ bool has_schedule() const { return schedule_.has_value(); }
+
+ // Returns the schedue of the module. CHECK fails if no schedule is set.
+ const HloSchedule& schedule() const { return *schedule_; }
+ HloSchedule& schedule() { return *schedule_; }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -262,6 +277,11 @@ class HloModule {
static std::atomic<int> next_unique_module_id_;
// A unique id to label modules with.
int unique_id_;
+
+ // The HloSchedule of the module. The schedule if it exists contains a
+ // sequential order of instructions for each non-fusion computation in the
+ // module.
+ absl::optional<HloSchedule> schedule_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4bc1bacd7d..400bd4d947 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -19,9 +19,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -30,6 +33,8 @@ namespace xla {
namespace {
+namespace op = ::xla::testing::opcode_matchers;
+
class HloModuleTest : public HloTestBase {
protected:
HloModuleTest() {}
@@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) {
EXPECT_NE(module_a->unique_id(), module_b->unique_id());
}
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_TRUE(module_copy->has_schedule());
+ TF_ASSERT_OK(module_copy->schedule().Verify());
+ EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+ ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+ module_copy->entry_computation()));
+ EXPECT_THAT(
+ module_copy->schedule()
+ .sequence(module_copy->entry_computation())
+ .instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 2105f7a349..f1dc08bafa 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
!LiveRangeStrictlyBefore(b, a, dataflow);
}
-HloOrderingProto HloOrdering::ToProto() const {
- HloOrderingProto proto;
- for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- SequentialOrder(*computation);
- if (sequence != nullptr) {
- HloOrderingProto::SequentialComputation* proto_computation =
- proto.add_sequential_computations();
- proto_computation->set_computation_name(computation->name());
- for (const HloInstruction* instruction : *sequence) {
- *proto_computation->add_instruction_names() = instruction->name();
- }
- }
- }
- return proto;
-}
-
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
: HloOrdering(module) {}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b21071c4b2..b0361c3f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -72,10 +72,6 @@ class HloOrdering {
virtual string ToString() const = 0;
- // Returns the serialized representation of this ordering.
- // Only sequential computation orders are represented.
- HloOrderingProto ToProto() const;
-
protected:
// Returns true if instruction 'a' executes before instruction 'b'.
// Precondition: 'a' and 'b' are in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 7c848ba7b4..c54360b063 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,6 +45,20 @@ using absl::StrJoin;
const double kF16max = 65504;
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+ HloSchedule schedule(module);
+ for (const HloComputation* computation : module->computations()) {
+ if (!computation->IsFusionComputation()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ schedule.GetOrCreateSequence(computation).push_back(instruction);
+ }
+ }
+ }
+ return schedule;
+}
+
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
@@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() {
return false;
}
+ absl::optional<bool> is_scheduled;
+ std::unordered_map<string, AttrConfig> attrs;
+ attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+ if (!ParseAttributes(attrs)) {
+ return false;
+ }
+
module_ = absl::make_unique<HloModule>(name, config_);
- return ParseComputations();
+ if (!ParseComputations()) {
+ return false;
+ }
+
+ if (is_scheduled.has_value() && *is_scheduled) {
+ TF_CHECK_OK(
+ module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ }
+
+ return true;
}
// computations ::= (computation)+
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 43e8736532..cca50fab54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1133,8 +1133,21 @@ ENTRY Computation {
}
)"
+ },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
}
- });
+
+)"
+}
+});
// clang-format on
}
@@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
EXPECT_EQ(convolution->feature_group_count(), 1);
}
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+ op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+ // As above but in with a different schedule order.
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679558..b9c0b0c4ee 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@ namespace xla {
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
- HloOrderingProto proto_ordering =
- assignment.liveness().hlo_ordering().ToProto();
BufferAssignmentProto proto_assignment = assignment.ToProto();
HloProto proto = MakeHloProto(module);
- proto.mutable_hlo_ordering()->Swap(&proto_ordering);
proto.mutable_buffer_assignment()->Swap(&proto_assignment);
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index a65b33bf40..3fc5dbeb02 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -21,12 +21,64 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+ const HloModule* module, const HloScheduleProto& proto) {
+ tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ for (const HloComputation* computation : module->computations()) {
+ id_to_computation[computation->unique_id()] = computation;
+ }
+
+ HloSchedule schedule(module);
+ for (const auto& id_sequence : proto.sequences()) {
+ int64 computation_id = id_sequence.first;
+
+ auto comp_it = id_to_computation.find(computation_id);
+ TF_RET_CHECK(comp_it != id_to_computation.end())
+ << "No computation exists in HLO module with id " << computation_id;
+ const HloComputation* computation = comp_it->second;
+
+ tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ id_to_instruction[instruction->unique_id()] = instruction;
+ }
+
+ HloInstructionSequence& sequence =
+ schedule.GetOrCreateSequence(computation);
+ for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+ auto instr_it = id_to_instruction.find(instruction_id);
+ TF_RET_CHECK(instr_it != id_to_instruction.end())
+ << "No instruction exists in HLO computation " << computation->name()
+ << " with id " << instruction_id;
+ sequence.push_back(instr_it->second);
+ }
+ }
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+ TF_RETURN_IF_ERROR(Verify());
+ HloScheduleProto proto;
+ for (const auto& id_sequence : sequences_) {
+ int64 computation_id = id_sequence.first;
+ const HloInstructionSequence& sequence = id_sequence.second;
+ HloScheduleProto::InstructionSequence& proto_sequence =
+ (*proto.mutable_sequences())[computation_id];
+ proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+ for (const int64 id : sequence.ids()) {
+ proto_sequence.add_instruction_ids(id);
+ }
+ }
+ return std::move(proto);
+}
+
void HloSchedule::set_sequence(
const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 21c6988638..270fe6039f 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -21,18 +21,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/status.h"
namespace xla {
+class HloModule;
+
// Class representing a sequence of HLO instructions such as the sequential
// execution order of an HLO computation.
class HloInstructionSequence {
public:
HloInstructionSequence() = default;
- HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ explicit HloInstructionSequence(
+ absl::Span<const HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) {
push_back(instruction);
}
@@ -77,7 +79,12 @@ class HloInstructionSequence {
// non-fusion computation in the HLO module.
class HloSchedule {
public:
- HloSchedule(const HloModule* module) : module_(module) {}
+ explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+ // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+ static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+ const HloScheduleProto& proto);
+ StatusOr<HloScheduleProto> ToProto() const;
// Returns a reference to the sequence for the given computation.
const HloInstructionSequence& sequence(