diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-14 10:43:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 10:45:35 -0700 |
commit | 5fb7401959391f7583087f404a48353ab21ef1ca (patch) | |
tree | 7413cfeec40ad33a8c4468219bbdf234fe9a6ed4 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 157c347f832413c29265e467cc733366b4b215a6 (diff) |
Use utility methods to compute AttrValue hash code and check for equality.
PiperOrigin-RevId: 196531355
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 38 |
1 files changed, 15 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index cd7e742e5c..adef75f63e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -23,6 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tensor_coding.h" @@ -1784,7 +1786,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { class UniqueNodes { public: NodeDef* FindOrAddRepresentative(NodeDef* node) { - std::size_t sig = ComputeSignature(*node); + uint64 sig = ComputeSignature(*node); std::vector<NodeDef*>& candidates = rep_[sig]; for (auto& candidate : candidates) { if (SameNode(*candidate, *node)) { @@ -1796,26 +1798,25 @@ class UniqueNodes { } private: - std::size_t ComputeSignature(const NodeDef& node) const; + uint64 ComputeSignature(const NodeDef& node) const; bool SameNode(const NodeDef& node1, const NodeDef& node2) const; - std::unordered_map<std::size_t, std::vector<NodeDef*>> rep_; + std::unordered_map<uint64, std::vector<NodeDef*>> rep_; }; -std::size_t UniqueNodes::ComputeSignature(const NodeDef& node) const { - std::size_t h = std::hash<string>{}(node.op()); - h ^= std::hash<string>{}(node.device()); +uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { + uint64 h = Hash64(node.op()); + h = Hash64Combine(Hash64(node.device()), h); + for (const auto& input : node.input()) { int pos; string node_name = ParseNodeName(input, &pos); - h ^= std::hash<string>{}(node_name); - h ^= static_cast<std::size_t>(pos); + h = Hash64CombineUnordered(Hash64(node_name), h); + h = Hash64CombineUnordered(std::hash<int>()(pos), h); } for (const auto& attr : node.attr()) { - h ^= std::hash<string>{}(attr.first); - string tmp; - attr.second.AppendToString(&tmp); - h ^= std::hash<string>{}(tmp); + h = Hash64CombineUnordered(Hash64(attr.first), h); + h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h); } return h; } @@ -1871,17 +1872,8 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { } for (const auto& attr1 : node1.attr()) { auto it = node2.attr().find(attr1.first); - if (it == node2.attr().end()) { - return false; - } - const auto& attr2 = *it; - string val1; - attr1.second.AppendToString(&val1); - string val2; - attr2.second.AppendToString(&val2); - if (val1 != val2) { - return false; - } + if (it == node2.attr().end()) return false; + if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false; } return true; |