diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-20 16:40:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 16:43:36 -0700 |
commit | 2d6d0351a5440db144ea42b8ae19b9ee7952a7a5 (patch) | |
tree | 812707983c7505f6d9b2a12a3d09b3b9ce6eae54 /tensorflow/compiler/xla/service/hlo_sharding.h | |
parent | 34a12dff9812d291dff494dae9abecc13b494b8a (diff) |
Propagate dominant devices to kWhile computations.
PiperOrigin-RevId: 201439537
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 31 |
1 files changed, 11 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 1e843481c3..34324d2058 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -19,7 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#include <map> #include <string> +#include <vector> #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -118,6 +120,14 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; + // Retrieves an histogram of the devices used by the sharding. The returned + // map has the device number as key, and the occurrence count as value. + // If a sharding does not have a device, it will not be incuded in the + // histogram. The count argument, if not nullptr, will receive the total + // number of elements this sharding is made of (one for array, N leaves for + // tuples). + std::map<int64, int64> UsedDevices(int64* count) const; + // Returns the tile that should be executed on the given device. // REQUIRES: !IsTuple() std::vector<int64> TileIndexForDevice(int64 device) const; @@ -179,26 +189,7 @@ class HloSharding { } bool operator!=(const HloSharding& other) const { return !(*this == other); } - size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } - if (replicated_) { - return 0; - } - size_t h = 0; - for (uint32 v : tile_assignment_) { - h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); - } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v)); - } - return h; - } + size_t Hash() const; struct Hasher { size_t operator()(const HloSharding& sharding) const { |