aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/protobuf/compare.py
blob: 19f7128f4e63ae28dc1c69449173617697ec1899 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
#!/usr/bin/python2.4

"""Utility functions for comparing proto2 messages in Python.

Proto2Cmp() is a cmp-style comparison function. It can be passed to sort(), etc.
See its docstring for details.

ClearDefaultValuedFields() recursively clears the fields that are set to their
default values. This is useful for comparing protocol buffers where the
semantics of unset fields and default valued fields are the same.

NormalizeRepeatedFields() sorts and optionally de-dupes repeated fields. This
is useful for treating repeated fields as sets instead of lists.

assertProto2Equal() and assertProto2SameElements() are useful for unit tests.
They produce much more helpful output than assertEqual() and friends for proto2
messages, e.g. this:

  outer {
    inner {
-     strings: "x"
?               ^
+     strings: "y"
?               ^
    }
  }

...compared to the default output from assertEqual() that looks like this:

AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>

Call them inside your unit test's googletest.TestCase subclasses like this:

  from tensorflow.python.util.protobuf import compare

  class MyTest(googletest.TestCase):
    ...
    def testXXX(self):
      ...
      compare.assertProto2Equal(self, a, b)
      compare.assertProto2SameElements(self, a, c)

Alternatively:

  from tensorflow.python.util.protobuf import compare

  class MyTest(compare.Proto2Assertions, googletest.TestCase):
    ...
    def testXXX(self):
      ...
      self.assertProto2Equal(a, b)
      self.assertProto2SameElements(a, c)
"""

import copy

from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import text_format


def assertProto2Equal(self, a, b, check_initialized=True,
                      normalize_numbers=False, msg=None):
  """Fails with a useful error if a and b aren't equal.

  Comparison of repeated fields matches the semantics of
  unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.

  Args:
    self: googletest.TestCase
    a: proto2 PB instance, or text string representing one
    b: proto2 PB instance -- message.Message or subclass thereof
    check_initialized: boolean, whether to fail if either a or b isn't
      initialized
    normalize_numbers: boolean, whether to normalize types and precision of
      numbers before comparison.
    msg: if specified, is used as the error message on failure
  """
  if isinstance(a, basestring):
    a = text_format.Merge(a, b.__class__())

  for pb in a, b:
    if check_initialized:
      errors = pb.FindInitializationErrors()
      if errors:
        self.fail('Initialization errors: %s\n%s' % (errors, pb))
    if normalize_numbers:
      NormalizeNumberFields(pb)

  self.assertMultiLineEqual(text_format.MessageToString(a),
                            text_format.MessageToString(b),
                            msg=msg)


def assertProto2SameElements(self, a, b, number_matters=False,
                             check_initialized=True, normalize_numbers=False,
                             msg=None):
  """Fails with a useful error if a and b aren't equivalent.

  When comparing repeated fields, order doesn't matter and the number of times
  each element appears (ie duplicates) only matters if number_matters is True.

  By default, comparison of repeated fields follows set semantics and matches
  googletest.TestCase.assertSameElements(): neither order nor number of a given
  element matters.

  Args:
    self: googletest.TestCase
    a: proto2 PB instance, or text string representing one
    b: proto2 PB instance -- message.Message or subclass thereof
    number_matters: boolean, whether number of each elements must match
    check_initialized: boolean, whether to fail if either a or b isn't
      initialized
    normalize_numbers: boolean, whether to normalize types and precision of
      numbers before comparison.
    msg: if specified, is used as the error message on failure
  """
  if isinstance(a, basestring):
    a = text_format.Merge(a, b.__class__())
  else:
    a = copy.deepcopy(a)
  b = copy.deepcopy(b)
  for pb in a, b:
    NormalizeRepeatedFields(pb, dedupe=not number_matters)
  assertProto2Equal(
      self, a, b, check_initialized=check_initialized,
      normalize_numbers=normalize_numbers, msg=msg)


def assertProto2Contains(self, a, b,  # pylint: disable=invalid-name
                         number_matters=False, check_initialized=True,
                         msg=None):
  """Fails with a useful error if fields in a are not in b.

  Useful to test if expected fields are in b, allows tests to define
  expected fields in string format.

  Example:
    compare.assertProto2Contains('group { field: "value" }', test_pb2)

  Args:
    self: googletest.TestCase
    a: proto2 PB instance, or text string representing one
    b: proto2 PB instance
    number_matters: boolean, whether number of each field must match
    check_initialized: boolean, whether to fail if b isn't initialized
    msg: if specified, is used as the error message on failure
  """
  if isinstance(a, basestring):
    a = text_format.Merge(a, b.__class__())
  else:
    a = copy.deepcopy(a)
  completed_a = copy.deepcopy(b)
  completed_a.MergeFrom(a)
  assertProto2SameElements(self, completed_a, b, number_matters=number_matters,
                           check_initialized=check_initialized, msg=msg)


def ClearDefaultValuedFields(pb):
  """Clears all fields in a proto2 message that are set to their default values.

  The result has more compact text / json / binary representation. It's also
  easier to compare to other protos if the choice whether fields are not set or
  set to their default values doesn't change the proto buffer's semantics.

  Args:
    pb: A proto2 message.
  """
  for field, value in pb.ListFields():
    if field.type == field.TYPE_MESSAGE:
      if field.label == field.LABEL_REPEATED:
        for item in value:
          ClearDefaultValuedFields(item)
      else:
        ClearDefaultValuedFields(value)
        if field.label == field.LABEL_OPTIONAL and not value.ListFields():
          pb.ClearField(field.name)
    elif field.label == field.LABEL_OPTIONAL and value == field.default_value:
      pb.ClearField(field.name)


def NormalizeRepeatedFields(pb, dedupe=True):
  """Sorts all repeated fields and optionally removes duplicates.

  Modifies pb in place. Recurses into nested objects. Uses Proto2Cmp for
  sorting.

  Args:
    pb: proto2 message
    dedupe: boolean, whether to remove duplicates

  Returns: the given pb, modified in place
  """
  for desc, values in pb.ListFields():
    if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
      values = [values]

    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
        desc.message_type.has_options and
        desc.message_type.GetOptions().map_entry):
      # This is a map, only recurse if the values have a message type.
      if (desc.message_type.fields_by_number[2].type ==
          descriptor.FieldDescriptor.TYPE_MESSAGE):
        for v in values.itervalues():
          NormalizeRepeatedFields(v, dedupe=dedupe)
    else:
      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
          desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
        for v in values:
          # recursive step
          NormalizeRepeatedFields(v, dedupe=dedupe)

      values.sort(Proto2Cmp)

      if dedupe:
        # De-dupe in place. Can't use set, etc. because messages aren't
        # hashable.  This is a heavily discussed toy problem. the code below is
        # a simplified version of http://code.activestate.com/recipes/52560/
        # and it requires that values is sorted.
        for i in xrange(len(values) - 1, 0, -1):
          if values[i] == values[i - 1]:
            del values[i]

  return pb


def NormalizeNumberFields(pb):
  """Normalizes types and precisions of number fields in a protocol buffer.

  Due to subtleties in the python protocol buffer implementation, it is possible
  for values to have different types and precision depending on whether they
  were set and retrieved directly or deserialized from a protobuf. This function
  normalizes integer values to ints and longs based on width, 32-bit floats to
  five digits of precision to account for python always storing them as 64-bit,
  and ensures doubles are floating point for when they're set to integers.

  Modifies pb in place. Recurses into nested objects.

  Args:
    pb: proto2 message

  Returns:
    the given pb, modified in place
  """
  for desc, values in pb.ListFields():
    is_repeated = True
    if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
      is_repeated = False
      values = [values]

    normalized_values = None

    # We force 32-bit values to int and 64-bit values to long to make
    # alternate implementations where the distinction is more significant
    # (e.g. the C++ implementation) simpler.
    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
                     descriptor.FieldDescriptor.TYPE_UINT64,
                     descriptor.FieldDescriptor.TYPE_SINT64):
      normalized_values = [long(x) for x in values]
    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
                       descriptor.FieldDescriptor.TYPE_UINT32,
                       descriptor.FieldDescriptor.TYPE_SINT32,
                       descriptor.FieldDescriptor.TYPE_ENUM):
      normalized_values = [int(x) for x in values]
    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
      normalized_values = [round(x, 6) for x in values]
    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
      normalized_values = [round(float(x), 7) for x in values]

    if normalized_values is not None:
      if is_repeated:
        pb.ClearField(desc.name)
        getattr(pb, desc.name).extend(normalized_values)
      else:
        setattr(pb, desc.name, normalized_values[0])

    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
          desc.message_type.has_options and
          desc.message_type.GetOptions().map_entry):
        # This is a map, only recurse if the values have a message type.
        if (desc.message_type.fields_by_number[2].type ==
            descriptor.FieldDescriptor.TYPE_MESSAGE):
          for v in values.itervalues():
            NormalizeNumberFields(v)
      else:
        for v in values:
          # recursive step
          NormalizeNumberFields(v)

  return pb


def _IsRepeatedContainer(value):
  if isinstance(value, basestring):
    return False
  try:
    iter(value)
    return True
  except TypeError:
    return False


def Proto2Cmp(a, b):
  """Compares two proto2 objects field by field, in ascending tag order.

  Recurses into nested messages. Uses list (not set) semantics for comparing
  repeated fields, ie duplicates and order matter. If one field is a prefix of
  the other, the longer field is greater.

  This function is intended to be used as a python cmp function, e.g. in sort.

  Ordering fields by tag number has precedent in other google code, but it's
  still somewhat arbitrary. The main value is to provide *some* stable ordering
  for proto2 messages.

  This would be easier as a__cmp__ method or set of __le__, __gt__, etc methods
  in the proto2 Message class itself. That would take a little more care,
  though, and probably some significant debate over whether they should exist at
  all, so this was easier.

  Args:
    a, b: proto2 messages or primitives

  Returns: integer > 0 if a > b, < 0 if a < b, 0 if a == b
  """
  def Format(pb):
    """Returns a dictionary that maps tag number (for messages) or element index
    (for repeated fields) to value, or just pb unchanged if it's neither."""
    if isinstance(pb, message.Message):
      return dict((desc.number, value) for desc, value in pb.ListFields())
    elif _IsRepeatedContainer(pb):
      return dict(enumerate(list(pb)))
    else:
      return pb

  a, b = Format(a), Format(b)

  # base case
  if not isinstance(a, dict) or not isinstance(b, dict):
    return cmp(a, b)

  # this list performs double duty: it compares two messages by tag value *or*
  # two repeated fields by element, in order. the magic is in the format()
  # function, which converts them both to the same easily comparable format.
  for tag in sorted(set(a.keys() + b.keys())):
    if tag not in a:
      return -1  # b is greater
    elif tag not in b:
      return 1   # a is greater
    else:
      # recursive step
      cmped = Proto2Cmp(a[tag], b[tag])
      if cmped != 0:
        return cmped

  # didn't find any values that differed, so they're equal!
  return 0


class Proto2Assertions(object):
  """Mix this into a googletest.TestCase class to get proto2 assertions.

  Usage:

  class SomeTestCase(compare.Proto2Assertions, googletest.TestCase):
    ...
    def testSomething(self):
      ...
      self.assertProto2Equal(a, b)

  See module-level definitions for method documentation.
  """

  # pylint: disable=invalid-name
  def assertProto2Equal(self, *args, **kwargs):
    return assertProto2Equal(self, *args, **kwargs)

  def assertProto2SameElements(self, *args, **kwargs):
    return assertProto2SameElements(self, *args, **kwargs)

  def assertProto2Contains(self, *args, **kwargs):
    return assertProto2Contains(self, *args, **kwargs)