diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_test.cc | 68 |
1 files changed, 0 insertions, 68 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 00ea38480e..d0a20471a0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -132,29 +132,6 @@ TEST_F(HloShardingTest, Tile) { } } -TEST_F(HloShardingTest, NestedTuple) { - // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) - Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), - ShapeUtil::MakeShape(F32, {4, 6}), - }); - - OpSharding proto; - proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); - *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); - *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); - *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto(); - HloSharding tuple_sharding = - HloSharding::FromProto(proto).ConsumeValueOrDie(); - - ShapeTree<HloSharding> shape_tree = - tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape); - EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); - EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); - EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1)); -} - TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -207,51 +184,6 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } - - HloSharding default_sharding = HloSharding::Replicate(); - { - ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), - default_sharding); - HloSharding sharding1 = HloSharding::Replicate(); - HloSharding sharding2 = HloSharding::Tuple(shape_tree); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - - { - ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}), - default_sharding); - HloSharding sharding1 = HloSharding::Tuple(shape_tree); - HloSharding sharding2 = HloSharding::Tuple(shape_tree); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - ShapeTree<HloSharding> shape_tree1( - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), - default_sharding); - *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); - ShapeTree<HloSharding> shape_tree2( - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), - default_sharding); - *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); - HloSharding sharding1 = HloSharding::Tuple(shape_tree1); - HloSharding sharding2 = HloSharding::Tuple(shape_tree2); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - - { - ShapeTree<HloSharding> shape_tree1( - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), - default_sharding); - *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); - ShapeTree<HloSharding> shape_tree2( - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), - default_sharding); - *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); - HloSharding sharding1 = HloSharding::Tuple(shape_tree1); - HloSharding sharding2 = HloSharding::Tuple(shape_tree2); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } } } // namespace |