diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-23 06:09:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 06:16:01 -0700 |
commit | 9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (patch) | |
tree | 8b90a81e214344e9926ee96564fc4f4bc6de530e /tensorflow/compiler/xla/service/hlo_instruction.h | |
parent | d2c36c1b3e84e4e29c2853aa421ada45ff5fd396 (diff) |
Reduce the memory usage of sharding domains
Previously the domain instructions inserted before and after an `n`
element tuple required `O(n^2)` memory (and compute) because every
operand and user had its own domain instruction with a tuple sharding
and tuple shape for the exit domains what constructed `n` HloSharding
and `n` Shape proto per domain.
After this change we keep track of the domain instructions inserted and
if we already have a domain instruction with the correct operand and
metadata then we re-use it instead of creating a new one.
Additionally we change HloInstruction and ShardingMetadata to store a
std::shared_ptr to HloSharding so the same instance can be shared by
many instructions. This CL doesn't update all uses to remove all of the
duplicated HloShardings but handles the most wastful cases to reduce
memory usage.
PiperOrigin-RevId: 209924260
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 121a9e55f6..432bb464f3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1037,6 +1037,8 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; @@ -1051,7 +1053,10 @@ class HloInstruction { // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = absl::make_unique<HloSharding>(sharding); + sharding_ = std::make_shared<const HloSharding>(sharding); + } + void set_sharding(std::shared_ptr<const HloSharding> sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1652,7 +1657,10 @@ class HloInstruction { bool copy_elision_allowed_ = true; // The sharding, if one exists. - std::unique_ptr<HloSharding> sharding_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr<const HloSharding> sharding_; // Fields used by the kDomain instruction. std::unique_ptr<DomainMetadata> operand_side_metadata_; |