diff options
author | 2018-06-28 02:32:37 -0700 | |
---|---|---|
committer | 2018-06-28 02:35:00 -0700 | |
commit | 06aac645329c45bb2c6d0fb1539816c3c0fb98e4 (patch) | |
tree | 4e0f130e80bf7d1f97a70a6043f0ded0d649143d | |
parent | 99dc8c88c465490975eb6933383a7195a5cae9a9 (diff) |
Fixed ShardingMetadata dump of null sharding from None to {}, to make it
compatible with hlo string syntax.
PiperOrigin-RevId: 202445509
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_test.cc | 21 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_metadata.cc | 2 |
2 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index ff356bdd6d..abc5b1c8ef 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -430,5 +430,26 @@ ENTRY entry { HloSharding::AssignDevice(0)})); } +// Tests that text dumps of domain instructions can be parsed back, in the +// specific case of null shardings. +TEST_F(HloDomainTest, DumpParseNullSharding) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr); + auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( + shape, param, std::move(sharding_md_0), std::move(sharding_md_1))); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto hlo_string = module->ToString(); + ASSERT_TRUE(ParseModule(hlo_string).status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 748273a43c..39036e205e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -377,7 +377,7 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { } string ShardingMetadata::ToString() const { - return sharding_ != nullptr ? sharding_->ToString() : "None"; + return sharding_ != nullptr ? sharding_->ToString() : "{}"; } Status ShardingMetadata::NormalizeInstructions( |