aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.h
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2017-06-16 16:18:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 16:24:15 -0700
commit2473505b70f03838fef4ae6108387ba66d443d62 (patch)
tree5e1bd767e5e5cc1d8cf2e1d329e8bd7108b29c86 /tensorflow/compiler/xla/service/hlo_rematerialization.h
parentdf609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (diff)
Adjust rematerialization cost function to only be the inverse of its benefit (total_memory/saved_memory).
PiperOrigin-RevId: 159290105
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h6
1 files changed, 1 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 1693f93183..42c279d440 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -18,7 +18,6 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -61,7 +60,7 @@ class HloRematerialization {
protected:
HloRematerialization(const ShapeSizeFunction& size_function)
- : size_function_(size_function), cost_analysis_(size_function_) {}
+ : size_function_(size_function) {}
~HloRematerialization() {}
// Runs rematerialization on the given module. Returns whether the module was
@@ -100,9 +99,6 @@ class HloRematerialization {
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
- // Analysis used for computing the rematerialization cost of instructions.
- HloCostAnalysis cost_analysis_;
-
// The peak memory usage of each computation. The map contains only those
// computations called from sequential context
// (CallContext::kSequential). These values are updated as rematerialization