diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-24 07:38:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-24 07:41:40 -0700 |
commit | 1ce99cfa52b19a40cff8a9ae983a0a7f04eb2bf1 (patch) | |
tree | 77318f028cc6086a234801c86fd79125121c68b5 /tensorflow/compiler/xla | |
parent | 5eb233d0686636a7bacc5b8813c079b6b9aa483c (diff) |
Softens the requirements in the HLO sharding validation
The goal is to support tiled shardings where the last N tile have no data.
PiperOrigin-RevId: 194085302
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 39 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 15 |
2 files changed, 16 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 1b42349b0b..994de44123 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -256,37 +256,24 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, ", input_shape=", ShapeUtil::HumanString(shape)); } - // The tile shape must not be the same as the input shape without maximal_ - // also set. If this is the case, we're not actually sharded and the correct - // constructor should have been used. - if (ShapeUtil::Equal(shape, tile_shape_)) { + // The correct constructor have to be used to create tile maximal shardings. + if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( - "Tile shape is the same as the input shape. If a replicated sharding " - "was intended, use HloSharding::Replicated(). If a device placement " - "was intended, use HloSharding::AssignDevice()"); + "Tile assignment only contains a single device. If a replicated " + "sharding was intended, use HloSharding::Replicated(). If a device " + "placement was intended, use HloSharding::AssignDevice()"); } - // The tile shape must not be greater than the input shape in any dimension. - for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) { - auto tile_dim = tile_shape_.dimensions(i); - auto shape_dim = shape.dimensions(i); - if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile is larger than input shape (dimension ", i, ", ", - tile_dim, " > ", shape_dim)); - } - } - - // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim] - // tile[dim]) for every dimension contained within tile. + // The tile assignment tensor must contain enough element to cover the full + // shape with tiles of the specified size. for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 expected_dim = - CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); - if (tile_assignment_.dimensions()[i] != expected_dim) { + int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); + if (shape.dimensions(i) > total_tile_size) { return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + StrCat("Tile assignment tensor has too few element to cover the full " + "shape. Dimension ", + i, ", shape ", shape.dimensions(i), ", total size ", + total_tile_size)); } } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 69ea4233e4..3bf0d25efb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -88,7 +88,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should pass. + // Test should fail because of more devices used then `num_device`. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); @@ -97,17 +97,8 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail due to the tile being larger than the input space. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}), - /*num_devices=*/4)); - } - - { - // Test should fail due to the tile not dividing the input space into 4 - // sections (even with padding). + // Test should fail because the total tiled size in dimension 0 is 4 but we + // have 6 elements along that dimensions. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); |