diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 268b4727bc..393944c20f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple( const Shape& tuple_shape, tensorflow::gtl::ArraySlice<HloSharding> shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + for (auto& sharding : shardings) { + CHECK(!sharding.IsTuple()) << sharding.ToString(); + } std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end()); CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) << "Flat list has " << flattened_list.size() << ", required " @@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple( return HloSharding(flattened_list); } +HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(!sharding.IsTuple()) << sharding.ToString(); + int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + std::vector<HloSharding> flattened_list; + flattened_list.reserve(leaf_count); + for (int64 i = 0; i < leaf_count; ++i) { + flattened_list.push_back(sharding); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Single(const Shape& shape, + const HloSharding& sharding) { + return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector<string> parts; |