aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-14 10:43:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 10:45:35 -0700
commit5fb7401959391f7583087f404a48353ab21ef1ca (patch)
tree7413cfeec40ad33a8c4468219bbdf234fe9a6ed4 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent157c347f832413c29265e467cc733366b4b215a6 (diff)
Use utility methods to compute AttrValue hash code and check for equality.
PiperOrigin-RevId: 196531355
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc38
1 files changed, 15 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index cd7e742e5c..adef75f63e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -38,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tensor_coding.h"
@@ -1784,7 +1786,7 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
class UniqueNodes {
public:
NodeDef* FindOrAddRepresentative(NodeDef* node) {
- std::size_t sig = ComputeSignature(*node);
+ uint64 sig = ComputeSignature(*node);
std::vector<NodeDef*>& candidates = rep_[sig];
for (auto& candidate : candidates) {
if (SameNode(*candidate, *node)) {
@@ -1796,26 +1798,25 @@ class UniqueNodes {
}
private:
- std::size_t ComputeSignature(const NodeDef& node) const;
+ uint64 ComputeSignature(const NodeDef& node) const;
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
- std::unordered_map<std::size_t, std::vector<NodeDef*>> rep_;
+ std::unordered_map<uint64, std::vector<NodeDef*>> rep_;
};
-std::size_t UniqueNodes::ComputeSignature(const NodeDef& node) const {
- std::size_t h = std::hash<string>{}(node.op());
- h ^= std::hash<string>{}(node.device());
+uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
+ uint64 h = Hash64(node.op());
+ h = Hash64Combine(Hash64(node.device()), h);
+
for (const auto& input : node.input()) {
int pos;
string node_name = ParseNodeName(input, &pos);
- h ^= std::hash<string>{}(node_name);
- h ^= static_cast<std::size_t>(pos);
+ h = Hash64CombineUnordered(Hash64(node_name), h);
+ h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
- h ^= std::hash<string>{}(attr.first);
- string tmp;
- attr.second.AppendToString(&tmp);
- h ^= std::hash<string>{}(tmp);
+ h = Hash64CombineUnordered(Hash64(attr.first), h);
+ h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
}
return h;
}
@@ -1871,17 +1872,8 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
}
for (const auto& attr1 : node1.attr()) {
auto it = node2.attr().find(attr1.first);
- if (it == node2.attr().end()) {
- return false;
- }
- const auto& attr2 = *it;
- string val1;
- attr1.second.AppendToString(&val1);
- string val2;
- attr2.second.AppendToString(&val2);
- if (val1 != val2) {
- return false;
- }
+ if (it == node2.attr().end()) return false;
+ if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
}
return true;