aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h83
1 files changed, 7 insertions, 76 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index f8ef2a3d05..d7ada30c70 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -68,18 +67,6 @@ class HloSharding {
// `num_tiles` tiles.
static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
- // Creates a new sharding for a tuple type. The given ShapeTree must have
- // elements for every leaf shape contained in the tuple.
- static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
- std::vector<HloSharding> flattened_list;
- flattened_list.reserve(
- std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
- for (const auto& index_to_sharding : sub_shardings.leaves()) {
- flattened_list.push_back(index_to_sharding.second);
- }
- return HloSharding(flattened_list);
- }
-
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
@@ -89,89 +76,47 @@ class HloSharding {
// Validate that this sharding can be applied to a tensor with shape `shape`.
Status Validate(const Shape& shape, int64 num_devices) const;
- // Returns true if the sharding has tuple type.
- bool IsTuple() const { return tuple_; }
-
// Returns true if the sharding is trivial: replicate on all devices.
- bool IsReplicated() const {
- if (!IsTuple()) {
- return replicated_;
- }
- return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
- [](const HloSharding& s) { return s.IsReplicated(); });
- }
+ bool IsReplicated() const { return replicated_; }
// Returns true if the tile size is the same as the input size.
- bool IsTileMaximal() const {
- if (!IsTuple()) {
- return maximal_;
- }
- return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
- [](const HloSharding& s) { return s.IsTileMaximal(); });
- }
+ bool IsTileMaximal() const { return maximal_; }
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
// Returns the tile that should be executed on the given device.
- // REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
// Returns the device that should execute the given tile.
// It is an error to call this if is_replicated() is true.
- // REQUIRES: !IsTuple()
int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
// Given a device ID, returns the offset within the input space of the
// tile that should be executed on the given core. This returns the lower
// extent of the tile in the input space.
- // REQUIRES: !IsTuple()
std::vector<int64> TileOffsetForDevice(int64 device) const;
// Given a device ID, returns the limit within the input space of the
// tile that should be executed on the given core. This returns the upper
// extent of the tile in the input space.
- // REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
// Returns the single device this op operates on.
- // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
+ // Requires !Replicated() && IsTileMaximal().
StatusOr<int64> UniqueDevice() const;
// Returns true if this op only uses a single device.
- bool HasUniqueDevice() const;
-
- // Returns the ShapeTree containing the shardings for each element of this
- // tuple. Only the leaf elements are populated. This creates a new ShapeTree
- // object so is not cheap. REQUIRES: IsTuple()
- ShapeTree<HloSharding> GetTupleShardingsAsShapeTree(
- const Shape& tuple_shape) const {
- ShapeTree<HloSharding> result(tuple_shape, HloSharding::Replicate());
- CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
- tuple_elements_.size());
- auto it = tuple_elements_.begin();
- for (auto& index_to_sharding : result.leaves()) {
- index_to_sharding.second = *it++;
- }
- return result;
- }
+ bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
- tile_assignment_ == other.tile_assignment_ &&
- tuple_elements_ == other.tuple_elements_;
+ tile_assignment_ == other.tile_assignment_;
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
size_t Hash() const {
- if (!tuple_) {
- size_t h = 0;
- for (const auto& element : tuple_elements_) {
- h = tensorflow::Hash64Combine(h, element.Hash());
- }
- return h;
- }
if (replicated_) {
return 0;
}
@@ -186,47 +131,33 @@ class HloSharding {
}
// Gets the tile shape.
- // REQUIRES: !IsTileMaximal() && !IsTuple()
+ // It is an error to call this if IsTileMaximal() is true.
const Shape& tile_shape() const { return tile_shape_; }
// Gets the tile assignment tensor.
- // REQUIRES: !IsReplicated() && !IsTuple()
+ // It is an error to call this if IsReplicated() is true.
const Array<int64>& tile_assignment() const { return tile_assignment_; }
private:
HloSharding()
: replicated_(true),
maximal_(true),
- tuple_(false),
tile_shape_(),
tile_assignment_({0}) {}
explicit HloSharding(int64 device_id)
: replicated_(false),
maximal_(true),
- tuple_(false),
tile_shape_(),
tile_assignment_({1}, device_id) {}
HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
: replicated_(false),
maximal_(false),
- tuple_(false),
tile_shape_(tile_shape),
tile_assignment_(tile_assignment) {}
- HloSharding(const std::vector<HloSharding>& tuple_shardings)
- : replicated_(false),
- maximal_(false),
- tuple_(true),
- tile_assignment_({0}),
- tuple_elements_(tuple_shardings) {}
bool replicated_;
bool maximal_;
- bool tuple_;
Shape tile_shape_;
Array<int64> tile_assignment_;
- // Only non-empty when tuple_ is true, but because empty tuples are allowed
- // may also be empty even then. This is a flattened list of all the leaf
- // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
- std::vector<HloSharding> tuple_elements_;
};
} // namespace xla