diff options
author | 2018-08-29 05:30:49 -0700 | |
---|---|---|
committer | 2018-08-29 05:34:52 -0700 | |
commit | 8dfb7532f8278e53a86a847ba6aa9c441f7b021b (patch) | |
tree | ff945473983cfce5fbcdc6aabe34f0f823a00ae7 /tensorflow/compiler/xla/service/hlo_sharding.cc | |
parent | a15f91d27e79cfe8d8b63ecd8121e19929924df6 (diff) |
Improve the implementation of HloSharding::GetSubSharding
The new implementation iterates the ShapeIndex and the flat HloSharding
directly instead of constructing a ShapeTree what provides a significant
performance benefit for large tuple shapes.
PiperOrigin-RevId: 210705719
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 980dae07ce..1235259764 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -429,12 +429,23 @@ Shape HloSharding::TileShape(const Shape& shape) const { HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - - Shape sub_shape = ShapeUtil::GetSubshape(shape, index); - ShapeTree<HloSharding> sub_shape_tree(sub_shape, Replicate()); - sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) - : sub_shape_tree.element(ShapeIndex({})); + int64 sharding_index = 0; + const Shape* sub_shape = &shape; + for (int64 idx : index) { + for (int64 i = 0; i < idx; ++i) { + sharding_index += + ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i})); + } + sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); + } + if (ShapeUtil::IsTuple(*sub_shape)) { + auto begin_it = tuple_elements_.begin() + sharding_index; + std::vector<HloSharding> sub_shardings( + begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); + return HloSharding::Tuple(*sub_shape, sub_shardings); + } else { + return tuple_elements_[sharding_index]; + } } absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const { |