aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 06:09:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 06:16:01 -0700
commit9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (patch)
tree8b90a81e214344e9926ee96564fc4f4bc6de530e /tensorflow/compiler/xla/service/hlo_instruction.h
parentd2c36c1b3e84e4e29c2853aa421ada45ff5fd396 (diff)
Reduce the memory usage of sharding domains
Previously the domain instructions inserted before and after an `n` element tuple required `O(n^2)` memory (and compute) because every operand and user had its own domain instruction with a tuple sharding and tuple shape for the exit domains what constructed `n` HloSharding and `n` Shape proto per domain. After this change we keep track of the domain instructions inserted and if we already have a domain instruction with the correct operand and metadata then we re-use it instead of creating a new one. Additionally we change HloInstruction and ShardingMetadata to store a std::shared_ptr to HloSharding so the same instance can be shared by many instructions. This CL doesn't update all uses to remove all of the duplicated HloShardings but handles the most wastful cases to reduce memory usage. PiperOrigin-RevId: 209924260
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 121a9e55f6..432bb464f3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1037,6 +1037,8 @@ class HloInstruction {
CHECK(has_sharding());
return *sharding_;
}
+ std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
+
// Returns the sharding applied to this operator, or default_ if none exists.
const HloSharding& sharding_or_default(const HloSharding& default_) const {
return sharding_ ? *sharding_ : default_;
@@ -1051,7 +1053,10 @@ class HloInstruction {
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
void set_sharding(const HloSharding& sharding) {
- sharding_ = absl::make_unique<HloSharding>(sharding);
+ sharding_ = std::make_shared<const HloSharding>(sharding);
+ }
+ void set_sharding(std::shared_ptr<const HloSharding> sharding) {
+ sharding_ = std::move(sharding);
}
void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
@@ -1652,7 +1657,10 @@ class HloInstruction {
bool copy_elision_allowed_ = true;
// The sharding, if one exists.
- std::unique_ptr<HloSharding> sharding_;
+ // Uses std::shared_ptr to allow reuse of the same sharding object between
+ // HloInstructions and other components as HloSharding can be very large for
+ // many element tuples.
+ std::shared_ptr<const HloSharding> sharding_;
// Fields used by the kDomain instruction.
std::unique_ptr<DomainMetadata> operand_side_metadata_;