diff options
author | 2017-06-16 16:18:44 -0700 | |
---|---|---|
committer | 2017-06-16 16:24:15 -0700 | |
commit | 2473505b70f03838fef4ae6108387ba66d443d62 (patch) | |
tree | 5e1bd767e5e5cc1d8cf2e1d329e8bd7108b29c86 /tensorflow/compiler/xla/service/hlo_rematerialization.h | |
parent | df609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (diff) |
Adjust rematerialization cost function to only be the inverse of its benefit (total_memory/saved_memory).
PiperOrigin-RevId: 159290105
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.h | 6 |
1 files changed, 1 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 1693f93183..42c279d440 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -18,7 +18,6 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -61,7 +60,7 @@ class HloRematerialization { protected: HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function), cost_analysis_(size_function_) {} + : size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -100,9 +99,6 @@ class HloRematerialization { // Call graph of the hlo_module. std::unique_ptr<CallGraph> call_graph_; - // Analysis used for computing the rematerialization cost of instructions. - HloCostAnalysis cost_analysis_; - // The peak memory usage of each computation. The map contains only those // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization |