aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 06:09:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 06:16:01 -0700
commit9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (patch)
tree8b90a81e214344e9926ee96564fc4f4bc6de530e /tensorflow/compiler/xla/service/hlo_domain_isolator.cc
parentd2c36c1b3e84e4e29c2853aa421ada45ff5fd396 (diff)
Reduce the memory usage of sharding domains
Previously the domain instructions inserted before and after an `n` element tuple required `O(n^2)` memory (and compute) because every operand and user had its own domain instruction with a tuple sharding and tuple shape for the exit domains what constructed `n` HloSharding and `n` Shape proto per domain. After this change we keep track of the domain instructions inserted and if we already have a domain instruction with the correct operand and metadata then we re-use it instead of creating a new one. Additionally we change HloInstruction and ShardingMetadata to store a std::shared_ptr to HloSharding so the same instance can be shared by many instructions. This CL doesn't update all uses to remove all of the duplicated HloShardings but handles the most wastful cases to reduce memory usage. PiperOrigin-RevId: 209924260
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_isolator.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc26
1 files changed, 3 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
index af904647f8..72185698c9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();
private:
- // Inserts a kDomain instruction between operand and instruction in case
- // the attribute (ie, sharding) values change between root and instruction.
- // Returns the newly inserted kDomain instruction, or nullptr if no kDomain
- // instruction was necessary.
- StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
- HloInstruction* root,
- HloInstruction* operand);
-
HloModule* module_;
HloDomainIsolator* isolator_;
};
-StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
- HloInstruction* instruction, HloInstruction* root,
- HloInstruction* operand) {
- HloInstruction* domain = nullptr;
- std::unique_ptr<HloInstruction> domain_instruction =
- isolator_->creator_(instruction, root, operand);
- if (domain_instruction != nullptr) {
- domain = operand->parent()->AddInstruction(std::move(domain_instruction));
- TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
- }
- return domain;
-}
-
StatusOr<bool> HloDomainIsolator::RunContext::Run() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
@@ -76,10 +55,11 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
- TF_ASSIGN_OR_RETURN(HloInstruction * domain,
- CreateDomain(instruction, root, operand));
+ HloInstruction* domain =
+ isolator_->creator_(instruction, root, operand);
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
++added_domains;
}
}