aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-09 02:32:44 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:38 -0800
commitc38781f6ec9710c0102bdc9d95bf6176fd96d1ce (patch)
treef52551bc7beabfdf8007a5de9d5d8b6ec8cd625d /tensorflow/compiler/xla/service/hlo_sharding_test.cc
parent791c8bf0baf4198c5922dba08a74960ca6dac74f (diff)
When sharding a tuple, we typically want to describe the data sharding
of each individual subtensor individually. Tuples are essentially just containers - the tensors they contain should be able to be sharded differently. Tuples are hierarchically structured, but shardings were designed to not contain the sharded type (the sharded type is inferred from the output type of the instruction the sharding is applied to). Therefore, shardings for tuples contain shardings for each subtensor as a non-structured list. This list is ordered as a preorder walk of the tuple shape, and of course only the leaf nodes of the tuple shape are stored. The structure is reapplied when the sharded instruction's shape is known. PiperOrigin-RevId: 175132692
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc68
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index d0a20471a0..00ea38480e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) {
}
}
+TEST_F(HloShardingTest, NestedTuple) {
+ // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
+ ShapeUtil::MakeShape(F32, {4, 6}),
+ });
+
+ OpSharding proto;
+ proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
+ HloSharding tuple_sharding =
+ HloSharding::FromProto(proto).ConsumeValueOrDie();
+
+ ShapeTree<HloSharding> shape_tree =
+ tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape);
+ EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
+ EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
+ EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
+}
+
TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) {
@@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) {
MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
+
+ HloSharding default_sharding = HloSharding::Replicate();
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Replicate();
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
}
} // namespace