aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc68
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index d0a20471a0..00ea38480e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) {
}
}
+TEST_F(HloShardingTest, NestedTuple) {
+ // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
+ ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
+ ShapeUtil::MakeShape(F32, {4, 6}),
+ });
+
+ OpSharding proto;
+ proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+ *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
+ *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
+ HloSharding tuple_sharding =
+ HloSharding::FromProto(proto).ConsumeValueOrDie();
+
+ ShapeTree<HloSharding> shape_tree =
+ tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape);
+ EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
+ EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
+ EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
+}
+
TEST_F(HloShardingTest, Hash) {
auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
if (a.Hash() != b.Hash()) {
@@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) {
MakeArray({2, 2}, {0, 3, 1, 2}));
EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
}
+
+ HloSharding default_sharding = HloSharding::Replicate();
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Replicate();
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+ default_sharding);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+ }
+
+ {
+ ShapeTree<HloSharding> shape_tree1(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
+ ShapeTree<HloSharding> shape_tree2(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+ default_sharding);
+ *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+ HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+ HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+ EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+ }
}
} // namespace