aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_domain_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_domain_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc21
1 files changed, 8 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 00b2c860a7..ffc18a0f88 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -97,12 +97,6 @@ class OpNameMetadata : public DomainMetadata {
string ToString() const override { return opname_; }
- Status NormalizeInstructions(
- const DomainMetadata::Domain& domain) const override {
- // For the purposes of this test, nothing to do.
- return Status::OK();
- }
-
static tensorflow::StringPiece KindName() { return "opname"; }
private:
@@ -124,7 +118,8 @@ std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction,
std::move(user_side_metadata));
}
-Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) {
+Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
+ const DomainMetadata* metadata) {
// Nothing to do for the particular use this test make of the OpName domains.
return Status::OK();
}
@@ -159,7 +154,7 @@ ENTRY entry {
EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
@@ -227,7 +222,7 @@ ENTRY entry {
EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
@@ -277,7 +272,7 @@ ENTRY entry {
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainRemover remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_FALSE(remover_changed);
@@ -324,7 +319,7 @@ ENTRY entry {
EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
sharding_remover.Run(module));
EXPECT_TRUE(sharding_remover_changed);
@@ -411,7 +406,7 @@ ENTRY entry {
}
HloDomainRemover remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
@@ -465,7 +460,7 @@ ENTRY entry {
TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
HloDomainRemover remover(ShardingMetadata::KindName(),
- NormalizeShardingDomain);
+ ShardingMetadata::NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);