aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc74
1 files changed, 5 insertions, 69 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 7356663454..0d019d22f5 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
namespace xla {
@@ -39,15 +38,6 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
}
string HloSharding::ToString() const {
- if (IsTuple()) {
- std::vector<string> parts;
- parts.reserve(tuple_elements_.size());
- for (const HloSharding& element : tuple_elements_) {
- parts.push_back(element.ToString());
- }
- return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
- }
-
string result = StrCat("{", (replicated_ ? " replicated" : ""),
(maximal_ ? " maximal" : ""));
@@ -63,11 +53,6 @@ string HloSharding::ToString() const {
}
bool HloSharding::UsesDevice(int64 device) const {
- if (IsTuple()) {
- return std::any_of(
- tuple_elements_.begin(), tuple_elements_.end(),
- [&](const HloSharding& s) { return s.UsesDevice(device); });
- }
const auto& devices = tile_assignment_;
return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end();
@@ -76,7 +61,6 @@ bool HloSharding::UsesDevice(int64 device) const {
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
- CHECK(!IsTuple());
std::vector<int64> ret_index;
tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
if (d == device) {
@@ -90,7 +74,6 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
int64 HloSharding::DeviceForTileIndex(
tensorflow::gtl::ArraySlice<int64> index) const {
CHECK(!replicated_);
- CHECK(!IsTuple());
if (maximal_) {
return *tile_assignment_.begin();
}
@@ -99,7 +82,7 @@ int64 HloSharding::DeviceForTileIndex(
}
std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
- CHECK(!IsTuple());
+ CHECK(!ShapeUtil::IsTuple(tile_shape_));
std::vector<int64> index = TileIndexForDevice(device);
if (maximal_) {
@@ -114,7 +97,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
}
std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
- CHECK(!IsTuple());
+ CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_); // Maximal shardings do not have a valid tile shape.
std::vector<int64> index = TileIndexForDevice(device);
@@ -125,41 +108,13 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
}
StatusOr<int64> HloSharding::UniqueDevice() const {
- if (IsTuple()) {
- if (tuple_elements_.empty()) {
- return tensorflow::errors::InvalidArgument(
- "UniqueDevice() called on empty tuple");
- }
- std::vector<StatusOr<int64>> results;
- std::transform(tuple_elements_.begin(), tuple_elements_.end(),
- std::back_inserter(results),
- [](const HloSharding& s) { return s.UniqueDevice(); });
- if (std::all_of(results.begin(), results.end(),
- [&](const StatusOr<int64>& s) {
- return s.ok() && results[0].ok() &&
- s.ValueOrDie() == results[0].ValueOrDie();
- })) {
- return results[0];
- } else {
- return tensorflow::errors::InvalidArgument(
- "Tuple did not contain a unique device");
- }
- }
- if (!replicated_ && maximal_ && !IsTuple()) {
+ if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
return tensorflow::errors::InvalidArgument(
"UniqueDevice() called on sharding that executes on multiple devices");
}
-bool HloSharding::HasUniqueDevice() const {
- if (IsTuple()) {
- return UniqueDevice().status().ok();
- } else {
- return !IsReplicated() && IsTileMaximal();
- }
-}
-
Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
if (replicated_) {
return Status::OK();
@@ -238,19 +193,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
const OpSharding& proto) {
- if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
- std::vector<HloSharding> tuple_shardings;
- tuple_shardings.reserve(proto.tuple_shardings().size());
- for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
- TF_ASSIGN_OR_RETURN(HloSharding sharding,
- HloSharding::FromProto(tuple_sharding_proto));
- tuple_shardings.push_back(sharding);
- }
- return HloSharding(tuple_shardings);
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
- proto.tile_assignment_devices().size() == 1) {
+ } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
return HloSharding(proto.tile_assignment_devices(0));
}
// Some versions of gcc cannot infer the TileAssignment constructor from a
@@ -267,15 +212,6 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
OpSharding HloSharding::ToProto() const {
OpSharding result;
-
- if (IsTuple()) {
- for (const HloSharding& element : tuple_elements_) {
- *result.add_tuple_shardings() = element.ToProto();
- }
- result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
- return result;
- }
-
*result.mutable_tile_shape() = tile_shape_;
for (int64 dim : tile_assignment_.dimensions()) {
result.add_tile_assignment_dimensions(dim);