1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
|
#!/usr/bin/python2.4
"""Utility functions for comparing proto2 messages in Python.
Proto2Cmp() is a cmp-style comparison function. It can be passed to sort(), etc.
See its docstring for details.
ClearDefaultValuedFields() recursively clears the fields that are set to their
default values. This is useful for comparing protocol buffers where the
semantics of unset fields and default valued fields are the same.
NormalizeRepeatedFields() sorts and optionally de-dupes repeated fields. This
is useful for treating repeated fields as sets instead of lists.
assertProto2Equal() and assertProto2SameElements() are useful for unit tests.
They produce much more helpful output than assertEqual() and friends for proto2
messages, e.g. this:
outer {
inner {
- strings: "x"
? ^
+ strings: "y"
? ^
}
}
...compared to the default output from assertEqual() that looks like this:
AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
Call them inside your unit test's googletest.TestCase subclasses like this:
from tensorflow.python.util.protobuf import compare
class MyTest(googletest.TestCase):
...
def testXXX(self):
...
compare.assertProto2Equal(self, a, b)
compare.assertProto2SameElements(self, a, c)
Alternatively:
from tensorflow.python.util.protobuf import compare
class MyTest(compare.Proto2Assertions, googletest.TestCase):
...
def testXXX(self):
...
self.assertProto2Equal(a, b)
self.assertProto2SameElements(a, c)
"""
import copy
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import text_format
def assertProto2Equal(self, a, b, check_initialized=True,
normalize_numbers=False, msg=None):
"""Fails with a useful error if a and b aren't equal.
Comparison of repeated fields matches the semantics of
unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
Args:
self: googletest.TestCase
a: proto2 PB instance, or text string representing one
b: proto2 PB instance -- message.Message or subclass thereof
check_initialized: boolean, whether to fail if either a or b isn't
initialized
normalize_numbers: boolean, whether to normalize types and precision of
numbers before comparison.
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
a = text_format.Merge(a, b.__class__())
for pb in a, b:
if check_initialized:
errors = pb.FindInitializationErrors()
if errors:
self.fail('Initialization errors: %s\n%s' % (errors, pb))
if normalize_numbers:
NormalizeNumberFields(pb)
self.assertMultiLineEqual(text_format.MessageToString(a),
text_format.MessageToString(b),
msg=msg)
def assertProto2SameElements(self, a, b, number_matters=False,
check_initialized=True, normalize_numbers=False,
msg=None):
"""Fails with a useful error if a and b aren't equivalent.
When comparing repeated fields, order doesn't matter and the number of times
each element appears (ie duplicates) only matters if number_matters is True.
By default, comparison of repeated fields follows set semantics and matches
googletest.TestCase.assertSameElements(): neither order nor number of a given
element matters.
Args:
self: googletest.TestCase
a: proto2 PB instance, or text string representing one
b: proto2 PB instance -- message.Message or subclass thereof
number_matters: boolean, whether number of each elements must match
check_initialized: boolean, whether to fail if either a or b isn't
initialized
normalize_numbers: boolean, whether to normalize types and precision of
numbers before comparison.
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
a = text_format.Merge(a, b.__class__())
else:
a = copy.deepcopy(a)
b = copy.deepcopy(b)
for pb in a, b:
NormalizeRepeatedFields(pb, dedupe=not number_matters)
assertProto2Equal(
self, a, b, check_initialized=check_initialized,
normalize_numbers=normalize_numbers, msg=msg)
def assertProto2Contains(self, a, b, # pylint: disable=invalid-name
number_matters=False, check_initialized=True,
msg=None):
"""Fails with a useful error if fields in a are not in b.
Useful to test if expected fields are in b, allows tests to define
expected fields in string format.
Example:
compare.assertProto2Contains('group { field: "value" }', test_pb2)
Args:
self: googletest.TestCase
a: proto2 PB instance, or text string representing one
b: proto2 PB instance
number_matters: boolean, whether number of each field must match
check_initialized: boolean, whether to fail if b isn't initialized
msg: if specified, is used as the error message on failure
"""
if isinstance(a, basestring):
a = text_format.Merge(a, b.__class__())
else:
a = copy.deepcopy(a)
completed_a = copy.deepcopy(b)
completed_a.MergeFrom(a)
assertProto2SameElements(self, completed_a, b, number_matters=number_matters,
check_initialized=check_initialized, msg=msg)
def ClearDefaultValuedFields(pb):
"""Clears all fields in a proto2 message that are set to their default values.
The result has more compact text / json / binary representation. It's also
easier to compare to other protos if the choice whether fields are not set or
set to their default values doesn't change the proto buffer's semantics.
Args:
pb: A proto2 message.
"""
for field, value in pb.ListFields():
if field.type == field.TYPE_MESSAGE:
if field.label == field.LABEL_REPEATED:
for item in value:
ClearDefaultValuedFields(item)
else:
ClearDefaultValuedFields(value)
if field.label == field.LABEL_OPTIONAL and not value.ListFields():
pb.ClearField(field.name)
elif field.label == field.LABEL_OPTIONAL and value == field.default_value:
pb.ClearField(field.name)
def NormalizeRepeatedFields(pb, dedupe=True):
"""Sorts all repeated fields and optionally removes duplicates.
Modifies pb in place. Recurses into nested objects. Uses Proto2Cmp for
sorting.
Args:
pb: proto2 message
dedupe: boolean, whether to remove duplicates
Returns: the given pb, modified in place
"""
for desc, values in pb.ListFields():
if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
values = [values]
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
for v in values.itervalues():
NormalizeRepeatedFields(v, dedupe=dedupe)
else:
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
for v in values:
# recursive step
NormalizeRepeatedFields(v, dedupe=dedupe)
values.sort(Proto2Cmp)
if dedupe:
# De-dupe in place. Can't use set, etc. because messages aren't
# hashable. This is a heavily discussed toy problem. the code below is
# a simplified version of http://code.activestate.com/recipes/52560/
# and it requires that values is sorted.
for i in xrange(len(values) - 1, 0, -1):
if values[i] == values[i - 1]:
del values[i]
return pb
def NormalizeNumberFields(pb):
"""Normalizes types and precisions of number fields in a protocol buffer.
Due to subtleties in the python protocol buffer implementation, it is possible
for values to have different types and precision depending on whether they
were set and retrieved directly or deserialized from a protobuf. This function
normalizes integer values to ints and longs based on width, 32-bit floats to
five digits of precision to account for python always storing them as 64-bit,
and ensures doubles are floating point for when they're set to integers.
Modifies pb in place. Recurses into nested objects.
Args:
pb: proto2 message
Returns:
the given pb, modified in place
"""
for desc, values in pb.ListFields():
is_repeated = True
if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
is_repeated = False
values = [values]
normalized_values = None
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
descriptor.FieldDescriptor.TYPE_UINT64,
descriptor.FieldDescriptor.TYPE_SINT64):
normalized_values = [long(x) for x in values]
elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
descriptor.FieldDescriptor.TYPE_UINT32,
descriptor.FieldDescriptor.TYPE_SINT32,
descriptor.FieldDescriptor.TYPE_ENUM):
normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 6) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
normalized_values = [round(float(x), 7) for x in values]
if normalized_values is not None:
if is_repeated:
pb.ClearField(desc.name)
getattr(pb, desc.name).extend(normalized_values)
else:
setattr(pb, desc.name, normalized_values[0])
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
desc.message_type.has_options and
desc.message_type.GetOptions().map_entry):
# This is a map, only recurse if the values have a message type.
if (desc.message_type.fields_by_number[2].type ==
descriptor.FieldDescriptor.TYPE_MESSAGE):
for v in values.itervalues():
NormalizeNumberFields(v)
else:
for v in values:
# recursive step
NormalizeNumberFields(v)
return pb
def _IsRepeatedContainer(value):
if isinstance(value, basestring):
return False
try:
iter(value)
return True
except TypeError:
return False
def Proto2Cmp(a, b):
"""Compares two proto2 objects field by field, in ascending tag order.
Recurses into nested messages. Uses list (not set) semantics for comparing
repeated fields, ie duplicates and order matter. If one field is a prefix of
the other, the longer field is greater.
This function is intended to be used as a python cmp function, e.g. in sort.
Ordering fields by tag number has precedent in other google code, but it's
still somewhat arbitrary. The main value is to provide *some* stable ordering
for proto2 messages.
This would be easier as a__cmp__ method or set of __le__, __gt__, etc methods
in the proto2 Message class itself. That would take a little more care,
though, and probably some significant debate over whether they should exist at
all, so this was easier.
Args:
a, b: proto2 messages or primitives
Returns: integer > 0 if a > b, < 0 if a < b, 0 if a == b
"""
def Format(pb):
"""Returns a dictionary that maps tag number (for messages) or element index
(for repeated fields) to value, or just pb unchanged if it's neither."""
if isinstance(pb, message.Message):
return dict((desc.number, value) for desc, value in pb.ListFields())
elif _IsRepeatedContainer(pb):
return dict(enumerate(list(pb)))
else:
return pb
a, b = Format(a), Format(b)
# base case
if not isinstance(a, dict) or not isinstance(b, dict):
return cmp(a, b)
# this list performs double duty: it compares two messages by tag value *or*
# two repeated fields by element, in order. the magic is in the format()
# function, which converts them both to the same easily comparable format.
for tag in sorted(set(a.keys() + b.keys())):
if tag not in a:
return -1 # b is greater
elif tag not in b:
return 1 # a is greater
else:
# recursive step
cmped = Proto2Cmp(a[tag], b[tag])
if cmped != 0:
return cmped
# didn't find any values that differed, so they're equal!
return 0
class Proto2Assertions(object):
"""Mix this into a googletest.TestCase class to get proto2 assertions.
Usage:
class SomeTestCase(compare.Proto2Assertions, googletest.TestCase):
...
def testSomething(self):
...
self.assertProto2Equal(a, b)
See module-level definitions for method documentation.
"""
# pylint: disable=invalid-name
def assertProto2Equal(self, *args, **kwargs):
return assertProto2Equal(self, *args, **kwargs)
def assertProto2SameElements(self, *args, **kwargs):
return assertProto2SameElements(self, *args, **kwargs)
def assertProto2Contains(self, *args, **kwargs):
return assertProto2Contains(self, *args, **kwargs)
|