aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 62c07d7fac..cf0be30c7a 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1203,7 +1203,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit_bytes, RematerializationSizes* sizes,
- bool run_copy_elision) {
+ CopyInsertion* copy_insertion) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@@ -1238,13 +1238,14 @@ StatusOr<bool> HloRematerialization::Run(
return size_function_(buffer.shape());
},
scheduler_algorithm_));
- if (run_copy_elision) {
+ if (copy_insertion) {
// We run a separate pass of copy elision here because the sequential
// ordering from the HLO schedule allows for more copies to be eliminated.
// TODO(b/80249101): Instead of a separate copy elision pass, use the
// ordering from the HLO schedule directly for copy insertion.
SequentialHloOrdering ordering(module, *sequence);
- TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
+ TF_RETURN_IF_ERROR(
+ copy_insertion->RemoveUnnecessaryCopies(ordering, module));
}
// Compute peak memory usage of all computations in the module called in a
@@ -1349,10 +1350,10 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, bool run_copy_elision) {
+ RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
HloRematerialization remat(scheduler_algorithm, size_function);
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
- run_copy_elision);
+ copy_insertion);
}
} // namespace xla