aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 03:28:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 03:32:42 -0700
commit15f1fa14a3fd4b63e18539836f6036fef024fce7 (patch)
treec18d09d0b8c450d129a775b7804de855f2c82bb4 /tensorflow
parentde537122fbd1a49a44bd71e3a24c7b4d4d23c24c (diff)
Remove tile shape from HloSharding
The tile shape can be deduced based on the tile assignment and then HLO shape and by not storing it in the sharding we can give more flexibility to the compiler to decide the data layout. PiperOrigin-RevId: 207860794
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc97
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h56
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc102
5 files changed, 79 insertions, 199 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 7de59acc1e..7961aece54 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -157,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) {
Array<int64> assignment({2});
assignment.SetValues({0, 1});
auto sharding = HloSharding::Tuple(
- tuple_shape,
- {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment),
- HloSharding::AssignDevice(1), HloSharding::Replicate()});
+ tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1),
+ HloSharding::Replicate()});
p2->set_sharding(sharding);
EXPECT_THAT(p0.get(), op::NoSharding());
@@ -172,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) {
EXPECT_THAT(
p2.get(),
- op::Sharding(
- "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}"));
+ op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}"));
EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
"%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 93cc884e3a..de73b38dec 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1383,7 +1383,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
bool replicated = false;
std::vector<tensorflow::int64> devices;
std::vector<tensorflow::int64> tile_assignment_dimensions;
- Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
@@ -1434,7 +1433,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
break;
}
case TokKind::kShape:
- tile_shape = lexer_.GetShapeVal();
+ // TODO(b/112302613): Left here for backward compatibility to ignore the
+ // removed tile shape data.
lexer_.Lex();
break;
case TokKind::kRbrace:
@@ -1449,19 +1449,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(loc,
"replicated shardings should not have any devices assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc,
- "replicated shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return Error(loc,
"maximal shardings should have exactly one device assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "maximal shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else {
@@ -1469,9 +1462,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(
loc, "non-maximal shardings must have more than one device assigned");
}
- if (ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "non-maximal shardings should have a tile shape set");
- }
if (tile_assignment_dimensions.empty()) {
return Error(
loc,
@@ -1479,7 +1469,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
"dimensions");
}
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
- *sharding->mutable_tile_shape() = tile_shape;
for (tensorflow::int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 6399f6ef3c..879fb3bbab 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
CHECK_EQ(1, ShapeUtil::Rank(input_shape));
CHECK_GT(num_tiles, 1);
std::vector<int64> dimensions(1, num_tiles);
- Shape tile_shape = input_shape;
- auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
- tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
Array<int64> assignment(dimensions);
std::iota(assignment.begin(), assignment.end(), 0);
- return HloSharding(tile_shape, assignment);
+ return HloSharding(assignment);
}
HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
@@ -104,8 +101,7 @@ string HloSharding::ToString() const {
return StrCat(
"{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
} else {
- return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[",
- Join(tile_assignment_.dimensions(), ","), "]",
+ return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]",
Join(tile_assignment_, ","), "}");
}
}
@@ -145,7 +141,6 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
}
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
- CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
CHECK(!IsTuple());
std::vector<int64> ret_index;
@@ -165,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex(
if (maximal_) {
return *tile_assignment_.begin();
}
- CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size());
return tile_assignment_(index);
}
-std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
+std::vector<int64> HloSharding::TileOffsetForDevice(const Shape& shape,
+ int64 device) const {
CHECK(!IsTuple());
- std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
- // Index will always be all zeroes if we're maximal, and tile_shape_ is not
- // valid.
- return index;
+ return std::vector<int64>(shape.dimensions_size(), 0);
}
+
+ CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
+ std::vector<int64> index = TileIndexForDevice(device);
for (int64 i = 0; i < index.size(); ++i) {
- index[i] *= tile_shape_.dimensions(i);
+ const int64 shape_dim = shape.dimensions(i);
+ index[i] = std::min(
+ index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim);
}
return index;
}
-std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
+std::vector<int64> HloSharding::TileLimitForDevice(const Shape& shape,
+ int64 device) const {
CHECK(!IsTuple());
- CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
+ if (maximal_) {
+ return std::vector<int64>(shape.dimensions().begin(),
+ shape.dimensions().end());
+ }
+
+ CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions());
std::vector<int64> index = TileIndexForDevice(device);
for (int64 i = 0; i < index.size(); ++i) {
- index[i] = (index[i] + 1) * tile_shape_.dimensions(i);
+ const int64 shape_dim = shape.dimensions(i);
+ index[i] = std::min(
+ (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)),
+ shape_dim);
}
return index;
}
@@ -336,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return Status::OK();
}
- // The tile rank must be the same as the input rank.
- if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
+ // The tile assignment tensor must have the same rank as the input.
+ if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) {
return tensorflow::errors::InvalidArgument(
- "Tile rank is different to the input rank. sharding=", ToString(),
- ", input_shape=", ShapeUtil::HumanString(shape));
+ "Number of tile assignment dimensions is different to the input rank. "
+ "sharding=",
+ ToString(), ", input_shape=", ShapeUtil::HumanString(shape));
}
// The correct constructor have to be used to create tile maximal shardings.
@@ -350,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
"sharding was intended, use HloSharding::Replicated(). If a device "
"placement was intended, use HloSharding::AssignDevice()");
}
-
- // The tile assignment tensor must contain enough element to cover the full
- // shape with tiles of the specified size.
- for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
- int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i);
- if (shape.dimensions(i) > total_tile_size) {
- return tensorflow::errors::InvalidArgument(
- StrCat("Tile assignment tensor has too few element to cover the full "
- "shape. Dimension ",
- i, ", shape ", shape.dimensions(i), ", total size ",
- total_tile_size));
- }
- }
-
return Status::OK();
}
@@ -393,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
proto.tile_assignment_dimensions().end()));
std::copy(proto.tile_assignment_devices().begin(),
proto.tile_assignment_devices().end(), tile_assignment.begin());
- return HloSharding(proto.tile_shape(), tile_assignment);
+ return HloSharding(tile_assignment);
}
OpSharding HloSharding::ToProto() const {
@@ -407,7 +400,6 @@ OpSharding HloSharding::ToProto() const {
return result;
}
- *result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
@@ -424,30 +416,16 @@ OpSharding HloSharding::ToProto() const {
return result;
}
-HloSharding HloSharding::TransformShardedTileShape(
- const Shape& new_shape,
- const std::function<int64(int64, int64)>& transform) const {
- CHECK(!IsTuple());
+Shape HloSharding::TileShape(const Shape& shape) const {
if (IsTileMaximal()) {
- return *this;
+ return shape;
}
- CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape()));
- Shape new_tile_shape;
- new_tile_shape.set_element_type(tile_shape().element_type());
- for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) {
- int64 dim;
- if (tile_assignment().dim(i) == 1) {
- dim = new_shape.dimensions(i);
- } else if (transform) {
- dim = transform(i, tile_shape().dimensions(i));
- } else {
- dim = tile_shape().dimensions(i);
- }
- new_tile_shape.add_dimensions(dim);
+ Shape result_shape = shape;
+ for (int64 i = 0; i < shape.dimensions_size(); ++i) {
+ (*result_shape.mutable_dimensions())[i] =
+ CeilOfRatio<int64>(shape.dimensions(i), tile_assignment_.dim(i));
}
- TF_CHECK_OK(
- LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape));
- return HloSharding::Tile(new_tile_shape, tile_assignment());
+ return result_shape;
}
HloSharding HloSharding::GetSubSharding(const Shape& shape,
@@ -489,9 +467,6 @@ size_t HloSharding::Hash() const {
for (uint32 v : tile_assignment_) {
h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
}
- for (uint32 v : tile_shape_.dimensions()) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
return h;
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 28575c0e75..894783e5d1 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -48,22 +48,10 @@ class HloSharding {
// the input shape (one tile) assigned to a single device.
static HloSharding AssignDevice(int64 device_id);
- // Creates a new sharding which splits a shape into tiles each with shape
- // `tile_shape`. Each tile is assigned to one device, which is specified by
- // `tile_assignment`. Any tensor not a multiple of the tile size in any
- // dimension is implicitly padded to the tile size.
- //
- // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like:
- // 2 1 padding
- // <------><->
- // +----+----+
- // | 0 | 1 |
- // +----+----+
- //
- // Split into two tiles, one of which is implicitly padded by one.
- static HloSharding Tile(const Shape& tile_shape,
- const Array<int64>& tile_assignment) {
- return HloSharding(tile_shape, tile_assignment);
+ // Creates a new sharding which splits a shape into tiles amongst the devices
+ // specified by `tile_assignment`.
+ static HloSharding Tile(const Array<int64>& tile_assignment) {
+ return HloSharding(tile_assignment);
}
// Creates a new sharding which splits a one-dimensional input shape into
@@ -146,17 +134,18 @@ class HloSharding {
// REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
- // Given a device ID, returns the offset within the input space of the
+ // Given a device ID, returns the offset within the specified shape of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
// REQUIRES: !IsTuple()
- std::vector<int64> TileOffsetForDevice(int64 device) const;
+ std::vector<int64> TileOffsetForDevice(const Shape& shape,
+ int64 device) const;
- // Given a device ID, returns the limit within the input space of the
+ // Given a device ID, returns the limit within the specified shape of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
// REQUIRES: !IsTuple()
- std::vector<int64> TileLimitForDevice(int64 device) const;
+ std::vector<int64> TileLimitForDevice(const Shape& shape, int64 device) const;
// Returns the single device this op operates on. If the sharding does not
// span a single device, the return value will be empty.
@@ -197,7 +186,6 @@ class HloSharding {
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
- ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
tile_assignment_ == other.tile_assignment_ &&
tuple_elements_ == other.tuple_elements_;
}
@@ -211,9 +199,6 @@ class HloSharding {
}
};
- // Gets the tile shape.
- // REQUIRES: !IsTileMaximal() && !IsTuple()
- const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
// REQUIRES: !IsReplicated() && !IsTuple()
const Array<int64>& tile_assignment() const { return tile_assignment_; }
@@ -225,25 +210,15 @@ class HloSharding {
return tuple_elements_;
}
- // Return a new sharding that can apply to the given new shape.
- // If this sharding is tile-maximal, the returned sharding will be the same as
- // this sharding. If this sharding is not tile-maximal, the returned
- // sharding's tile size will differ:
- // - Non-sharded dimensions will be adapted to be the same as `new_shape`;
- // tile_dimension(i) = new_shape.dimensions(i);
- // - Sharded dimensions will be kept the same unless `transform` is supplied
- // in which case tile_dimension(i) = transform(i, tile_dimension(i));
- // REQUIRES: !IsTuple().
- HloSharding TransformShardedTileShape(
- const Shape& new_shape,
- const std::function<int64(int64, int64)>& transform = nullptr) const;
+ // Gets the tile shape.
+ // REQUIRES: !IsTuple()
+ Shape TileShape(const Shape& shape) const;
private:
HloSharding()
: replicated_(true),
maximal_(true),
tuple_(false),
- tile_shape_(),
tile_assignment_({0}) {}
// device_id values:
// -2: magic number to mean unassigned device, used by spatial partitioning
@@ -255,15 +230,13 @@ class HloSharding {
: replicated_(false),
maximal_(true),
tuple_(false),
- tile_shape_(),
tile_assignment_({1}, device_id) {}
- HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
+ explicit HloSharding(const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
tuple_(false),
- tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
- HloSharding(const std::vector<HloSharding>& tuple_shardings)
+ explicit HloSharding(const std::vector<HloSharding>& tuple_shardings)
: replicated_(false),
maximal_(false),
tuple_(true),
@@ -286,7 +259,6 @@ class HloSharding {
bool replicated_;
bool maximal_;
bool tuple_;
- Shape tile_shape_;
Array<int64> tile_assignment_;
// Only non-empty when tuple_ is true, but because empty tuples are allowed
// may also be empty even then. This is a flattened list of all the leaf
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index aebda562d3..45fc300fca 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -39,7 +39,6 @@ Array<int64> MakeArray(tensorflow::gtl::ArraySlice<int64> dimensions,
class HloShardingTest : public HloTestBase {};
TEST_F(HloShardingTest, Replicate) {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {4});
HloSharding sharding = HloSharding::Replicate();
EXPECT_TRUE(sharding.IsReplicated());
EXPECT_TRUE(sharding.IsTileMaximal());
@@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) {
TEST_F(HloShardingTest, Tile) {
{
// Test should fail because of a duplicate tile assignment.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3}));
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}),
/*num_devices=*/4));
}
{
// Test should fail because of more devices used then `num_device`.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3}));
EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}),
/*num_devices=*/2));
}
{
- // Test should fail because the total tiled size in dimension 0 is 4 but we
- // have 6 elements along that dimensions.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
- EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}),
- /*num_devices=*/4));
- }
-
- {
// Test should pass.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
+ Shape shape = ShapeUtil::MakeShape(U32, {4, 5});
+ HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}),
/*num_devices=*/5));
@@ -118,10 +102,14 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0}));
EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1}));
- EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector<int64>{0, 0}));
- EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector<int64>{0, 3}));
- EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
- EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0),
+ (std::vector<int64>{0, 0}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3),
+ (std::vector<int64>{0, 3}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2),
+ (std::vector<int64>{2, 0}));
+ EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1),
+ (std::vector<int64>{2, 3}));
EXPECT_FALSE(sharding.HasUniqueDevice());
}
@@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) {
ShapeUtil::MakeShape(F32, {4, 6}),
});
- HloSharding tiled_sharding = HloSharding::Tile(
- ShapeUtil::MakeShape(F32, {4, 3}), Array<int64>({{0, 1}}));
+ HloSharding tiled_sharding = HloSharding::Tile(Array<int64>({{0, 1}}));
OpSharding proto;
proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
*proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
@@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) {
}
{
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 2, 1}));
- EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
- }
-
- {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 2, 1}));
+ HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
+ HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1}));
EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
}
- {
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding1 =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1}));
- HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}),
- MakeArray({2, 2}, {0, 3, 1, 2}));
- EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
- }
-
HloSharding default_sharding = HloSharding::Replicate();
{
ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
@@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) {
}
}
-TEST_F(HloShardingTest, TransformShardedTileShapeTest) {
- HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
- HloSharding result = sharding.TransformShardedTileShape(
- ShapeUtil::MakeShape(F32, {13, 15, 17, 19}),
- [](int dim, int value) { return dim * 111; });
- HloSharding expected =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
- EXPECT_EQ(result, expected);
-}
-
TEST_F(HloShardingTest, ToStringReplicatedTest) {
HloSharding sharding = HloSharding::Replicate();
EXPECT_EQ(sharding.ToString(), "{replicated}");
@@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) {
TEST_F(HloShardingTest, ToStringTiledTest) {
HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}),
- Array3D<int64>({{{2, 3}}, {{5, 7}}}));
- EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}");
+ HloSharding::Tile(Array3D<int64>({{{2, 3}}, {{5, 7}}}));
+ EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}");
}
TEST_F(HloShardingTest, ToStringTupleTest) {
@@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}),
ShapeUtil::MakeShape(U32, {7, 25}),
ShapeUtil::MakeShape(S32, {9, 11})}),
- {HloSharding::Replicate(),
- HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}),
- Array2D<int64>({{3, 5}})),
+ {HloSharding::Replicate(), HloSharding::Tile(Array2D<int64>({{3, 5}})),
HloSharding::AssignDevice(3)});
EXPECT_EQ(sharding.ToString(),
- "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}");
+ "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}");
}
TEST_F(HloShardingTest, OstreamTest) {
HloSharding sharding =
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}),
- Array4D<int64>({{{{0, 1}, {2, 3}}}}));
+ HloSharding::Tile(Array4D<int64>({{{{0, 1}, {2, 3}}}}));
std::ostringstream oss;
oss << sharding;
- EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
+ EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}");
}
TEST_F(HloShardingTest, ParseHloString) {
@@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) {
};
check(HloSharding::Replicate());
check(HloSharding::AssignDevice(2));
- check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})));
+ check(HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})));
// Empty tuple. One sharding is required for empty tuples, as we need to be
// able to assign sharding to them, even though they have no leaves.
check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
@@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) {
ShapeUtil::MakeShape(F32, {3, 5, 7}),
ShapeUtil::MakeShape(F32, {3, 7})});
check(HloSharding::Tuple(
- tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})),
+ tuple_shape, {HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
HloSharding::Replicate(), HloSharding::AssignDevice(1)}));
}
{
@@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) {
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}),
ShapeUtil::MakeShape(F32, {3, 7})})});
std::vector<HloSharding> leaf_shardings = {
- HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
- Array4D<int64>({{{{0}, {1}}}})),
+ HloSharding::Tile(Array4D<int64>({{{{0}, {1}}}})),
HloSharding::Replicate(), HloSharding::AssignDevice(1)};
ShapeTree<HloSharding> sharding_tree(tuple_shape, HloSharding::Replicate());
// Assign leaf_shardings to sharding_tree leaves.