aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/BUILD30
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto26
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc104
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h13
15 files changed, 346 insertions, 66 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ab86dce510..b8ee6a093e 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -291,6 +291,7 @@ cc_library(
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
+ "hlo_schedule.cc",
"hlo_sharding.cc",
],
hdrs = [
@@ -303,6 +304,7 @@ cc_library(
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
+ "hlo_schedule.h",
"hlo_sharding.h",
],
deps = [
@@ -331,6 +333,8 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@@ -1037,7 +1041,6 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1065,7 +1068,6 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
- ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1086,7 +1088,6 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1108,7 +1109,6 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
- ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1177,22 +1177,6 @@ cc_library(
],
)
-cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib_internal",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- ],
-)
-
tf_cc_test(
name = "hlo_schedule_test",
srcs = ["hlo_schedule_test.cc"],
@@ -1202,7 +1186,6 @@ tf_cc_test(
":hlo_dce",
":hlo_ordering",
":hlo_parser",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1222,7 +1205,6 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1969,6 +1951,8 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1977,6 +1961,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -2413,7 +2398,6 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 13ccff35f8..a68b7a1bef 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -813,7 +813,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 99d0cf50ca..93ec2c9438 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -199,6 +199,17 @@ message HloComputationProto {
int64 root_id = 6;
}
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+ message InstructionSequence {
+ repeated int64 instruction_ids = 1;
+ }
+
+ // Map from computation id to sequence.
+ map<int64, InstructionSequence> sequences = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -214,16 +225,9 @@ message HloModuleProto {
// The id of this module.
int64 id = 5;
-}
-// Serialization of HloOrdering.
-message HloOrderingProto {
- // NOTE: currently only sequential orderings are serialized.
- message SequentialComputation {
- string computation_name = 1;
- repeated string instruction_names = 2;
- }
- repeated SequentialComputation sequential_computations = 1;
+ // The schedule for this module.
+ HloScheduleProto schedule = 7;
}
// Serialization of LogicalBuffer.
@@ -322,8 +326,10 @@ message BufferAssignmentProto {
// Grouping message that contains all of the information above.
message HloProto {
+ reserved 2;
+ reserved "hlo_ordering";
+
HloModuleProto hlo_module = 1;
- HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index fe7f2be888..233d2199d1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
}
string HloComputation::ToString(const HloPrintOptions& options) const {
+ return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const {
+ CHECK_EQ(instruction_order.size(), instruction_count());
+
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
@@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
CanonicalNameMap name_map;
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (const HloInstruction* instruction : instruction_order) {
+ CHECK_EQ(this, instruction->parent());
+
for (int i = 0; i < new_options.indent_amount(); i++) {
s << " ";
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fe2d3bbbe5..91c5234a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -170,6 +170,11 @@ class HloComputation {
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
+ // Overload which accepts an order to emit the instructions in.
+ string ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const;
+
// Returns a serialized representation of this computation.
HloComputationProto ToProto() const;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 3a1bc4e328..cfe906d9c5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
return const_cast<HloInstruction*>(hlo);
}
+Status HloModule::set_schedule(HloSchedule schedule) {
+ TF_RET_CHECK(schedule.module() == this);
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ schedule_ = std::move(schedule);
+ return Status::OK();
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
@@ -198,12 +206,23 @@ void HloModule::ReplaceComputations(
string HloModule::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
- s << "HloModule " << name() << "\n\n";
+ s << "HloModule " << name();
+ if (has_schedule()) {
+ TF_CHECK_OK(schedule().Verify());
+ s << ", is_scheduled=true";
+ }
+ s << "\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
if (computation == entry_computation()) {
s << "ENTRY ";
}
- s << computation->ToString(options) << "\n\n";
+ if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+ s << computation->ToString(
+ options, schedule().sequence(computation).instructions())
+ << "\n\n";
+ } else {
+ s << computation->ToString(options) << "\n\n";
+ }
}
return s.str();
}
@@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const {
}
proto.add_computations()->Swap(&computation_proto);
}
+ if (has_schedule()) {
+ *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+ }
return proto;
}
@@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
}
+ if (proto.has_schedule()) {
+ TF_ASSIGN_OR_RETURN(
+ HloSchedule schedule,
+ HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ }
+
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3c3371426b..26fd1b2438 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -235,6 +237,19 @@ class HloModule {
StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
const HloInstruction* hlo);
+ // Sets the schedule of the module to the given schedule.
+ Status set_schedule(HloSchedule schedule);
+
+ // Clears the schedule of the module.
+ void clear_schedule() { schedule_.reset(); }
+
+ // Returns true if the module has a schedule set.
+ bool has_schedule() const { return schedule_.has_value(); }
+
+ // Returns the schedue of the module. CHECK fails if no schedule is set.
+ const HloSchedule& schedule() const { return *schedule_; }
+ HloSchedule& schedule() { return *schedule_; }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -262,6 +277,11 @@ class HloModule {
static std::atomic<int> next_unique_module_id_;
// A unique id to label modules with.
int unique_id_;
+
+ // The HloSchedule of the module. The schedule if it exists contains a
+ // sequential order of instructions for each non-fusion computation in the
+ // module.
+ absl::optional<HloSchedule> schedule_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4bc1bacd7d..400bd4d947 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -19,9 +19,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -30,6 +33,8 @@ namespace xla {
namespace {
+namespace op = ::xla::testing::opcode_matchers;
+
class HloModuleTest : public HloTestBase {
protected:
HloModuleTest() {}
@@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) {
EXPECT_NE(module_a->unique_id(), module_b->unique_id());
}
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_TRUE(module_copy->has_schedule());
+ TF_ASSERT_OK(module_copy->schedule().Verify());
+ EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+ ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+ module_copy->entry_computation()));
+ EXPECT_THAT(
+ module_copy->schedule()
+ .sequence(module_copy->entry_computation())
+ .instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 2105f7a349..f1dc08bafa 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
!LiveRangeStrictlyBefore(b, a, dataflow);
}
-HloOrderingProto HloOrdering::ToProto() const {
- HloOrderingProto proto;
- for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- SequentialOrder(*computation);
- if (sequence != nullptr) {
- HloOrderingProto::SequentialComputation* proto_computation =
- proto.add_sequential_computations();
- proto_computation->set_computation_name(computation->name());
- for (const HloInstruction* instruction : *sequence) {
- *proto_computation->add_instruction_names() = instruction->name();
- }
- }
- }
- return proto;
-}
-
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
: HloOrdering(module) {}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b21071c4b2..b0361c3f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -72,10 +72,6 @@ class HloOrdering {
virtual string ToString() const = 0;
- // Returns the serialized representation of this ordering.
- // Only sequential computation orders are represented.
- HloOrderingProto ToProto() const;
-
protected:
// Returns true if instruction 'a' executes before instruction 'b'.
// Precondition: 'a' and 'b' are in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 7c848ba7b4..c54360b063 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,6 +45,20 @@ using absl::StrJoin;
const double kF16max = 65504;
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+ HloSchedule schedule(module);
+ for (const HloComputation* computation : module->computations()) {
+ if (!computation->IsFusionComputation()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ schedule.GetOrCreateSequence(computation).push_back(instruction);
+ }
+ }
+ }
+ return schedule;
+}
+
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
@@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() {
return false;
}
+ absl::optional<bool> is_scheduled;
+ std::unordered_map<string, AttrConfig> attrs;
+ attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+ if (!ParseAttributes(attrs)) {
+ return false;
+ }
+
module_ = absl::make_unique<HloModule>(name, config_);
- return ParseComputations();
+ if (!ParseComputations()) {
+ return false;
+ }
+
+ if (is_scheduled.has_value() && *is_scheduled) {
+ TF_CHECK_OK(
+ module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ }
+
+ return true;
}
// computations ::= (computation)+
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 43e8736532..cca50fab54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1133,8 +1133,21 @@ ENTRY Computation {
}
)"
+ },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
}
- });
+
+)"
+}
+});
// clang-format on
}
@@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
EXPECT_EQ(convolution->feature_group_count(), 1);
}
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+ op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+ // As above but in with a different schedule order.
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679558..b9c0b0c4ee 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@ namespace xla {
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
- HloOrderingProto proto_ordering =
- assignment.liveness().hlo_ordering().ToProto();
BufferAssignmentProto proto_assignment = assignment.ToProto();
HloProto proto = MakeHloProto(module);
- proto.mutable_hlo_ordering()->Swap(&proto_ordering);
proto.mutable_buffer_assignment()->Swap(&proto_assignment);
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index a65b33bf40..3fc5dbeb02 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -21,12 +21,64 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+ const HloModule* module, const HloScheduleProto& proto) {
+ tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ for (const HloComputation* computation : module->computations()) {
+ id_to_computation[computation->unique_id()] = computation;
+ }
+
+ HloSchedule schedule(module);
+ for (const auto& id_sequence : proto.sequences()) {
+ int64 computation_id = id_sequence.first;
+
+ auto comp_it = id_to_computation.find(computation_id);
+ TF_RET_CHECK(comp_it != id_to_computation.end())
+ << "No computation exists in HLO module with id " << computation_id;
+ const HloComputation* computation = comp_it->second;
+
+ tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ id_to_instruction[instruction->unique_id()] = instruction;
+ }
+
+ HloInstructionSequence& sequence =
+ schedule.GetOrCreateSequence(computation);
+ for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+ auto instr_it = id_to_instruction.find(instruction_id);
+ TF_RET_CHECK(instr_it != id_to_instruction.end())
+ << "No instruction exists in HLO computation " << computation->name()
+ << " with id " << instruction_id;
+ sequence.push_back(instr_it->second);
+ }
+ }
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+ TF_RETURN_IF_ERROR(Verify());
+ HloScheduleProto proto;
+ for (const auto& id_sequence : sequences_) {
+ int64 computation_id = id_sequence.first;
+ const HloInstructionSequence& sequence = id_sequence.second;
+ HloScheduleProto::InstructionSequence& proto_sequence =
+ (*proto.mutable_sequences())[computation_id];
+ proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+ for (const int64 id : sequence.ids()) {
+ proto_sequence.add_instruction_ids(id);
+ }
+ }
+ return std::move(proto);
+}
+
void HloSchedule::set_sequence(
const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 21c6988638..270fe6039f 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -21,18 +21,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/status.h"
namespace xla {
+class HloModule;
+
// Class representing a sequence of HLO instructions such as the sequential
// execution order of an HLO computation.
class HloInstructionSequence {
public:
HloInstructionSequence() = default;
- HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ explicit HloInstructionSequence(
+ absl::Span<const HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) {
push_back(instruction);
}
@@ -77,7 +79,12 @@ class HloInstructionSequence {
// non-fusion computation in the HLO module.
class HloSchedule {
public:
- HloSchedule(const HloModule* module) : module_(module) {}
+ explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+ // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+ static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+ const HloScheduleProto& proto);
+ StatusOr<HloScheduleProto> ToProto() const;
// Returns a reference to the sequence for the given computation.
const HloInstructionSequence& sequence(