aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc97
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h56
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc102
5 files changed, 79 insertions, 199 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 7de59acc1e..7961aece54 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) {
Array<int64> assignment({2});
assignment.SetValues({0, 1});
auto sharding = HloSharding::Tuple(
- tuple_shape,
- {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment),
- HloSharding::AssignDevice(1), HloSharding::Replicate()});
+ tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1),
+ HloSharding::Replicate()});
p2->set_sharding(sharding);
EXPECT_THAT(p0.get(), op::NoSharding());
@@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) {
EXPECT_THAT(
p2.get(),
- op::Sharding(
- "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}"));
+ op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}"));
EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
"%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 93cc884e3a..de73b38dec 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1383,7 +1383,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
bool replicated = false;
std::vector<tensorflow::int64> devices;
std::vector<tensorflow::int64> tile_assignment_dimensions;
- Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
@@ -1434,7 +1433,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
break;
}
case TokKind::kShape:
- tile_shape = lexer_.GetShapeVal();
+ // TODO(b/112302613): Left here for backward compatibility to ignore the
+ // removed tile shape data.
lexer_.Lex();
break;
case TokKind::kRbrace:
@@ -1449,19 +1449,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(loc,
"replicated shardings should not have any devices assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc,
- "replicated shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return Error(loc,
"maximal shardings should have exactly one device assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "maximal shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else {
@@ -1469,9 +1462,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(
loc, "non-maximal shardings must have more than one device assigned");
}
- if (ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "non-maximal shardings should have a tile shape set");
- }
if (tile_assignment_dimensions.empty()) {
return Error(
loc,
@@ -1479,7 +1469,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
"dimensions");
}
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
- *sharding->mutable_tile_shape() = tile_shape;
for (tensorflow::int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 6399f6ef3c..879fb3bbab 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
CHECK_GT(num_tiles, 1);
std::vector<int64> dimensions(1, num_tiles);
- Shape tile_shape = input_shape;
- auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
- tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
Array<int64> assignment(dimensions);
std::iota(assignment.begin(), assignment.end(), 0);
- return HloSharding(tile_shape, assignment);
+ return HloSharding(assignment);
}
HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
@@ -104,8 +101,7 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[",
- Join(tile_assignment_.dimensions(), ","), "]",
+ return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]",
Join(tile_assignment_, ","), "}");
}
}
@@ -145,7 +141,6 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
}
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
@@ -165,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex(
if (maximal_) {
return *tile_assignment_.begin();
}
- CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size());
return tile_assignment_(index);
}
-std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
+std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
+ int64 device) const {
CHECK(!IsTuple());
- std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
- // Index will always be all zeroes if we're maximal, and tile_shape_ is not
- // valid.
- return index;
+ return std::vector<int64>(shape.dimensions_size(), 0);
}
+
+ CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
+ std::vector<int64> index = TileIndexForDevice(device);
for (int64 i = 0; i < index.size(); ++i) {
- index[i] *= tile_shape_.dimensions(i);
+ const int64 shape_dim = shape.dimensions(i);
+ index[i] = std::min(
+ index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
}
return index;
}
-std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
+std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
+ int64 device) const {
CHECK(!IsTuple());
- CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
+ if (maximal_) {
+ return std::vector<int64>(shape.dimensions().begin(),
+ shape.dimensions().end());
+ }
+
+ CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
std::vector<int64> index = TileIndexForDevice(device);
for (int64 i = 0; i < index.size(); ++i) {
- index[i] = (index[i] + 1) * tile_shape_.dimensions(i);
+ const int64 shape_dim = shape.dimensions(i);
+ index[i] = std::min(
+ (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
+ shape_dim);
}
return index;
}
@@ -336,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return Status::OK();
}
- // The tile rank must be the same as the input rank.
- if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
+ // The tile assignment tensor must have the same rank as the input.
+ if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) {
return tensorflow::errors::InvalidArgument(
- "Tile rank is different to the input rank. sharding=", ToString(),
- ", input_shape=", ShapeUtil::HumanString(shape));
+ "Number of tile assignment dimensions is different to the input rank. "
+ "sharding=",
+ ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
}
// The correct constructor have to be used to create tile maximal shardings.
@@ -350,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
"sharding was intended, use HloSharding::Replicated(). If a device "
"placement was intended, use HloSharding::AssignDevice()");
}
-
- // The tile assignment tensor must contain enough element to cover the full
- // shape with tiles of the specified size.
- for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
- int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i);
- if (shape.dimensions(i) > total_tile_size) {
- return tensorflow::errors::InvalidArgument(
- StrCat("Tile assignment tensor has too few element to cover the full "
- "shape. Dimension ",
- i, ", shape ", shape.dimensions(i), ", total size ",
- total_tile_size));
- }
- }
-
return Status::OK();
}
@@ -393,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
proto.tile_assignment_dimensions().end()));
std::copy(proto.tile_assignment_devices().begin(),
proto.tile_assignment_devices().end(), tile_assignment.begin());
- return HloSharding(proto.tile_shape(), tile_assignment);
+ return HloSharding(tile_assignment);
}
OpSharding HloSharding::ToProto() const {
@@ -407,7 +400,6 @@ OpSharding HloSharding::ToProto() const {
return result;
}
- *result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
@@ -424,30 +416,16 @@ OpSharding HloSharding::ToProto() const {
return result;
}
-HloSharding HloSharding::TransformShardedTileShape(
- const Shape& new_shape,
- const std::function<int64(int64, int64)>& transform) const {
- CHECK(!IsTuple());
+Shape HloSharding::TileShape(const Shape& shape) const {
if (IsTileMaximal()) {
- return *this;
+ return shape;
}
- CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape()));
- Shape new_tile_shape;
- new_tile_shape.set_element_type(tile_shape().element_type());
- for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) {
- int64 dim;
- if (tile_assignment().dim(i) == 1) {
- dim = new_shape.dimensions(i);
- } else if (transform) {
- dim = transform(i, tile_shape().dimensions(i));
- } else {
- dim = tile_shape().dimensions(i);
- }
- new_tile_shape.add_dimensions(dim);
+ Shape result_shape = shape;
+ for (int64 i = 0; i < shape.dimensions_size(); ++i) {
+ (*result_shape.mutable_dimensions())[i] =
+ CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i));
}
- TF_CHECK_OK(
- LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape));
- return HloSharding::Tile(new_tile_shape, tile_assignment());
+ return result_shape;
}
HloSharding HloSharding::GetSubSharding(const Shape& shape,
@@ -489,9 +467,6 @@ size_t HloSharding::Hash() const {
for (uint32 v : tile_assignment_) {
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
}
- for (uint32 v : tile_shape_.dimensions()) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
return h;
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 28575c0e75..894783e5d1 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -48,22 +48,10 @@ class HloSharding {
// the input shape (one tile) assigned to a single device.
static HloSharding AssignDevice(int64 device_id);
- // Creates a new sharding which splits a shape into tiles each with shape
- // `tile_shape`. Each tile is assigned to one device, which is specified by
- // `tile_assignment`. Any tensor not a multiple of the tile size in any
- // dimension is implicitly padded to the tile size.
- //
- // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like:
- // 2 1 padding
- // <------><->
- // +----+----+
- // | 0 | 1 |
- // +----+----+
- //
- // Split into two tiles, one of which is implicitly padded by one.
- static HloSharding Tile(const Shape& tile_shape,
- const Array<int64>& tile_assignment) {
- return HloSharding(tile_shape, tile_assignment);
+ // Creates a new sharding which splits a shape into tiles amongst the devices
+ // specified by `tile_assignment`.
+ static HloSharding Tile(const Array<int64>& tile_assignment) {
+ return HloSharding(tile_assignment);
}
// Creates a new sharding which splits a one-dimensional input shape into
@@ -146,17 +134,18 @@ class HloSharding {
// REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
- // Given a device ID, returns the offset within the input space of the
+ // Given a device ID, returns the offset within the specified shape of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
// REQUIRES: !IsTuple()
- std::vector<int64> TileOffsetForDevice(int64 device) const;
+ std::vector<int64> TileOffsetForDevice(const Shape& shape,
+ int64 device) const;
- // Given a device ID, returns the limit within the input space of the
+ // Given a device ID, returns the limit within the specified shape of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
// REQUIRES: !IsTuple()
- std::vector<int64> TileLimitForDevice(int64 device) const;
+ std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const;
// Returns the single device this op operates on. If the sharding does not
// span a single device, the return value will be empty.
@@ -197,7 +186,6 @@ class HloSharding {
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
- ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_;
}
@@ -211,9 +199,6 @@ class HloSharding {
}
};
- // Gets the tile shape.
- // REQUIRES: !IsTileMaximal() && !IsTuple()
- const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
// REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
@@ -225,25 +210,15 @@ class HloSharding {
return tuple_elements_;
}
- // Return a new sharding that can apply to the given new shape.
- // If this sharding is tile-maximal, the returned sharding will be the same as
- // this sharding. If this sharding is not tile-maximal, the returned
- // sharding's tile size will differ:
- // - Non-sharded dimensions will be adapted to be the same as `new_shape`;
- // tile_dimension(i) = new_shape.dimensions(i);
- // - Sharded dimensions will be kept the same unless `transform` is supplied
- // in which case tile_dimension(i) = transform(i, tile_dimension(i));
- // REQUIRES: !IsTuple().
- HloSharding TransformShardedTileShape(
- const Shape& new_shape,
- const std::function<int64(int64, int64)>& transform = nullptr) const;
+ // Gets the tile shape.
+ // REQUIRES: !IsTuple()
+ Shape TileShape(const Shape& shape) const;
private:
HloSharding()
: replicated_(true),
maximal_(true),
tuple_(false),
- tile_shape_(),
tile_assignment_({0}) {}
// device_id values:
// -2: magic number to mean unassigned device, used by spatial partitioning
@@ -255,15 +230,13 @@ class HloSharding {
: replicated_(false),
maximal_(true),
tuple_(false),
- tile_shape_(),
tile_assignment_({1}, device_id) {}
- HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
+ explicit HloSharding(const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
tuple_(false),
- tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
- HloSharding(const std::vector<HloSharding>& tuple_shardings)
+ explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
: replicated_(false),
maximal_(false),
tuple_(true),
@@ -286,7 +259,6 @@ class HloSharding {
bool replicated_;
bool maximal_;
bool tuple_;
- Shape tile_shape_;
Array<int64> tile_assignment_;
// Only non-empty when tuple_ is true, but because empty tuples are allowed
// may also be empty even then. This is a flattened list of all the leaf
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index aebda562d3..45fc300fca 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -39,7 +39,6 @@ Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
class HloShardingTest : public HloTestBase {};
TEST_F(HloShardingTest, Replicate) {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {4});
HloSharding sharding = HloSharding::Replicate();
EXPECT_TRUE(sharding.IsReplicated());
EXPECT_TRUE(sharding.IsTileMaximal());
@@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) {
TEST_F(HloShardingTest, Tile) {
{
// Test should fail because of a duplicate tile assignment.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3}));
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
/*num_devices=*/4));
}
{
// Test should fail because of more devices used then `num_device`.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
/*num_devices=*/2));
}
{
- // Test should fail because the total tiled size in dimension 0 is 4 but we
- // have 6 elements along that dimensions.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
- EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}),
- /*num_devices=*/4));
- }
-
- {
// Test should pass.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
+ Shape shape = ShapeUtil::MakeShape(U32, {4, 5});
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
/*num_devices=*/5));
@@ -118,10 +102,14 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
- EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector<int64>{0, 0}));
- EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector<int64>{0, 3}));
- EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
- EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0),
+ (std::vector<int64>{0, 0}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3),
+ (std::vector<int64>{0, 3}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2),
+ (std::vector<int64>{2, 0}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1),
+ (std::vector<int64>{2, 3}));
EXPECT_FALSE(sharding.HasUniqueDevice());
}
@@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) {
ShapeUtil::MakeShape(F32, {4, 6}),
});
- HloSharding tiled_sharding = HloSharding::Tile(
- ShapeUtil::MakeShape(F32, {4, 3}), Array<int64>({{0, 1}}));
+ HloSharding tiled_sharding = HloSharding::Tile(Array<int64>({{0, 1}}));
OpSharding proto;
proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
*proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
@@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) {
}
{
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 2, 1}));
- EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
- }
-
- {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 2, 1}));
+ HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
+ HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
- {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 1, 2}));
- EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
- }
-
HloSharding default_sharding = HloSharding::Replicate();
{
ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
@@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) {
}
}
-TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
- HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
- HloSharding result = sharding.TransformShardedTileShape(
- ShapeUtil::MakeShape(F32, {13, 15, 17, 19}),
- [](int dim, int value) { return dim * 111; });
- HloSharding expected =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
- EXPECT_EQ(result, expected);
-}
-
TEST_F(HloShardingTest, ToStringReplicatedTest) {
HloSharding sharding = HloSharding::Replicate();
EXPECT_EQ(sharding.ToString(), "{replicated}");
@@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) {
TEST_F(HloShardingTest, ToStringTiledTest) {
HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}),
- Array3D<int64>({{{2, 3}}, {{5, 7}}}));
- EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}");
+ HloSharding::Tile(Array3D<int64>({{{2, 3}}, {{5, 7}}}));
+ EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}");
}
TEST_F(HloShardingTest, ToStringTupleTest) {
@@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
ShapeUtil::MakeShape(U32, {7, 25}),
ShapeUtil::MakeShape(S32, {9, 11})}),
- {HloSharding::Replicate(),
- HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}),
- Array2D<int64>({{3, 5}})),
+ {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64>({{3, 5}})),
HloSharding::AssignDevice(3)});
EXPECT_EQ(sharding.ToString(),
- "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}");
+ "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
}
TEST_F(HloShardingTest, OstreamTest) {
HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+ HloSharding::Tile(Array4D<int64>({{{{0, 1}, {2, 3}}}}));
std::ostringstream oss;
oss << sharding;
- EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
+ EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}");
}
TEST_F(HloShardingTest, ParseHloString) {
@@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) {
};
check(HloSharding::Replicate());
check(HloSharding::AssignDevice(2));
- check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})));
+ check(HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})));
// Empty tuple. One sharding is required for empty tuples, as we need to be
// able to assign sharding to them, even though they have no leaves.
check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
@@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) {
ShapeUtil::MakeShape(F32, {3, 5, 7}),
ShapeUtil::MakeShape(F32, {3, 7})});
check(HloSharding::Tuple(
- tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})),
+ tuple_shape, {HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
}
{
@@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
ShapeUtil::MakeShape(F32, {3, 7})})});
std::vector<HloSharding> leaf_shardings = {
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})),
+ HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
HloSharding::Replicate(), HloSharding::AssignDevice(1)};
ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
// Assign leaf_shardings to sharding_tree leaves.