From 12116d724cbefe4f49d85ba7dc96b6419ae9c865 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Mon, 25 Sep 2017 23:49:12 -0700 Subject: Move SerializeToStringDeterministic to a header file It is a useful function for computing hash values. PiperOrigin-RevId: 170013135 --- tensorflow/core/framework/attr_value_util.cc | 27 ++++++--------------------- tensorflow/core/lib/hash/hash.cc | 11 +++++++++++ tensorflow/core/lib/hash/hash.h | 10 ++++++++++ 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 51957ecbfc..5aba091840 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -405,21 +405,6 @@ void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { } } -// Wrapper around protocol buffer serialization that requests deterministic -// serialization, in particular for Map fields, which serialize in a random -// order by default. Returns true on success. -template -static bool DeterministicSerialization(const T& t, string* result) { - const int size = t.ByteSize(); - *result = string(size, '\0'); - ::tensorflow::protobuf::io::ArrayOutputStream array_stream(&(*result)[0], - size); - ::tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream); - output_stream.SetSerializationDeterministic(true); - t.SerializeWithCachedSizes(&output_stream); - return !output_stream.HadError() && size == output_stream.ByteCount(); -} - bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { // There are multiple equivalent representations of attr values containing // TensorProtos. Compare them by constructing Tensors and serializing them @@ -442,8 +427,8 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { bt.AsProtoTensorContent(&bp); string a_str, b_str; - DeterministicSerialization(ap, &a_str); - DeterministicSerialization(bp, &b_str); + SerializeToStringDeterministic(ap, &a_str); + SerializeToStringDeterministic(bp, &b_str); return a_str == b_str; } @@ -470,8 +455,8 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { // All other fields in AttrValue have deterministic representations. // It is safe to compare their serialized strings. string a_str, b_str; - DeterministicSerialization(a, &a_str); - DeterministicSerialization(b, &b_str); + SerializeToStringDeterministic(a, &a_str); + SerializeToStringDeterministic(b, &b_str); return a_str == b_str; } @@ -486,7 +471,7 @@ uint64 AttrValueHash(const AttrValue& a) { TensorProto p; tensor.AsProtoTensorContent(&p); string s; - DeterministicSerialization(p, &s); + SerializeToStringDeterministic(p, &s); return Hash64(s); } if (a.has_func()) { @@ -502,7 +487,7 @@ uint64 AttrValueHash(const AttrValue& a) { // If `a` is not a tensor or func, get a hash of serialized string. string s; - DeterministicSerialization(a, &s); + SerializeToStringDeterministic(a, &s); return Hash64(s); } diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/lib/hash/hash.cc index dc9d300d00..ed9b4df37a 100644 --- a/tensorflow/core/lib/hash/hash.cc +++ b/tensorflow/core/lib/hash/hash.cc @@ -126,4 +126,15 @@ uint64 Hash64(const char* data, size_t n, uint64 seed) { return h; } +bool SerializeToStringDeterministic(const protobuf::MessageLite& msg, + string* result) { + const size_t size = msg.ByteSizeLong(); + *result = string(size, '\0'); + protobuf::io::ArrayOutputStream array_stream(&(*result)[0], size); + protobuf::io::CodedOutputStream output_stream(&array_stream); + output_stream.SetSerializationDeterministic(true); + msg.SerializeWithCachedSizes(&output_stream); + return !output_stream.HadError() && size == output_stream.ByteCount(); +} + } // namespace tensorflow diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h index 77b8031598..0fb12966af 100644 --- a/tensorflow/core/lib/hash/hash.h +++ b/tensorflow/core/lib/hash/hash.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -84,6 +85,15 @@ struct hash> { } }; +// Wrapper around protocol buffer serialization that requests deterministic +// serialization, in particular for Map fields, which serialize in a random +// order by default. Returns true on success. +// Serialization is guaranteed to be deterministic for a given binary only. +// See the following for more details: +// https://github.com/google/protobuf/blob/a1bb147e96b6f74db6cdf3c3fcb00492472dbbfa/src/google/protobuf/io/coded_stream.h#L834 +bool SerializeToStringDeterministic(const protobuf::MessageLite& msg, + string* result); + } // namespace tensorflow #endif // TENSORFLOW_LIB_HASH_HASH_H_ -- cgit v1.2.3