aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-05-29 19:07:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 19:10:10 -0700
commit28cec60df3397ed16c9897a2d1e26eea622ad3be (patch)
treeaec4277a03b80d59f6552a70baf83d9c1f9e7e52
parent02ba49573008c22758fb90c8e26dde24406c1584 (diff)
[XLA] Minor HloSharding cleanups.
Delete dead code in HloSharding::ToString(), and add and use proper hasher struct. PiperOrigin-RevId: 198493972
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h9
3 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index a2cb21c09b..efdeb6c64f 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -427,7 +427,8 @@ class HloDotDumper {
// When coloring by sharding information, we track the sharding string
// representation to color association, by round-robin the color schemes.
- std::unordered_map<string, ColorScheme> sharding_colors_;
+ std::unordered_map<HloSharding, ColorScheme, HloSharding::Hasher>
+ sharding_colors_;
int64 next_shard_color_ = 0;
};
@@ -882,14 +883,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
if (!instr->has_sharding()) {
return kDashedBorder;
}
- string shard_str = instr->sharding().ToString();
- auto it = sharding_colors_.find(shard_str);
+ auto it = sharding_colors_.find(instr->sharding());
if (it != sharding_colors_.end()) {
return it->second;
}
ColorScheme color = static_cast<ColorScheme>(
kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
- sharding_colors_.emplace(shard_str, color);
+ sharding_colors_.emplace(instr->sharding(), color);
return color;
}
const auto kParameterColor = kOrange;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 7f7e3f7dab..7708422ce1 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -49,9 +49,6 @@ string HloSharding::ToString() const {
return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
}
- string result = StrCat("{", (replicated_ ? " replicated" : ""),
- (maximal_ ? " maximal" : ""));
-
if (replicated_) {
return "{replicated}";
} else if (maximal_) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 2b8e757f42..e8bb06c8f7 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -99,6 +99,9 @@ class HloSharding {
static bool IsReservedDevice(int64 device) { return device < 0; }
OpSharding ToProto() const;
+
+ // Note that this string canonically has outer curly braces, e.g.
+ // "{replicated}".
string ToString() const;
// Validate that this sharding can be applied to a tensor with shape `shape`.
@@ -208,6 +211,12 @@ class HloSharding {
return h;
}
+ struct Hasher {
+ size_t operator()(const HloSharding& sharding) const {
+ return sharding.Hash();
+ }
+ };
+
// Gets the tile shape.
// REQUIRES: !IsTileMaximal() && !IsTuple()
const Shape& tile_shape() const { return tile_shape_; }