diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-22 09:22:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 09:26:39 -0700 |
commit | 567189980f7a1c2aa09a5170bd8d01a6ec37d303 (patch) | |
tree | 6990c7bf6f36a6ca82343b4d5fbaae6703e866c5 | |
parent | cffdccdc642c1fe852e61cb3236aa00ee53c92bf (diff) |
Fix domain isolation for the case when multiple domain type is involved
Previously when we were stacking domains we inserted the new domain
instructions between the upper most domain and its operand. This caused
issues if that domain had more then one user with different atribute for
the domain inserted at the second pass because we could have ended up
with edges between different domains.
After this change we insert the new domains between the lower most
domain and its user ensuring that the domain separates every instruction
with different attributes.
PiperOrigin-RevId: 209776741
5 files changed, 89 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 78955db0da..af904647f8 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,12 +31,12 @@ class HloDomainIsolator::RunContext { StatusOr<bool> Run(); private: - // Inserts a kDomain instruction between parent and operand, in case - // the attribute (ie, sharding) values change between instruction and operand. + // Inserts a kDomain instruction between operand and instruction in case + // the attribute (ie, sharding) values change between root and instruction. // Returns the newly inserted kDomain instruction, or nullptr if no kDomain // instruction was necessary. StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction, - HloInstruction* parent, + HloInstruction* root, HloInstruction* operand); HloModule* module_; @@ -44,14 +44,14 @@ class HloDomainIsolator::RunContext { }; StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* parent, + HloInstruction* instruction, HloInstruction* root, HloInstruction* operand) { HloInstruction* domain = nullptr; std::unique_ptr<HloInstruction> domain_instruction = - isolator_->creator_(instruction, operand); + isolator_->creator_(instruction, root, operand); if (domain_instruction != nullptr) { domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); } return domain; } @@ -71,14 +71,13 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() { // When applying multiple domains, we could end up stacking more than // one in one edge, so here we want to build the effective // (kDomain-less) instruction->operand edge. - HloInstruction* parent = instruction; - while (operand->opcode() == HloOpcode::kDomain) { - parent = operand; - operand = operand->mutable_operand(0); + HloInstruction* root = operand; + while (root->opcode() == HloOpcode::kDomain) { + root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, parent, operand)); + CreateDomain(instruction, root, operand)); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); ++added_domains; diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index eded3e78ee..bb1537766c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -34,10 +34,12 @@ class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the - // second HloInstruction argument). + // third HloInstruction argument) if the interesting attribute of the + // instruction differes from the attribute of the root (the second + // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. using DomainCreator = std::function<std::unique_ptr<HloInstruction>( - HloInstruction*, HloInstruction*)>; + HloInstruction*, HloInstruction*, HloInstruction*)>; explicit HloDomainIsolator(DomainCreator creator); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 7d48be15cf..2654929bf0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -106,12 +106,13 @@ class OpNameMetadata : public DomainMetadata { // Creator function for OpNameMetadata domains. std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, HloInstruction* operand) { - if (instruction->metadata().op_name() == operand->metadata().op_name()) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } std::unique_ptr<DomainMetadata> operand_side_metadata = - absl::make_unique<OpNameMetadata>(operand->metadata().op_name()); + absl::make_unique<OpNameMetadata>(root->metadata().op_name()); std::unique_ptr<DomainMetadata> user_side_metadata = absl::make_unique<OpNameMetadata>(instruction->metadata().op_name()); return HloInstruction::CreateDomain(operand->shape(), operand, @@ -524,5 +525,64 @@ ENTRY entry { tpl->sharding()); } +TEST_F(HloDomainTest, MultiDomainMultiUser) { + const char* const hlo_string = R"( + HloModule Module + +ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { + %p0 = (f32[4], f32[4]) parameter(0) + %a = f32[4]{0} get-tuple-element(%p0), index=0 + %domain = f32[4] domain(%a), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %b = f32[4] get-tuple-element(%p0), index=1 + %domain.1 = f32[4] domain(%b), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1} + %domain.2 = f32[4] domain(%c), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %d = f32[4] subtract(%domain, %c), + sharding={maximal device=1}, metadata={op_name="D"} + %domain.3 = f32[4] domain(%d), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %e = f32[4] multiply(%c, %d), + sharding={maximal device=1}, metadata={op_name="D"} + %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1} + %domain.4 = f32[4]{0} domain(%f), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 4e19557f82..6f0353ee5f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -284,18 +284,19 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, // The kDomain instruction will be created only if the sharding differ between // the instruction and the operand. std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, + HloInstruction* root, HloInstruction* operand) { const HloSharding* instruction_sharding = instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* operand_sharding = - operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding* root_sharding = + root->has_sharding() ? &root->sharding() : nullptr; // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && operand_sharding == nullptr) { + if (instruction_sharding == nullptr && root_sharding == nullptr) { return nullptr; } // No need for domain if they match. - if (instruction_sharding != nullptr && operand_sharding != nullptr && - ShardingMatches(*instruction_sharding, *operand_sharding)) { + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { return nullptr; } std::unique_ptr<HloSharding> real_instruction_sharding; @@ -303,8 +304,8 @@ std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, if (instruction_sharding != nullptr) { real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); } - if (operand_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*operand_sharding); + if (root_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*root_sharding); } VLOG(3) << "Creating domain:"; VLOG(3) << " Instruction: " << instruction->name(); @@ -417,8 +418,9 @@ Status ShardingMetadata::NormalizeShardingDomain( } std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand) { - return CreateDomain(instruction, operand); + HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand) { + return CreateDomain(instruction, root, operand); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 5e01fc0e22..dc258e4094 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -60,10 +60,10 @@ class ShardingMetadata : public DomainMetadata { // Given an HLO graph edge between instruction and one of its operands, creates // a ShardingMetadata based kDomain instruction if the sharding between -// instruction and operand changes. Returns nullptr if there is no need for a +// instruction and parent changes. Returns nullptr if there is no need for a // domain separation. std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand); + HloInstruction* instruction, HloInstruction* root, HloInstruction* operand); } // namespace xla |