diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-09 02:32:44 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:38 -0800 |
commit | c38781f6ec9710c0102bdc9d95bf6176fd96d1ce (patch) | |
tree | f52551bc7beabfdf8007a5de9d5d8b6ec8cd625d /tensorflow/compiler/xla/service/hlo_sharding_test.cc | |
parent | 791c8bf0baf4198c5922dba08a74960ca6dac74f (diff) |
When sharding a tuple, we typically want to describe the data sharding
of each individual subtensor individually. Tuples are essentially just
containers - the tensors they contain should be able to be sharded
differently.
Tuples are hierarchically structured, but shardings were designed to
not contain the sharded type (the sharded type is inferred from the
output type of the instruction the sharding is applied to). Therefore,
shardings for tuples contain shardings for each subtensor as a
non-structured list.
This list is ordered as a preorder walk of the tuple shape, and of
course only the leaf nodes of the tuple shape are stored. The
structure is reapplied when the sharded instruction's shape is known.
PiperOrigin-RevId: 175132692
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index d0a20471a0..00ea38480e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) { } } +TEST_F(HloShardingTest, NestedTuple) { + // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), + ShapeUtil::MakeShape(F32, {4, 6}), + }); + + OpSharding proto; + proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto(); + HloSharding tuple_sharding = + HloSharding::FromProto(proto).ConsumeValueOrDie(); + + ShapeTree<HloSharding> shape_tree = + tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape); + EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); + EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); + EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1)); +} + TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } + + HloSharding default_sharding = HloSharding::Replicate(); + { + ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Tuple(shape_tree); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); + ShapeTree<HloSharding> shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree<HloSharding> shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); + ShapeTree<HloSharding> shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } } } // namespace |