diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding.h | 9 |
3 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index a2cb21c09b..efdeb6c64f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -427,7 +427,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map<string, ColorScheme> sharding_colors_; + std::unordered_map<HloSharding, ColorScheme, HloSharding::Hasher> + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -882,14 +883,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast<ColorScheme>( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7f7e3f7dab..7708422ce1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -49,9 +49,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42..e8bb06c8f7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -99,6 +99,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -208,6 +211,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } |