aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2017-06-16 16:18:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 16:24:15 -0700
commit2473505b70f03838fef4ae6108387ba66d443d62 (patch)
tree5e1bd767e5e5cc1d8cf2e1d329e8bd7108b29c86
parentdf609c9de4ea0cae0fb1d41893b8071d67bd6bb2 (diff)
Adjust rematerialization cost function to only be the inverse of its benefit (total_memory/saved_memory).
PiperOrigin-RevId: 159290105
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc57
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h6
2 files changed, 15 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 2c1b0fff4e..fb6d8674b6 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -58,9 +58,8 @@ bool IsRematerializable(const HloInstruction* instruction) {
return false;
}
- // Don't rematerialize instructions with side effects, those with a cost that
- // might not be captured by HloCostAnalysis, or instructions which cannot be
- // cloned safely.
+ // Don't rematerialize instructions with side effects or instructions which
+ // cannot be cloned safely.
switch (instruction->opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConstant:
@@ -802,23 +801,14 @@ bool MemoryUsageTracker::Check() const {
// Computes and returns the cost of rematerializing the given instruction.
// Cost per rematerialized instruction is defined as:
//
-// (flop_count + transcendental_count + element_count) / memory_reduced
+// memory_limit_bytes / memory_reduced
//
-// flop_count: from HloCostAnalysis
-// transcendental_count: from HloCostAnalysis
-// element_count: number of elements accessed in operands and output of
-// instruction
-// memory_reduced: The memory usage reduced by rematerializing the
-// instruction.
-//
-// This is a rough estimate of the extra execution time per byte saved by
-// rematerializing this instruction for its remaining uses. In general, we
-// want the most memory saving for the least latency penalty which is captured
-// by this heuristic.
+// The idea is to choose the operation that will save the most memory for
+// rematerialization and do not worry about how much the compute costs since
+// running out of memory is more harmful than taking longer to get the answer.
int64 RematerializationCost(const HloInstruction* instruction,
const MemoryUsageTracker& memory_tracker,
- const HloCostAnalysis& cost_analysis,
- int64 memory_reduced) {
+ int64 memory_reduced, int64 memory_limit_bytes) {
// If none of the users of 'instruction' have been placed in the sequence (as
// tracked by memory_tracker), then rematerialization of 'instruction' is a
// zero-cost move of 'instruction' in the sequence.
@@ -830,22 +820,8 @@ int64 RematerializationCost(const HloInstruction* instruction,
}
CHECK_GT(memory_reduced, 0);
- const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction);
- const int64 elements_accessed =
- ShapeUtil::IsTuple(instruction->shape())
- ? bytes_accessed
- : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType(
- instruction->shape().element_type());
-
- // Multiply by 256 to improve precision of cost. Without this factor,
- // many instructions such as many elementwise instructions would have
- // zero cost because the bytes reduced can be several times greater than
- // the element count.
- return 256 *
- (cost_analysis.flop_count(*instruction) +
- cost_analysis.transcendental_count(*instruction) +
- elements_accessed) /
- memory_reduced;
+ // Return the inverse of the benefit of rematerialization.
+ return memory_limit_bytes / memory_reduced;
}
// Selects and returns the best candidate instruction for rematerialization.
@@ -856,8 +832,8 @@ int64 RematerializationCost(const HloInstruction* instruction,
HloInstruction* PickRematerializationCandidate(
const MemoryUsageTracker& memory_tracker,
const InstructionList& instruction_list,
- const HloCostAnalysis& cost_analysis,
- const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist) {
+ const tensorflow::gtl::FlatSet<const HloInstruction*>& blacklist,
+ int64 memory_limit_bytes) {
HloInstruction* best = nullptr;
int64 best_cost = 0;
@@ -891,12 +867,12 @@ HloInstruction* PickRematerializationCandidate(
if (memory_reduced <= 0) {
VLOG(5) << "candidate " << candidate->name()
- << " memory reduced = " << memory_reduced << " <= 0";
+ << " memory reduced = " << memory_reduced << " <= 0";
continue;
}
const int cost = RematerializationCost(candidate, memory_tracker,
- cost_analysis, memory_reduced);
+ memory_reduced, memory_limit_bytes);
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
<< memory_reduced << ", cost per byte " << cost;
@@ -1011,7 +987,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
HloInstruction* best = PickRematerializationCandidate(
- memory_tracker, instruction_list, cost_analysis_, blacklist);
+ memory_tracker, instruction_list, blacklist, memory_limit_bytes);
if (best == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
@@ -1211,11 +1187,6 @@ StatusOr<bool> HloRematerialization::Run(
VLOG(1) << "Peak memory usage of module (before): "
<< HumanReadableNumBytes(before_peak_memory);
- // Run cost analysis. Operation cost is used in the heuristic for selecting
- // instructions for rematerialization.
- TF_RETURN_IF_ERROR(
- module->entry_computation()->root_instruction()->Accept(&cost_analysis_));
-
// Subcomputations called by the entry computation will also be
// rematerialized.
TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 1693f93183..42c279d440 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -18,7 +18,6 @@
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -61,7 +60,7 @@ class HloRematerialization {
protected:
HloRematerialization(const ShapeSizeFunction& size_function)
- : size_function_(size_function), cost_analysis_(size_function_) {}
+ : size_function_(size_function) {}
~HloRematerialization() {}
// Runs rematerialization on the given module. Returns whether the module was
@@ -100,9 +99,6 @@ class HloRematerialization {
// Call graph of the hlo_module.
std::unique_ptr<CallGraph> call_graph_;
- // Analysis used for computing the rematerialization cost of instructions.
- HloCostAnalysis cost_analysis_;
-
// The peak memory usage of each computation. The map contains only those
// computations called from sequential context
// (CallContext::kSequential). These values are updated as rematerialization