From b55a20fa2c669b181f47ea9219b8e74d1263da19 Mon Sep 17 00:00:00 2001 From: "xiaofeng@google.com" Date: Sat, 22 Sep 2012 02:40:50 +0000 Subject: Down-integrate from internal branch --- python/google/protobuf/internal/message_test.py | 181 ++++++++++++++++++++++-- 1 file changed, 166 insertions(+), 15 deletions(-) (limited to 'python/google/protobuf/internal/message_test.py') diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 65174373..53e9d507 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -45,10 +45,15 @@ __author__ = 'gps@google.com (Gregory P. Smith)' import copy import math +import operator +import pickle + import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf import message # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. @@ -70,9 +75,9 @@ class MessageTest(unittest.TestCase): golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): golden_data = test_util.GoldenFile('golden_message').read() @@ -81,9 +86,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -92,9 +97,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -103,9 +108,28 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedExtensions() test_util.SetAllPackedExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleSupport(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + self.assertEquals(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) def testPositiveInfinity(self): golden_data = ('\x5D\x00\x00\x80\x7F' @@ -118,7 +142,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsPosInf(golden_message.optional_double)) self.assertTrue(IsPosInf(golden_message.repeated_float[0])) self.assertTrue(IsPosInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self): golden_data = ('\x5D\x00\x00\x80\xFF' @@ -131,7 +155,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsNegInf(golden_message.optional_double)) self.assertTrue(IsNegInf(golden_message.repeated_float[0])) self.assertTrue(IsNegInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumber(self): golden_data = ('\x5D\x00\x00\xC0\x7F' @@ -144,7 +168,18 @@ class MessageTest(unittest.TestCase): self.assertTrue(isnan(golden_message.optional_double)) self.assertTrue(isnan(golden_message.repeated_float[0])) self.assertTrue(isnan(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + # The protocol buffer may serialize to any one of multiple different + # representations of a NaN. Rather than verify a specific representation, + # verify the serialized string can be converted into a correctly + # behaving protocol buffer. + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestAllTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.optional_float)) + self.assertTrue(isnan(message.optional_double)) + self.assertTrue(isnan(message.repeated_float[0])) + self.assertTrue(isnan(message.repeated_double[0])) def testPositiveInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' @@ -153,7 +188,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) self.assertTrue(IsPosInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' @@ -162,7 +197,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) self.assertTrue(IsNegInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumberPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' @@ -171,7 +206,12 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) self.assertTrue(isnan(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestPackedTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.packed_float[0])) + self.assertTrue(isnan(message.packed_double[0])) def testExtremeFloatValues(self): message = unittest_pb2.TestAllTypes() @@ -218,7 +258,7 @@ class MessageTest(unittest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - def testExtremeFloatValues(self): + def testExtremeDoubleValues(self): message = unittest_pb2.TestAllTypes() # Most positive exponent, no significand bits set. @@ -338,6 +378,117 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + def testRepeatedCompositeFieldSortArguments(self): + """Check sorting a repeated composite field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + get_bb = operator.attrgetter('bb') + cmp_bb = lambda a, b: cmp(a.bb, b.bb) + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=get_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(key=get_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + message.repeated_nested_message.sort(sort_function=cmp_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + + def testRepeatedScalarFieldSortArguments(self): + """Check sorting a scalar field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(key=abs, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(key=len, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + + def testSortEmptyRepeatedCompositeContainer(self): + """Exercise a scenario that has led to segfaults in the past. + """ + m = unittest_pb2.TestAllTypes() + m.repeated_nested_message.sort() + if __name__ == '__main__': unittest.main() -- cgit v1.2.3