aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 268b4727bc..393944c20f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple(
const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ for (auto& sharding : shardings) {
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ }
std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
<< "Flat list has " << flattened_list.size() << ", required "
@@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple(
return HloSharding(flattened_list);
}
+HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding) {
+ CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(leaf_count);
+ for (int64 i = 0; i < leaf_count; ++i) {
+ flattened_list.push_back(sharding);
+ }
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::Single(const Shape& shape,
+ const HloSharding& sharding) {
+ return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding;
+}
+
string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;