diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_matchers_test.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 97 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 56 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 102 |
5 files changed, 79 insertions, 199 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 7de59acc1e..7961aece54 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array<int64> assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 93cc884e3a..de73b38dec 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1383,7 +1383,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector<tensorflow::int64> devices; std::vector<tensorflow::int64> tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1434,7 +1433,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1449,19 +1449,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1469,9 +1462,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1479,7 +1469,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 6399f6ef3c..879fb3bbab 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector<int64> dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles); Array<int64> assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) { @@ -104,8 +101,7 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", + return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", Join(tile_assignment_, ","), "}"); } } @@ -145,7 +141,6 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const { } std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector<int64> ret_index; @@ -165,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex( if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const { +std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector<int64> index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector<int64>(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector<int64> index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const { +std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector<int64>(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector<int64> index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -336,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -350,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -393,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -407,7 +400,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -424,30 +416,16 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function<int64(int64, int64)>& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, @@ -489,9 +467,6 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); - } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 28575c0e75..894783e5d1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -48,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array<int64>& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array<int64>& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -146,17 +134,18 @@ class HloSharding { // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector<int64> TileOffsetForDevice(int64 device) const; + std::vector<int64> TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector<int64> TileLimitForDevice(int64 device) const; + std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const; // Returns the single device this op operates on. If the sharding does not // span a single device, the return value will be empty. @@ -197,7 +186,6 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -211,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array<int64>& tile_assignment() const { return tile_assignment_; } @@ -225,25 +210,15 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function<int64(int64, int64)>& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -255,15 +230,13 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment) + explicit HloSharding(const Array<int64>& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector<HloSharding>& tuple_shardings) + explicit HloSharding(const std::vector<HloSharding>& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -286,7 +259,6 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array<int64> tile_assignment_; // Only non-empty when tuple_ is true, but because empty tuples are allowed // may also be empty even then. This is a flattened list of all the leaf diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index aebda562d3..45fc300fca 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -39,7 +39,6 @@ Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,10 +102,14 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector<int64>{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector<int64>{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector<int64>{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector<int64>{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector<int64>{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector<int64>{2, 3})); EXPECT_FALSE(sharding.HasUniqueDevice()); } @@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array<int64>({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array<int64>({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D<int64>({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D<int64>({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D<int64>({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D<int64>({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D<int64>({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64>({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D<int64>({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D<int64>({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D<int64>({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D<int64>({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector<HloSharding> leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D<int64>({{{{0}, {1}}}})), + HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. |