diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 74 |
1 files changed, 5 insertions, 69 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7356663454..0d019d22f5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { @@ -39,15 +38,6 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { } string HloSharding::ToString() const { - if (IsTuple()) { - std::vector<string> parts; - parts.reserve(tuple_elements_.size()); - for (const HloSharding& element : tuple_elements_) { - parts.push_back(element.ToString()); - } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); - } - string result = StrCat("{", (replicated_ ? " replicated" : ""), (maximal_ ? " maximal" : "")); @@ -63,11 +53,6 @@ string HloSharding::ToString() const { } bool HloSharding::UsesDevice(int64 device) const { - if (IsTuple()) { - return std::any_of( - tuple_elements_.begin(), tuple_elements_.end(), - [&](const HloSharding& s) { return s.UsesDevice(device); }); - } const auto& devices = tile_assignment_; return replicated_ || std::find(devices.begin(), devices.end(), device) != devices.end(); @@ -76,7 +61,6 @@ bool HloSharding::UsesDevice(int64 device) const { std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); - CHECK(!IsTuple()); std::vector<int64> ret_index; tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) { if (d == device) { @@ -90,7 +74,6 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const { int64 HloSharding::DeviceForTileIndex( tensorflow::gtl::ArraySlice<int64> index) const { CHECK(!replicated_); - CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); } @@ -99,7 +82,7 @@ int64 HloSharding::DeviceForTileIndex( } std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const { - CHECK(!IsTuple()); + CHECK(!ShapeUtil::IsTuple(tile_shape_)); std::vector<int64> index = TileIndexForDevice(device); if (maximal_) { @@ -114,7 +97,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const { } std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const { - CHECK(!IsTuple()); + CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. std::vector<int64> index = TileIndexForDevice(device); @@ -125,41 +108,13 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const { } StatusOr<int64> HloSharding::UniqueDevice() const { - if (IsTuple()) { - if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); - } - std::vector<StatusOr<int64>> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr<int64>& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); - } - } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast<int64>(*tile_assignment_.begin()); } return tensorflow::errors::InvalidArgument( "UniqueDevice() called on sharding that executes on multiple devices"); } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } -} - Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { if (replicated_) { return Status::OK(); @@ -238,19 +193,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { /*static*/ StatusOr<HloSharding> HloSharding::FromProto( const OpSharding& proto) { - if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { - std::vector<HloSharding> tuple_shardings; - tuple_shardings.reserve(proto.tuple_shardings().size()); - for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { - TF_ASSIGN_OR_RETURN(HloSharding sharding, - HloSharding::FromProto(tuple_sharding_proto)); - tuple_shardings.push_back(sharding); - } - return HloSharding(tuple_shardings); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { return HloSharding(proto.tile_assignment_devices(0)); } // Some versions of gcc cannot infer the TileAssignment constructor from a @@ -267,15 +212,6 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { OpSharding HloSharding::ToProto() const { OpSharding result; - - if (IsTuple()) { - for (const HloSharding& element : tuple_elements_) { - *result.add_tuple_shardings() = element.ToProto(); - } - result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); - return result; - } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); |