diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-23 06:09:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 06:16:01 -0700 |
commit | 9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (patch) | |
tree | 8b90a81e214344e9926ee96564fc4f4bc6de530e /tensorflow/compiler/xla/service/hlo_domain_isolator.cc | |
parent | d2c36c1b3e84e4e29c2853aa421ada45ff5fd396 (diff) |
Reduce the memory usage of sharding domains
Previously the domain instructions inserted before and after an `n`
element tuple required `O(n^2)` memory (and compute) because every
operand and user had its own domain instruction with a tuple sharding
and tuple shape for the exit domains what constructed `n` HloSharding
and `n` Shape proto per domain.
After this change we keep track of the domain instructions inserted and
if we already have a domain instruction with the correct operand and
metadata then we re-use it instead of creating a new one.
Additionally we change HloInstruction and ShardingMetadata to store a
std::shared_ptr to HloSharding so the same instance can be shared by
many instructions. This CL doesn't update all uses to remove all of the
duplicated HloShardings but handles the most wastful cases to reduce
memory usage.
PiperOrigin-RevId: 209924260
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_isolator.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_isolator.cc | 26 |
1 files changed, 3 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index af904647f8..72185698c9 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr<bool> Run(); private: - // Inserts a kDomain instruction between operand and instruction in case - // the attribute (ie, sharding) values change between root and instruction. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* root, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr<HloInstruction> domain_instruction = - isolator_->creator_(instruction, root, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); - } - return domain; -} - StatusOr<bool> HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -76,10 +55,11 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() { root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, root, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } |