aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-20 16:40:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 16:43:36 -0700
commit2d6d0351a5440db144ea42b8ae19b9ee7952a7a5 (patch)
tree812707983c7505f6d9b2a12a3d09b3b9ce6eae54 /tensorflow/compiler/xla/service/hlo_sharding.h
parent34a12dff9812d291dff494dae9abecc13b494b8a (diff)
Propagate dominant devices to kWhile computations.
PiperOrigin-RevId: 201439537
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h31
1 files changed, 11 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 1e843481c3..34324d2058 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -19,7 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
+#include <map>
#include <string>
+#include <vector>
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -118,6 +120,14 @@ class HloSharding {
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
+ // Retrieves an histogram of the devices used by the sharding. The returned
+ // map has the device number as key, and the occurrence count as value.
+ // If a sharding does not have a device, it will not be incuded in the
+ // histogram. The count argument, if not nullptr, will receive the total
+ // number of elements this sharding is made of (one for array, N leaves for
+ // tuples).
+ std::map<int64, int64> UsedDevices(int64* count) const;
+
// Returns the tile that should be executed on the given device.
// REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
@@ -179,26 +189,7 @@ class HloSharding {
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
- size_t Hash() const {
- if (!tuple_) {
- size_t h = 0;
- for (const auto& element : tuple_elements_) {
- h = tensorflow::Hash64Combine(h, element.Hash());
- }
- return h;
- }
- if (replicated_) {
- return 0;
- }
- size_t h = 0;
- for (uint32 v : tile_assignment_) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- for (uint32 v : tile_shape_.dimensions()) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- return h;
- }
+ size_t Hash() const;
struct Hasher {
size_t operator()(const HloSharding& sharding) const {