diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 83 |
1 files changed, 7 insertions, 76 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index f8ef2a3d05..d7ada30c70 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" @@ -68,18 +67,6 @@ class HloSharding { // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); - // Creates a new sharding for a tuple type. The given ShapeTree must have - // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) { - std::vector<HloSharding> flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } - // Create a new sharding from a protobuf OpSharding. static StatusOr<HloSharding> FromProto(const OpSharding& proto); @@ -89,89 +76,47 @@ class HloSharding { // Validate that this sharding can be applied to a tensor with shape `shape`. Status Validate(const Shape& shape, int64 num_devices) const; - // Returns true if the sharding has tuple type. - bool IsTuple() const { return tuple_; } - // Returns true if the sharding is trivial: replicate on all devices. - bool IsReplicated() const { - if (!IsTuple()) { - return replicated_; - } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsReplicated(); }); - } + bool IsReplicated() const { return replicated_; } // Returns true if the tile size is the same as the input size. - bool IsTileMaximal() const { - if (!IsTuple()) { - return maximal_; - } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsTileMaximal(); }); - } + bool IsTileMaximal() const { return maximal_; } // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; // Returns the tile that should be executed on the given device. - // REQUIRES: !IsTuple() std::vector<int64> TileIndexForDevice(int64 device) const; // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. - // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const; // Given a device ID, returns the offset within the input space of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. - // REQUIRES: !IsTuple() std::vector<int64> TileOffsetForDevice(int64 device) const; // Given a device ID, returns the limit within the input space of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. - // REQUIRES: !IsTuple() std::vector<int64> TileLimitForDevice(int64 device) const; // Returns the single device this op operates on. - // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() + // Requires !Replicated() && IsTileMaximal(). StatusOr<int64> UniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const; - - // Returns the ShapeTree containing the shardings for each element of this - // tuple. Only the leaf elements are populated. This creates a new ShapeTree - // object so is not cheap. REQUIRES: IsTuple() - ShapeTree<HloSharding> GetTupleShardingsAsShapeTree( - const Shape& tuple_shape) const { - ShapeTree<HloSharding> result(tuple_shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } + bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); } bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && - tile_assignment_ == other.tile_assignment_ && - tuple_elements_ == other.tuple_elements_; + tile_assignment_ == other.tile_assignment_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } if (replicated_) { return 0; } @@ -186,47 +131,33 @@ class HloSharding { } // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() + // It is an error to call this if IsTileMaximal() is true. const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. - // REQUIRES: !IsReplicated() && !IsTuple() + // It is an error to call this if IsReplicated() is true. const Array<int64>& tile_assignment() const { return tile_assignment_; } private: HloSharding() : replicated_(true), maximal_(true), - tuple_(false), tile_shape_(), tile_assignment_({0}) {} explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), - tuple_(false), tile_shape_(), tile_assignment_({1}, device_id) {} HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment) : replicated_(false), maximal_(false), - tuple_(false), tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector<HloSharding>& tuple_shardings) - : replicated_(false), - maximal_(false), - tuple_(true), - tile_assignment_({0}), - tuple_elements_(tuple_shardings) {} bool replicated_; bool maximal_; - bool tuple_; Shape tile_shape_; Array<int64> tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). - std::vector<HloSharding> tuple_elements_; }; } // namespace xla |