aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-22 09:22:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 09:26:39 -0700
commit567189980f7a1c2aa09a5170bd8d01a6ec37d303 (patch)
tree6990c7bf6f36a6ca82343b4d5fbaae6703e866c5
parentcffdccdc642c1fe852e61cb3236aa00ee53c92bf (diff)
Fix domain isolation for the case when multiple domain type is involved
Previously when we were stacking domains we inserted the new domain instructions between the upper most domain and its operand. This caused issues if that domain had more then one user with different atribute for the domain inserted at the second pass because we could have ended up with edges between different domains. After this change we insert the new domains between the lower most domain and its user ensuring that the domain separates every instruction with different attributes. PiperOrigin-RevId: 209776741
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.h4
5 files changed, 89 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
index 78955db0da..af904647f8 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc
@@ -31,12 +31,12 @@ class HloDomainIsolator::RunContext {
StatusOr<bool> Run();
private:
- // Inserts a kDomain instruction between parent and operand, in case
- // the attribute (ie, sharding) values change between instruction and operand.
+ // Inserts a kDomain instruction between operand and instruction in case
+ // the attribute (ie, sharding) values change between root and instruction.
// Returns the newly inserted kDomain instruction, or nullptr if no kDomain
// instruction was necessary.
StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction,
- HloInstruction* parent,
+ HloInstruction* root,
HloInstruction* operand);
HloModule* module_;
@@ -44,14 +44,14 @@ class HloDomainIsolator::RunContext {
};
StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain(
- HloInstruction* instruction, HloInstruction* parent,
+ HloInstruction* instruction, HloInstruction* root,
HloInstruction* operand) {
HloInstruction* domain = nullptr;
std::unique_ptr<HloInstruction> domain_instruction =
- isolator_->creator_(instruction, operand);
+ isolator_->creator_(instruction, root, operand);
if (domain_instruction != nullptr) {
domain = operand->parent()->AddInstruction(std::move(domain_instruction));
- TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain));
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
}
return domain;
}
@@ -71,14 +71,13 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
// When applying multiple domains, we could end up stacking more than
// one in one edge, so here we want to build the effective
// (kDomain-less) instruction->operand edge.
- HloInstruction* parent = instruction;
- while (operand->opcode() == HloOpcode::kDomain) {
- parent = operand;
- operand = operand->mutable_operand(0);
+ HloInstruction* root = operand;
+ while (root->opcode() == HloOpcode::kDomain) {
+ root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
TF_ASSIGN_OR_RETURN(HloInstruction * domain,
- CreateDomain(instruction, parent, operand));
+ CreateDomain(instruction, root, operand));
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
++added_domains;
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index eded3e78ee..bb1537766c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -34,10 +34,12 @@ class HloDomainIsolator : public HloPassInterface {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
- // second HloInstruction argument).
+ // third HloInstruction argument) if the interesting attribute of the
+ // instruction differes from the attribute of the root (the second
+ // HloInstruction argument).
// Returns nullptr in case no domain separation is necessary.
using DomainCreator = std::function<std::unique_ptr<HloInstruction>(
- HloInstruction*, HloInstruction*)>;
+ HloInstruction*, HloInstruction*, HloInstruction*)>;
explicit HloDomainIsolator(DomainCreator creator);
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 7d48be15cf..2654929bf0 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -106,12 +106,13 @@ class OpNameMetadata : public DomainMetadata {
// Creator function for OpNameMetadata domains.
std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
+ HloInstruction* root,
HloInstruction* operand) {
- if (instruction->metadata().op_name() == operand->metadata().op_name()) {
+ if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
- absl::make_unique<OpNameMetadata>(operand->metadata().op_name());
+ absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
return HloInstruction::CreateDomain(operand->shape(), operand,
@@ -524,5 +525,64 @@ ENTRY entry {
tpl->sharding());
}
+TEST_F(HloDomainTest, MultiDomainMultiUser) {
+ const char* const hlo_string = R"(
+ HloModule Module
+
+ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
+ %p0 = (f32[4], f32[4]) parameter(0)
+ %a = f32[4]{0} get-tuple-element(%p0), index=0
+ %domain = f32[4] domain(%a),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %b = f32[4] get-tuple-element(%p0), index=1
+ %domain.1 = f32[4] domain(%b),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1}
+ %domain.2 = f32[4] domain(%c),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %d = f32[4] subtract(%domain, %c),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %domain.3 = f32[4] domain(%d),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %e = f32[4] multiply(%c, %d),
+ sharding={maximal device=1}, metadata={op_name="D"}
+ %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1}
+ %domain.4 = f32[4]{0} domain(%f),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+ LOG(INFO) << "Original module:\n" << module->ToString();
+
+ HloDomainIsolator opname_isolator(OpNameDomainCreator);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
+ opname_isolator.Run(module));
+ EXPECT_TRUE(opname_isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
+
+ HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
+ ShardingMetadata::NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
+ sharding_remover.Run(module));
+ EXPECT_TRUE(sharding_remover_changed);
+
+ HloDomainRemover opname_remover(OpNameMetadata::KindName(),
+ OpNameDomainNormalizer);
+ TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
+ opname_remover.Run(module));
+ EXPECT_TRUE(opname_remover_changed);
+
+ EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "c"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 4e19557f82..6f0353ee5f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -284,18 +284,19 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
// The kDomain instruction will be created only if the sharding differ between
// the instruction and the operand.
std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
+ HloInstruction* root,
HloInstruction* operand) {
const HloSharding* instruction_sharding =
instruction->has_sharding() ? &instruction->sharding() : nullptr;
- const HloSharding* operand_sharding =
- operand->has_sharding() ? &operand->sharding() : nullptr;
+ const HloSharding* root_sharding =
+ root->has_sharding() ? &root->sharding() : nullptr;
// No need for domain if they both have no sharding.
- if (instruction_sharding == nullptr && operand_sharding == nullptr) {
+ if (instruction_sharding == nullptr && root_sharding == nullptr) {
return nullptr;
}
// No need for domain if they match.
- if (instruction_sharding != nullptr && operand_sharding != nullptr &&
- ShardingMatches(*instruction_sharding, *operand_sharding)) {
+ if (instruction_sharding != nullptr && root_sharding != nullptr &&
+ ShardingMatches(*instruction_sharding, *root_sharding)) {
return nullptr;
}
std::unique_ptr<HloSharding> real_instruction_sharding;
@@ -303,8 +304,8 @@ std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction,
if (instruction_sharding != nullptr) {
real_instruction_sharding = CloneShardingForDomain(*instruction_sharding);
}
- if (operand_sharding != nullptr) {
- real_operand_sharding = CloneShardingForDomain(*operand_sharding);
+ if (root_sharding != nullptr) {
+ real_operand_sharding = CloneShardingForDomain(*root_sharding);
}
VLOG(3) << "Creating domain:";
VLOG(3) << " Instruction: " << instruction->name();
@@ -417,8 +418,9 @@ Status ShardingMetadata::NormalizeShardingDomain(
}
std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand) {
- return CreateDomain(instruction, operand);
+ HloInstruction* instruction, HloInstruction* root,
+ HloInstruction* operand) {
+ return CreateDomain(instruction, root, operand);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index 5e01fc0e22..dc258e4094 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -60,10 +60,10 @@ class ShardingMetadata : public DomainMetadata {
// Given an HLO graph edge between instruction and one of its operands, creates
// a ShardingMetadata based kDomain instruction if the sharding between
-// instruction and operand changes. Returns nullptr if there is no need for a
+// instruction and parent changes. Returns nullptr if there is no need for a
// domain separation.
std::unique_ptr<HloInstruction> CreateShardingDomain(
- HloInstruction* instruction, HloInstruction* operand);
+ HloInstruction* instruction, HloInstruction* root, HloInstruction* operand);
} // namespace xla