diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 62c07d7fac..cf0be30c7a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1203,7 +1203,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( StatusOr<bool> HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, int64 memory_limit_bytes, RematerializationSizes* sizes, - bool run_copy_elision) { + CopyInsertion* copy_insertion) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1238,13 +1238,14 @@ StatusOr<bool> HloRematerialization::Run( return size_function_(buffer.shape()); }, scheduler_algorithm_)); - if (run_copy_elision) { + if (copy_insertion) { // We run a separate pass of copy elision here because the sequential // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); } // Compute peak memory usage of all computations in the module called in a @@ -1349,10 +1350,10 @@ StatusOr<bool> HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, bool run_copy_elision) { + RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, - run_copy_elision); + copy_insertion); } } // namespace xla |