aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h10
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 59b4cf5dcc..2ec004350a 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -57,8 +58,9 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
- // run_copy_elision: Enable copy elision. This pass is used to eliminate
- // copies that were inserted before HLO scheduling.
+ // copy_insertion: If non-null, run copy elision after scheduling. This
+ // pass is used to eliminate copies that were inserted by copy insertion
+ // before HLO scheduling.
//
// TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
// insertion is integrated with HLO scheduling.
@@ -74,7 +76,7 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes, bool run_copy_elision = true);
+ RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -90,7 +92,7 @@ class HloRematerialization {
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit, RematerializationSizes* sizes,
- bool run_copy_elision);
+ CopyInsertion* copy_insertion);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the