aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-08 09:23:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-08 09:36:42 -0700
commita6bb25c05c15e39d04baf6dac30200db367e1ef2 (patch)
treed5140caaeff44e59360ef86cac78c15925c1b0f7 /tensorflow
parent4136bd49d92c80de3c6ae03ffdb2524b36e96fa8 (diff)
Make scheduling and rematerialization HLO passes.
Now that HloSchedule is a field on the HLO module, scheduling can be done as an HLO pass. Similarly, rematerialization which requires a schedule can also be a pass which just gets the schedule from the module. Also as a clean up, hoist calls to CopyInsertion out of rematerialization. PiperOrigin-RevId: 212119795
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/BUILD24
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling.cc)20
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h (renamed from tensorflow/compiler/xla/service/hlo_scheduling.h)38
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (renamed from tensorflow/compiler/xla/service/hlo_scheduling_test.cc)28
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc88
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h83
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc75
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
16 files changed, 188 insertions, 187 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e784663ff6..6ace6d3271 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1012,8 +1012,8 @@ cc_library(
":buffer_value_containers",
":heap_simulator",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_proto",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1041,8 +1041,8 @@ tf_cc_test(
":cpu_plugin",
":flatten_call_graph",
":hlo",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1088,8 +1088,8 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dataflow_analysis",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1185,9 +1185,9 @@ tf_cc_test(
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1199,13 +1199,14 @@ tf_cc_test(
)
cc_library(
- name = "hlo_scheduling",
- srcs = ["hlo_scheduling.cc"],
- hdrs = ["hlo_scheduling.h"],
+ name = "hlo_memory_scheduler",
+ srcs = ["hlo_memory_scheduler.cc"],
+ hdrs = ["hlo_memory_scheduler.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1219,15 +1220,15 @@ cc_library(
)
tf_cc_test(
- name = "hlo_scheduling_test",
- srcs = ["hlo_scheduling_test.cc"],
+ name = "hlo_memory_scheduler_test",
+ srcs = ["hlo_memory_scheduler_test.cc"],
deps = [
":heap_simulator",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
- ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -2394,12 +2395,11 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
- ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
+ ":hlo_memory_scheduler",
":hlo_ordering",
- ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 0f0af57626..65fa951afe 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 5a231c173d..c30abd1d3e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -30,11 +30,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 2368ac8c6a..039cbbff6c 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -122,7 +122,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index e7b6075994..18fc144efe 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -77,12 +77,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 6791e15ee0..569381f5b0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -813,9 +813,9 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:buffer_value",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index ea9376e101..02a0d028c1 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 9bfb0af96c..c7ec88d450 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <map>
#include <queue>
@@ -582,4 +582,22 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
size_function, nullptr, empty_map);
}
+HloMemoryScheduler::HloMemoryScheduler(
+ const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm)
+ : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+ ScheduleModule(*module, size_function_, algorithm_));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ return true;
+}
+
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+ bool changed = module->has_schedule();
+ module->clear_schedule();
+ return changed;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 54e32340ba..5e02868eba 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
#include <vector>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -86,6 +87,37 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+ // size_function is the function returning the number of bytes required for a
+ // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+ // specified, then DefaultMemoryScheduler is used.
+ HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
+ ~HloMemoryScheduler() override = default;
+ absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ LogicalBuffer::SizeFunction size_function_;
+ MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+ HloDescheduler() = default;
+ ~HloDescheduler() override = default;
+ absl::string_view name() const override { return "hlo-descheduler"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 6afe51997e..1b9e9bfc77 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include <memory>
#include <string>
@@ -67,22 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(
- HloSchedule schedule,
- ScheduleModule(*module, [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- }));
+ HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ });
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+ EXPECT_TRUE(changed);
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+
// Verify that all instructions are in the sequence.
const std::vector<const HloInstruction*>& sequence =
- schedule.sequence(module->entry_computation()).instructions();
+ module->schedule().sequence(module->entry_computation()).instructions();
EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
// The first instruction should be the parameter and the last the root "sub".
EXPECT_EQ(param, sequence.front());
EXPECT_EQ(sub, sequence.back());
- SequentialHloOrdering ordering(schedule);
+ SequentialHloOrdering ordering(module->schedule());
EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+ // Clear the schedule using the descheduling pass.
+ HloDescheduler descheduler;
+ EXPECT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+ descheduler.Run(module.get()));
+ EXPECT_TRUE(descheduler_changed);
+ EXPECT_FALSE(module->has_schedule());
}
TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 6b6005e7a5..00970bcda3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 0a0a6a323e..bd6dd79b67 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -27,15 +27,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -1194,51 +1193,12 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
return changed;
}
-StatusOr<bool> HloRematerialization::Run(HloModule* module,
- HloSchedule* schedule,
- int64 memory_limit_bytes,
- RematerializationSizes* sizes,
- CopyInsertion* copy_insertion) {
- // The schedule is constructed entirely by this method.
- TF_RET_CHECK(schedule->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
VLOG(1) << "HloRematerialization() with memory limit of "
- << HumanReadableNumBytes(memory_limit_bytes);
+ << HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
- // Create initial schedule of HLO instructions.
- TF_ASSIGN_OR_RETURN(*schedule,
- ScheduleModule(*module,
- [this](const BufferValue& buffer) {
- return size_function_(buffer.shape());
- },
- scheduler_algorithm_));
- if (copy_insertion) {
- // We run a separate pass of copy elision here because the sequential
- // ordering from the HLO schedule allows for more copies to be eliminated.
- // TODO(b/80249101): Instead of a separate copy elision pass, use the
- // ordering from the HLO schedule directly for copy insertion.
- SequentialHloOrdering ordering(*schedule);
- TF_RETURN_IF_ERROR(
- copy_insertion->RemoveUnnecessaryCopies(ordering, module));
-
- // RemoveUnnecessaryCopies only considers interference when determining
- // whether it is legal to remove a copy. However, copies in the graph may be
- // necessary for other reason such as preventing a constant from being live
- // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
- // TODO(b/80249101): Break copy insertion into several passes and run each
- // one once in the regular HLO pipeline.
- TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
-
- // The passes above can add and remove copies, update the schedule to
- // account for these transformations. Newly added instructions will be
- // placed ASAP in the schedule.
- TF_RETURN_IF_ERROR(schedule->Update());
-
- TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
- SequentialHloOrdering(*schedule), module));
- }
-
+ TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
// Adjust memory limit to account for the output of the entry
@@ -1254,7 +1214,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
});
const int64 adjusted_memory_limit_bytes =
- memory_limit_bytes - module_output_size;
+ memory_limit_bytes_ - module_output_size;
VLOG(1) << "Adjusted memory limit accounting for output ("
<< HumanReadableNumBytes(module_output_size)
<< "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
@@ -1263,13 +1223,14 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// sequential context.
call_graph_ = CallGraph::Build(module);
TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
- [this, schedule](const CallGraphNode& node) -> Status {
+ [this, module](const CallGraphNode& node) -> Status {
if (node.context() == CallContext::kSequential) {
TF_ASSIGN_OR_RETURN(
computation_peak_memory_[node.computation()],
- ComputePeakMemory(
- node.computation(),
- schedule->sequence(node.computation()).instructions()));
+ ComputePeakMemory(node.computation(),
+ module->schedule()
+ .sequence(node.computation())
+ .instructions()));
}
return Status::OK();
},
@@ -1287,9 +1248,10 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// Subcomputations called by the entry computation will also be
// rematerialized.
- TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
- module->entry_computation(), schedule,
- adjusted_memory_limit_bytes));
+ TF_ASSIGN_OR_RETURN(
+ bool changed,
+ RematerializeComputation(module->entry_computation(), &module->schedule(),
+ adjusted_memory_limit_bytes));
// Rematerialization can introduce dead code. This occurs if all uses of an
// instruction are replaced with rematerializations of the instruction.
@@ -1298,7 +1260,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
// After DCE, the module sequence may include instructions which no longer
// exist.
- TF_RETURN_IF_ERROR(schedule->Update());
+ TF_RETURN_IF_ERROR(module->schedule().Update());
VLOG(1) << "Rematerialized " << instructions_rematerialized_
<< " instructions in module " << module->name() << "; "
<< net_instructions_added_ << " net instructions added";
@@ -1315,32 +1277,22 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module,
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
<< reduced_peak_memory << " bytes)";
- if (sizes != nullptr) {
- sizes->before_bytes = before_peak_memory;
- sizes->after_bytes = current_peak_memory;
+ if (sizes_ != nullptr) {
+ sizes_->before_bytes = before_peak_memory;
+ sizes_->after_bytes = current_peak_memory;
}
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
- if (current_peak_memory > memory_limit_bytes) {
+ if (current_peak_memory > memory_limit_bytes_) {
LOG(WARNING) << absl::StrFormat(
"Can't reduce memory use below %s (%d bytes) by rematerialization; "
"only reduced to %s (%d bytes)",
- HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+ HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
HumanReadableNumBytes(current_peak_memory), current_peak_memory);
}
return changed;
}
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
- const HloRematerialization::ShapeSizeFunction& size_function,
- int64 memory_limit_bytes, HloModule* hlo_module,
- MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule,
- RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
- HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes,
- copy_insertion);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index fa0414b472..e2aaf18b3e 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,17 +17,23 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
namespace xla {
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
@@ -38,10 +44,7 @@ class HloRematerialization {
int64 after_bytes;
};
- // Rematerialize HLO instructions in the given module to reduce peak memory
- // use below memory_limit_bytes where memory use is defined as the total size
- // of all live HLO instruction values. Parameters and constants are included
- // in memory use estimates. Method parameters:
+ // Constructor parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -49,51 +52,27 @@ class HloRematerialization {
// memory_limit_bytes: The threshold number of bytes to reduce memory use to
// via rematerialization.
//
- // hlo_module: HLO module to rematerialize instructions in.
- //
- // schedule: Should point to an empty HloSchedule. Upon return
- // contains the HLO instruction order which was used for
- // rematerialization. This is the order in which HLO instructions should
- // be emitted to minimize memory use.
- //
- // sizes: Optional outparam that indicates the peak memory usage of the HLO
- // module before/after rematerialization.
- //
- // copy_insertion: If non-null, run copy elision after scheduling. This
- // pass is used to eliminate copies that were inserted by copy insertion
- // before HLO scheduling.
- //
- // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
- // insertion is integrated with HLO scheduling.
- //
- // Returns whether any instructions were rematerialized. If memory use is
- // already below the given limit then no instructions are rematerialized and
- // false is returned.
- //
- // CSE will undo the effects of this optimization and should not be run after
- // this pass. In general, this pass should be run very late immediately before
- // code generation.
- static StatusOr<bool> RematerializeAndSchedule(
- const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
- HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
- HloSchedule* schedule, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion = nullptr);
-
- protected:
- HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
- const ShapeSizeFunction& size_function)
- : scheduler_algorithm_(scheduler_algorithm),
- size_function_(size_function) {}
+ // sizes: Pointer to data structure which records the peak memory usage of
+ // the HLO module before/after rematerialization. Value are set during
+ // Run(). Can be nullptr.
+ HloRematerialization(const ShapeSizeFunction& size_function,
+ int64 memory_limit_bytes, RematerializationSizes* sizes)
+ : size_function_(size_function),
+ memory_limit_bytes_(memory_limit_bytes),
+ sizes_(sizes) {}
~HloRematerialization() {}
+ absl::string_view name() const override { return "rematerialization"; }
+
// Runs rematerialization on the given module. Returns whether the module was
- // changed. memory_limit is the target maximum peak memory usage by the
- // module. schedule should be an empty HloSchedule. Upon return sequence
- // contains the memory-minimizing order in which to emit the HLO instructions.
- StatusOr<bool> Run(HloModule* module, HloSchedule* schedule,
- int64 memory_limit, RematerializationSizes* sizes,
- CopyInsertion* copy_insertion);
+ // changed. Requires that the module has a schedule set
+ // (HloModule::has_schedule() is true) before running. Returns whether any
+ // instructions were rematerialized. If memory use is already below the limit
+ // specified in the constructor then no instructions are rematerialized and
+ // false is returned.
+ StatusOr<bool> Run(HloModule* module) override;
+ protected:
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
@@ -121,6 +100,14 @@ class HloRematerialization {
// Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
+ // The threshold number of bytes to reduce memory use to via
+ // rematerialization.
+ const int64 memory_limit_bytes_;
+
+ // Pointer to data structure which records the peak memory usage of the HLO
+ // module before/after rematerialization
+ RematerializationSizes* sizes_;
+
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83cb113bfb..4b611fe450 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -142,12 +142,15 @@ class HloRematerializationTest : public HloTestBase {
}
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
- HloModule* module,
- HloSchedule* schedule) {
+ HloModule* module) {
TF_EXPECT_OK(verifier().Run(module).status());
- return HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- schedule, /*sizes=*/nullptr);
+ HloMemoryScheduler scheduler(
+ [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+ DefaultMemoryScheduler);
+ TF_EXPECT_OK(scheduler.Run(module).status());
+ HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+ /*sizes=*/nullptr);
+ return remat.Run(module);
}
// Various shapes used in the canned computations.
@@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) {
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0);
- HloSchedule schedule(module.get());
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/14 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024, module.get()));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -187,10 +189,12 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
- EXPECT_EQ(schedule.sequence(computation)
+ EXPECT_EQ(module->schedule()
+ .sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bcast);
}
@@ -205,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
EXPECT_EQ(computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/20 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024, module.get()));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
@@ -244,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/17 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024, module.get()));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
@@ -278,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
EXPECT_EQ(entry_computation->instruction_count(), 7);
EXPECT_EQ(body_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/15 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024, module.get()));
EXPECT_TRUE(changed);
// Both computations should have rematerialized instructions added.
@@ -318,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
- HloSchedule schedule(module.get());
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/13 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024, module.get()));
EXPECT_TRUE(changed);
// All computations should have rematerialized instructions added.
@@ -384,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
ASSERT_EQ(count_rngs(entry_computation), 1);
const int64 original_instruction_count =
entry_computation->instruction_count();
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), &schedule));
+ bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get()));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -478,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
EXPECT_EQ(add_3->operand(0), bcast);
EXPECT_EQ(add_4->operand(0), bcast);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module.get()));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
EXPECT_EQ(entry_computation->instruction_count(), 8);
- HloSchedule schedule(module.get());
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
- /*memory_limit_bytes=*/22 * 1024,
- module.get(), &schedule));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024, module.get()));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
index eb52582bb5..1424569ac1 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 069586a738..50f39cbcb5 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1123,6 +1123,11 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+ // If the module has a schedule, it must be valid.
+ if (module->has_schedule()) {
+ TF_RETURN_IF_ERROR(module->schedule().Verify());
+ }
+
return false;
}