diff options
author | 2018-07-31 23:18:58 -0700 | |
---|---|---|
committer | 2018-07-31 23:23:48 -0700 | |
commit | 26ba623dccacfb2f913951e12089a8340e6a11ac (patch) | |
tree | bfa67abb9eb1d408bb1a4a52b8c7c5c3d470b00f /tensorflow/compiler/xla/service/hlo_sharding.cc | |
parent | 92279f8bfa6ce2124439aabfa6db84d722dc2b66 (diff) |
Cleanup the sharding unique device API.
PiperOrigin-RevId: 206885051
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 49 |
1 files changed, 20 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 393944c20f..6399f6ef3c 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const { if (IsTuple()) { for (auto& tuple_element_sharding : tuple_elements()) { auto unique_device = tuple_element_sharding.UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } element_count = tuple_elements().size(); } else { auto unique_device = UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } if (count != nullptr) { @@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree<HloSharding>(shape, *this)); } -StatusOr<int64> HloSharding::UniqueDevice() const { +tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); + return tensorflow::gtl::nullopt; } - std::vector<StatusOr<int64>> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr<int64>& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); + tensorflow::gtl::optional<int64> unique_device; + for (auto& tuple_sharding : tuple_elements_) { + auto device = tuple_sharding.UniqueDevice(); + if (!device || (unique_device && *device != *unique_device)) { + return tensorflow::gtl::nullopt; + } + unique_device = device; } + return unique_device; } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast<int64>(*tile_assignment_.begin()); } - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on sharding that executes on multiple devices"); + return tensorflow::gtl::nullopt; } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } +int64 HloSharding::GetUniqueDevice() const { + auto device = UniqueDevice(); + CHECK(device) << "Sharding does not have a unique device: " << *this; + return *device; } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { |