aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_rematerialization.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-03-24 14:30:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-24 15:50:42 -0700
commit060e002e70e1abf04144a107fde939bda4051ac5 (patch)
tree896c50a4def912c0edd0c5b4ebf9e770b390dec6 /tensorflow/compiler/xla/service/hlo_rematerialization.h
parentb7de7cb58908c517febbb085eb5dbaa1a92cf3a2 (diff)
[XLA] Rematerialize subcomputations.
Extend HLO rematerialization to rematerialize subcomputations in addition to the entry computations. Outer nesting levels of computations are rematerialized before inner nesting levels because inner subcomputations may be while bodies where rematerialization is more expensive. Also Also fix latent bug in call_graph dealing with fusion instructions, and extend HloInstruction::Clone to accept a string suffix (eg, "remat") for the clone name. Change: 151179956
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_rematerialization.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h56
1 files changed, 42 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 580a805ef0..86e1998b89 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -16,7 +16,9 @@
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -26,10 +28,10 @@ class HloRematerialization {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
- // Rematerialize HLO instructions in the entry computation of the given module
- // to reduce maximum memory use below memory_limit_bytes where memory use is
- // defined as the total size of all live HLO instruction values. Parameters
- // and constants are included in memory use estimates. Method parameters:
+ // Rematerialize HLO instructions in the given module to reduce peak memory
+ // use below memory_limit_bytes where memory use is defined as the total size
+ // of all live HLO instruction values. Parameters and constants are included
+ // in memory use estimates. Method parameters:
//
// size_function: Function which returns the size in bytes of the top-level
// buffer of the given shape.
@@ -57,29 +59,55 @@ class HloRematerialization {
SequentialHloOrdering::HloModuleSequence* sequence);
protected:
- HloRematerialization(const ShapeSizeFunction& size_function,
- int64 memory_limit_bytes)
- : size_function_(size_function),
- memory_limit_bytes_(memory_limit_bytes) {}
+ HloRematerialization(const ShapeSizeFunction& size_function)
+ : size_function_(size_function), cost_analysis_(size_function_) {}
~HloRematerialization() {}
// Runs rematerialization on the given module. Returns whether the module was
- // changed.
+ // changed. memory_limit is the target maximum peak memory usage by the
+ // module. sequence should be an empty HloModuleSequence. Upon return sequence
+ // contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
- SequentialHloOrdering::HloModuleSequence* sequence);
+ SequentialHloOrdering::HloModuleSequence* sequence,
+ int64 memory_limit);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
// backend. Rematerialized instructions will be added to the HLO computation
// and inserted into 'order'.
StatusOr<bool> RematerializeComputation(
- HloComputation* computation, std::vector<const HloInstruction*>* order);
+ HloComputation* computation,
+ SequentialHloOrdering::HloModuleSequence* sequence,
+ int64 computation_memory_limit);
- // Returns the total size of the shape (including nested elements) in bytes.
- int64 TotalSizeBytes(const Shape& shape);
+ // Computes and returns the peak memory used by the given computation. The
+ // peak memory is the maximum total size of all live HLO instruction values at
+ // any program point. 'order' is the order in which the HLO instructions will
+ // be emitted which is used to determine lifespans of HLO values.
+ StatusOr<int64> ComputePeakMemory(
+ const HloComputation* computation,
+ const std::vector<const HloInstruction*>& order) const;
+ // Returns the peak memory usage of the called computations for the given
+ // instruction. Zero is returned if the instruction calls no computations.
+ StatusOr<int64> CalledComputationsMemoryUsage(
+ const HloInstruction* instruction) const;
+
+ // Function which computes the size of the top-level buffer of a shape.
const ShapeSizeFunction size_function_;
- const int64 memory_limit_bytes_;
+
+ // Call graph of the hlo_module.
+ std::unique_ptr<CallGraph> call_graph_;
+
+ // Analysis used for computing the rematerialization cost of instructions.
+ HloCostAnalysis cost_analysis_;
+
+ // The peak memory usage of each computation. The map contains only those
+ // computations called from sequential context
+ // (CallContext::kSequential). These values are updated as rematerialization
+ // occurs.
+ tensorflow::gtl::FlatMap<const HloComputation*, int64>
+ computation_peak_memory_;
};
} // namespace xla