aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.cc
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2017-12-15 15:13:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 15:17:37 -0800
commita81e83eea823ff1f3e6871eb24f85e7ca09dcf72 (patch)
tree6eb2208b6043b8a4310c76f259e4942ae5ff1cc5 /tensorflow/compiler/xla/service/hlo_rematerialization.cc
parent7f8e7437693d1051fa378047ae9ee75f91201cb5 (diff)
[XLA] Add a flag to control the HLO scheduling algorithm choice.
List scheduling is more easily rematerialized sometimes, this gives the ability to force list scheduling via the API. PiperOrigin-RevId: 179246142
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 1747790e63..c6b4dc0368 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1213,11 +1213,12 @@ StatusOr<bool> HloRematerialization::Run(
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
// Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence,
- CreateMemoryMinimizingSequence(
- *module, [this](const LogicalBuffer& buffer) {
- return size_function_(buffer.shape());
- }));
+ TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence(
+ *module,
+ [this](const LogicalBuffer& buffer) {
+ return size_function_(buffer.shape());
+ },
+ scheduler_algorithm_));
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@@ -1318,9 +1319,10 @@ StatusOr<bool> HloRematerialization::Run(
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
const HloRematerialization::ShapeSizeFunction& size_function,
int64 memory_limit_bytes, HloModule* hlo_module,
+ SchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes) {
- HloRematerialization remat(size_function);
+ HloRematerialization remat(scheduler_algorithm, size_function);
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
}