aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc15
1 files changed, 3 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 69ea4233e4..3bf0d25efb 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -88,7 +88,7 @@ TEST_F(HloShardingTest, Tile) {
}
{
- // Test should pass.
+ // Test should fail because of more devices used then `num_device`.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
@@ -97,17 +97,8 @@ TEST_F(HloShardingTest, Tile) {
}
{
- // Test should fail due to the tile being larger than the input space.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
- EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
- /*num_devices=*/4));
- }
-
- {
- // Test should fail due to the tile not dividing the input space into 4
- // sections (even with padding).
+ // Test should fail because the total tiled size in dimension 0 is 4 but we
+ // have 6 elements along that dimensions.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));