diff options
author | 2018-03-23 05:43:13 -0700 | |
---|---|---|
committer | 2018-03-23 05:45:38 -0700 | |
commit | c1eef98b8fa5fa531c6ddd08a0e5f0f00f31b431 (patch) | |
tree | 4ae29f646375f4cde08e902a8a898776fb1e8b94 | |
parent | 77ac1fb6fb6b4be8152d4d9972cd2ea3968001e8 (diff) |
Fix HloSharding::ToString to be compatible with the HLO text parser
The incompatibility made it impossible to parse HLO modules with
sharding attributes printed from a previous run of XLA. The previous
format generated by this function was ambigious so changing the HLO
parser only wouldn't be feasible.
PiperOrigin-RevId: 190208069
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 32 |
2 files changed, 35 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 27b4626094..e8e45f1ee9 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -20,6 +20,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; HloSharding HloSharding::AssignDevice(int64 device_id) { @@ -57,8 +58,9 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", - "devices=", VectorString(tile_assignment_), "}"); + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", + Join(tile_assignment_.dimensions(), ","), "]", + Join(tile_assignment_, ","), "}"); } } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 9887096eb5..69ea4233e4 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -282,13 +282,43 @@ TEST_F(HloShardingTest, TransformShardedTileShapeTest) { EXPECT_EQ(result, expected); } +TEST_F(HloShardingTest, ToStringReplicatedTest) { + HloSharding sharding = HloSharding::Replicate(); + EXPECT_EQ(sharding.ToString(), "{replicated}"); +} + +TEST_F(HloShardingTest, ToStringAssignDeviceTest) { + HloSharding sharding = HloSharding::AssignDevice(7); + EXPECT_EQ(sharding.ToString(), "{maximal device=7}"); +} + +TEST_F(HloShardingTest, ToStringTiledTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), + Array3D<int64>({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); +} + +TEST_F(HloShardingTest, ToStringTupleTest) { + HloSharding sharding = HloSharding::Tuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), + ShapeUtil::MakeShape(U32, {7, 25}), + ShapeUtil::MakeShape(S32, {9, 11})}), + {HloSharding::Replicate(), + HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), + Array2D<int64>({{3, 5}})), + HloSharding::AssignDevice(3)}); + EXPECT_EQ(sharding.ToString(), + "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); +} + TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), Array4D<int64>({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=(0, 1, 2, 3)}"); + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } } // namespace |