diff options
author | 2018-07-03 17:53:12 -0700 | |
---|---|---|
committer | 2018-07-03 17:56:11 -0700 | |
commit | ae5cdb7028df3215d04bdb84c4e87b79f2b62b97 (patch) | |
tree | 80b4aca9181cfa1db7580c7bc3f6df96f0e831ae /tensorflow/compiler/xla/service/hlo_rematerialization.cc | |
parent | 774a01fa39cb1f31d4484d9cb82a700d0f70c4e3 (diff) |
Automated g4 rollback of changelist 203171335
PiperOrigin-RevId: 203211687
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 0b222f4348..59a8800a7d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1202,17 +1202,14 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( StatusOr<bool> HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes) { + int64 memory_limit_bytes, RematerializationSizes* sizes, + bool run_copy_elision) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); - if (copy_insertion_) { - TF_RETURN_IF_ERROR(copy_insertion_->Run(module).status()); - } - TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1241,12 +1238,13 @@ StatusOr<bool> HloRematerialization::Run( return size_function_(buffer.shape()); }, scheduler_algorithm_)); - if (copy_insertion_) { + if (run_copy_elision) { // We run a separate pass of copy elision here because the sequential // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion_->RemoveUnnecessaryCopies(ordering, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); } // Compute peak memory usage of all computations in the module called in a @@ -1351,10 +1349,10 @@ StatusOr<bool> HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(std::move(scheduler_algorithm), size_function, - copy_insertion); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); + RematerializationSizes* sizes, bool run_copy_elision) { + HloRematerialization remat(scheduler_algorithm, size_function); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + run_copy_elision); } } // namespace xla |