aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 05:43:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-23 05:45:38 -0700
commitc1eef98b8fa5fa531c6ddd08a0e5f0f00f31b431 (patch)
tree4ae29f646375f4cde08e902a8a898776fb1e8b94
parent77ac1fb6fb6b4be8152d4d9972cd2ea3968001e8 (diff)
Fix HloSharding::ToString to be compatible with the HLO text parser
The incompatibility made it impossible to parse HLO modules with sharding attributes printed from a previous run of XLA. The previous format generated by this function was ambigious so changing the HLO parser only wouldn't be feasible. PiperOrigin-RevId: 190208069
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc32
2 files changed, 35 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 27b4626094..e8e45f1ee9 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -20,6 +20,7 @@ limitations under the License.
namespace xla {
+using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrCat;
HloSharding HloSharding::AssignDevice(int64 device_id) {
@@ -57,8 +58,9 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ",
- "devices=", VectorString(tile_assignment_), "}");
+ return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[",
+ Join(tile_assignment_.dimensions(), ","), "]",
+ Join(tile_assignment_, ","), "}");
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 9887096eb5..69ea4233e4 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -282,13 +282,43 @@ TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
EXPECT_EQ(result, expected);
}
+TEST_F(HloShardingTest, ToStringReplicatedTest) {
+ HloSharding sharding = HloSharding::Replicate();
+ EXPECT_EQ(sharding.ToString(), "{replicated}");
+}
+
+TEST_F(HloShardingTest, ToStringAssignDeviceTest) {
+ HloSharding sharding = HloSharding::AssignDevice(7);
+ EXPECT_EQ(sharding.ToString(), "{maximal device=7}");
+}
+
+TEST_F(HloShardingTest, ToStringTiledTest) {
+ HloSharding sharding =
+ HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}),
+ Array3D<int64>({{{2, 3}}, {{5, 7}}}));
+ EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}");
+}
+
+TEST_F(HloShardingTest, ToStringTupleTest) {
+ HloSharding sharding = HloSharding::Tuple(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
+ ShapeUtil::MakeShape(U32, {7, 25}),
+ ShapeUtil::MakeShape(S32, {9, 11})}),
+ {HloSharding::Replicate(),
+ HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}),
+ Array2D<int64>({{3, 5}})),
+ HloSharding::AssignDevice(3)});
+ EXPECT_EQ(sharding.ToString(),
+ "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}");
+}
+
TEST_F(HloShardingTest, OstreamTest) {
HloSharding sharding =
HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
Array4D<int64>({{{{0, 1}, {2, 3}}}}));
std::ostringstream oss;
oss << sharding;
- EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=(0, 1, 2, 3)}");
+ EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
}
} // namespace