diff options
author | 2017-12-15 15:13:31 -0800 | |
---|---|---|
committer | 2017-12-15 15:17:37 -0800 | |
commit | a81e83eea823ff1f3e6871eb24f85e7ca09dcf72 (patch) | |
tree | 6eb2208b6043b8a4310c76f259e4942ae5ff1cc5 /tensorflow/compiler/xla/service/hlo_rematerialization.cc | |
parent | 7f8e7437693d1051fa378047ae9ee75f91201cb5 (diff) |
[XLA] Add a flag to control the HLO scheduling algorithm choice.
List scheduling is more easily rematerialized sometimes, this
gives the ability to force list scheduling via the API.
PiperOrigin-RevId: 179246142
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 1747790e63..c6b4dc0368 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1213,11 +1213,12 @@ StatusOr<bool> HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, - CreateMemoryMinimizingSequence( - *module, [this](const LogicalBuffer& buffer) { - return size_function_(buffer.shape()); - })); + TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + *module, + [this](const LogicalBuffer& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1318,9 +1319,10 @@ StatusOr<bool> HloRematerialization::Run( /* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, + SchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { - HloRematerialization remat(size_function); + HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } |