aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 05:30:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 05:34:52 -0700
commit8dfb7532f8278e53a86a847ba6aa9c441f7b021b (patch)
treeff945473983cfce5fbcdc6aabe34f0f823a00ae7 /tensorflow/compiler/xla/service/hlo_sharding.cc
parenta15f91d27e79cfe8d8b63ecd8121e19929924df6 (diff)
Improve the implementation of HloSharding::GetSubSharding
The new implementation iterates the ShapeIndex and the flat HloSharding directly instead of constructing a ShapeTree what provides a significant performance benefit for large tuple shapes. PiperOrigin-RevId: 210705719
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc23
1 files changed, 17 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 980dae07ce..1235259764 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -429,12 +429,23 @@ Shape HloSharding::TileShape(const Shape& shape) const {
HloSharding HloSharding::GetSubSharding(const Shape& shape,
const ShapeIndex& index) const {
CHECK(IsTuple());
-
- Shape sub_shape = ShapeUtil::GetSubshape(shape, index);
- ShapeTree<HloSharding> sub_shape_tree(sub_shape, Replicate());
- sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {});
- return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree)
- : sub_shape_tree.element(ShapeIndex({}));
+ int64 sharding_index = 0;
+ const Shape* sub_shape = &shape;
+ for (int64 idx : index) {
+ for (int64 i = 0; i < idx; ++i) {
+ sharding_index +=
+ ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i}));
+ }
+ sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx});
+ }
+ if (ShapeUtil::IsTuple(*sub_shape)) {
+ auto begin_it = tuple_elements_.begin() + sharding_index;
+ std::vector<HloSharding> sub_shardings(
+ begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape));
+ return HloSharding::Tuple(*sub_shape, sub_shardings);
+ } else {
+ return tuple_elements_[sharding_index];
+ }
}
absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {