aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-07 06:54:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-07 21:40:28 -0700
commit35287be3bb7daa0448af064f5d005a25201d6853 (patch)
tree81a52537771013064cd727780c285e8da14fd39a /tensorflow/compiler/xla/service/hlo_sharding.cc
parent6d5b8b7cae669372df3f756f827c40b08a0d14a9 (diff)
Build fully connected graph which edges across called computations.
Restructured sharding passes to propagate sharding on pass-through instructions which now the placer does not assign anymore (GTEs, tuples, bitcast, parameters, ...). PiperOrigin-RevId: 203591020
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 268b4727bc..393944c20f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple(
const Shape& tuple_shape,
tensorflow::gtl::ArraySlice<HloSharding> shardings) {
CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ for (auto& sharding : shardings) {
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ }
std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
<< "Flat list has " << flattened_list.size() << ", required "
@@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple(
return HloSharding(flattened_list);
}
+HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding) {
+ CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(leaf_count);
+ for (int64 i = 0; i < leaf_count; ++i) {
+ flattened_list.push_back(sharding);
+ }
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::Single(const Shape& shape,
+ const HloSharding& sharding) {
+ return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding;
+}
+
string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;