diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index d0a20471a0..00ea38480e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) { } } +TEST_F(HloShardingTest, NestedTuple) { + // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), + ShapeUtil::MakeShape(F32, {4, 6}), + }); + + OpSharding proto; + proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto(); + HloSharding tuple_sharding = + HloSharding::FromProto(proto).ConsumeValueOrDie(); + + ShapeTree<HloSharding> shape_tree = + tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape); + EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); + EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); + EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1)); +} + TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } + + HloSharding default_sharding = HloSharding::Replicate(); + { + ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Tuple(shape_tree); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); + ShapeTree<HloSharding> shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); + ShapeTree<HloSharding> shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } } } // namespace |