diff options
author | 2018-07-07 06:54:44 -0700 | |
---|---|---|
committer | 2018-07-07 21:40:28 -0700 | |
commit | 35287be3bb7daa0448af064f5d005a25201d6853 (patch) | |
tree | 81a52537771013064cd727780c285e8da14fd39a /tensorflow/compiler/xla/service/hlo_sharding.cc | |
parent | 6d5b8b7cae669372df3f756f827c40b08a0d14a9 (diff) |
Build fully connected graph which edges across called computations.
Restructured sharding passes to propagate sharding on pass-through instructions which now the placer does not assign anymore (GTEs, tuples, bitcast, parameters, ...).
PiperOrigin-RevId: 203591020
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 268b4727bc..393944c20f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -60,6 +60,9 @@ HloSharding HloSharding::Tuple( const Shape& tuple_shape, tensorflow::gtl::ArraySlice<HloSharding> shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + for (auto& sharding : shardings) { + CHECK(!sharding.IsTuple()) << sharding.ToString(); + } std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end()); CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) << "Flat list has " << flattened_list.size() << ", required " @@ -67,6 +70,24 @@ HloSharding HloSharding::Tuple( return HloSharding(flattened_list); } +HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(!sharding.IsTuple()) << sharding.ToString(); + int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + std::vector<HloSharding> flattened_list; + flattened_list.reserve(leaf_count); + for (int64 i = 0; i < leaf_count; ++i) { + flattened_list.push_back(sharding); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Single(const Shape& shape, + const HloSharding& sharding) { + return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector<string> parts; |