diff options
author | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
---|---|---|
committer | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
commit | bde4a3254a7de58911941b0fbf38e9dd992de973 (patch) | |
tree | 02b151c2ec6e9be2e9d5ea0efc406aabe6958ae7 /python/google/protobuf/internal/message_test.py | |
parent | d7339318a33c5f9e8b5dded4077223fbd4ebf229 (diff) |
down integrate python opensource to svn
Diffstat (limited to 'python/google/protobuf/internal/message_test.py')
-rwxr-xr-x | python/google/protobuf/internal/message_test.py | 274 |
1 files changed, 228 insertions, 46 deletions
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 53e9d507..f4c4ae00 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -47,9 +47,9 @@ import copy import math import operator import pickle +import sys -import unittest -from google.protobuf import unittest_import_pb2 +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util @@ -68,10 +68,23 @@ def IsPosInf(val): def IsNegInf(val): return isinf(val) and (val < 0) -class MessageTest(unittest.TestCase): + +class MessageTest(basetest.TestCase): + + def testBadUtf8String(self): + if api_implementation.Type() != 'python': + self.skipTest("Skipping testBadUtf8String, currently only the python " + "api implementation raises UnicodeDecodeError when a " + "string field contains bad utf-8.") + bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') + with self.assertRaises(UnicodeDecodeError) as context: + unittest_pb2.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', + str(context.exception)) def testGoldenMessage(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) @@ -80,7 +93,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestAllExtensions() @@ -91,7 +104,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedTypes() @@ -102,7 +115,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedExtensions() @@ -113,7 +126,7 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_copy.SerializeToString()) def testPickleSupport(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) pickled_message = pickle.dumps(golden_message) @@ -121,6 +134,7 @@ class MessageTest(unittest.TestCase): unpickled_message = pickle.loads(pickled_message) self.assertEquals(unpickled_message, golden_message) + def testPickleIncompleteProto(self): golden_message = unittest_pb2.TestRequired(a=1) pickled_message = pickle.dumps(golden_message) @@ -132,10 +146,10 @@ class MessageTest(unittest.TestCase): self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) def testPositiveInfinity(self): - golden_data = ('\x5D\x00\x00\x80\x7F' - '\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - '\xCD\x02\x00\x00\x80\x7F' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) @@ -145,10 +159,10 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self): - golden_data = ('\x5D\x00\x00\x80\xFF' - '\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - '\xCD\x02\x00\x00\x80\xFF' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) @@ -158,10 +172,10 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumber(self): - golden_data = ('\x5D\x00\x00\xC0\x7F' - '\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' - '\xCD\x02\x00\x00\xC0\x7F' - '\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_data = (b'\x5D\x00\x00\xC0\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' + b'\xCD\x02\x00\x00\xC0\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) @@ -182,8 +196,8 @@ class MessageTest(unittest.TestCase): self.assertTrue(isnan(message.repeated_double[0])) def testPositiveInfinityPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) @@ -191,8 +205,8 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinityPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) @@ -200,8 +214,8 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumberPacked(self): - golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' - '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) @@ -303,6 +317,26 @@ class MessageTest(unittest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) + def testFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_float = 2.0 + self.assertEqual(str(message), 'optional_float: 2.0\n') + + def testHighPrecisionFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_double = 0.12345678912345678 + if sys.version_info.major >= 3: + self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') + else: + self.assertEqual(str(message), 'optional_double: 0.123456789123\n') + + def testUnknownFieldPrinting(self): + populated = unittest_pb2.TestAllTypes() + test_util.SetAllNonLazyFields(populated) + empty = unittest_pb2.TestEmptyMessage() + empty.ParseFromString(populated.SerializeToString()) + self.assertEqual(str(empty), '') + def testSortingRepeatedScalarFieldsDefaultComparator(self): """Check some different types with the default comparator.""" message = unittest_pb2.TestAllTypes() @@ -332,13 +366,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') - message.repeated_bytes.append('a') - message.repeated_bytes.append('c') - message.repeated_bytes.append('b') + message.repeated_bytes.append(b'a') + message.repeated_bytes.append(b'c') + message.repeated_bytes.append(b'b') message.repeated_bytes.sort() - self.assertEqual(message.repeated_bytes[0], 'a') - self.assertEqual(message.repeated_bytes[1], 'b') - self.assertEqual(message.repeated_bytes[2], 'c') + self.assertEqual(message.repeated_bytes[0], b'a') + self.assertEqual(message.repeated_bytes[1], b'b') + self.assertEqual(message.repeated_bytes[2], b'c') def testSortingRepeatedScalarFieldsCustomComparator(self): """Check some different types with custom comparator.""" @@ -347,7 +381,7 @@ class MessageTest(unittest.TestCase): message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) - message.repeated_int32.sort(lambda x,y: cmp(abs(x), abs(y))) + message.repeated_int32.sort(key=abs) self.assertEqual(message.repeated_int32[0], -1) self.assertEqual(message.repeated_int32[1], -2) self.assertEqual(message.repeated_int32[2], -3) @@ -355,7 +389,7 @@ class MessageTest(unittest.TestCase): message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') - message.repeated_string.sort(lambda x,y: cmp(len(x), len(y))) + message.repeated_string.sort(key=len) self.assertEqual(message.repeated_string[0], 'c') self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa') @@ -370,7 +404,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.add().bb = 6 message.repeated_nested_message.add().bb = 5 message.repeated_nested_message.add().bb = 4 - message.repeated_nested_message.sort(lambda x,y: cmp(x.bb, y.bb)) + message.repeated_nested_message.sort(key=operator.attrgetter('bb')) self.assertEqual(message.repeated_nested_message[0].bb, 1) self.assertEqual(message.repeated_nested_message[1].bb, 2) self.assertEqual(message.repeated_nested_message[2].bb, 3) @@ -396,6 +430,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.sort(key=get_bb, reverse=True) self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) + if sys.version_info.major >= 3: return # No cmp sorting in PY3. message.repeated_nested_message.sort(sort_function=cmp_bb) self.assertEqual([k.bb for k in message.repeated_nested_message], [1, 2, 3, 4, 5, 6]) @@ -407,7 +442,6 @@ class MessageTest(unittest.TestCase): """Check sorting a scalar field using list.sort() arguments.""" message = unittest_pb2.TestAllTypes() - abs_cmp = lambda a, b: cmp(abs(a), abs(b)) message.repeated_int32.append(-3) message.repeated_int32.append(-2) message.repeated_int32.append(-1) @@ -415,12 +449,13 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) message.repeated_int32.sort(key=abs, reverse=True) self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - message.repeated_int32.sort(sort_function=abs_cmp) - self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) - message.repeated_int32.sort(cmp=abs_cmp, reverse=True) - self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + if sys.version_info.major < 3: # No cmp sorting in PY3. + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) - len_cmp = lambda a, b: cmp(len(a), len(b)) message.repeated_string.append('aaa') message.repeated_string.append('bb') message.repeated_string.append('c') @@ -428,10 +463,47 @@ class MessageTest(unittest.TestCase): self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) message.repeated_string.sort(key=len, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - message.repeated_string.sort(sort_function=len_cmp) - self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) - message.repeated_string.sort(cmp=len_cmp, reverse=True) - self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + if sys.version_info.major < 3: # No cmp sorting in PY3. + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testRepeatedFieldsComparable(self): + m1 = unittest_pb2.TestAllTypes() + m2 = unittest_pb2.TestAllTypes() + m1.repeated_int32.append(0) + m1.repeated_int32.append(1) + m1.repeated_int32.append(2) + m2.repeated_int32.append(0) + m2.repeated_int32.append(1) + m2.repeated_int32.append(2) + m1.repeated_nested_message.add().bb = 1 + m1.repeated_nested_message.add().bb = 2 + m1.repeated_nested_message.add().bb = 3 + m2.repeated_nested_message.add().bb = 1 + m2.repeated_nested_message.add().bb = 2 + m2.repeated_nested_message.add().bb = 3 + + if sys.version_info.major >= 3: return # No cmp() in PY3. + + # These comparisons should not raise errors. + _ = m1 < m2 + _ = m1.repeated_nested_message < m2.repeated_nested_message + + # Make sure cmp always works. If it wasn't defined, these would be + # id() comparisons and would all fail. + self.assertEqual(cmp(m1, m2), 0) + self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) + self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) + self.assertEqual(cmp(m1.repeated_nested_message, + m2.repeated_nested_message), 0) + with self.assertRaises(TypeError): + # Can't compare repeated composite containers to lists. + cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) + + # TODO(anuraag): Implement extensiondict comparison in C++ and then add test def testParsingMerge(self): """Check the merge behavior when a required or optional field appears @@ -482,6 +554,87 @@ class MessageTest(unittest.TestCase): self.assertEqual(len(parsing_merge.Extensions[ unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def ensureNestedMessageExists(self, msg, attribute): + """Make sure that a nested message object exists. + + As soon as a nested message attribute is accessed, it will be present in the + _fields dict, without being marked as actually being set. + """ + getattr(msg, attribute) + self.assertFalse(msg.HasField(attribute)) + + def testOneofGetCaseNonexistingField(self): + m = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + + def testOneofSemantics(self): + m = unittest_pb2.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + + m.oneof_uint32 = 11 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + + m.oneof_string = u'foo' + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertTrue(m.HasField('oneof_string')) + + m.oneof_nested_message.bb = 11 + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_string')) + self.assertTrue(m.HasField('oneof_nested_message')) + + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_nested_message')) + self.assertTrue(m.HasField('oneof_bytes')) + + def testOneofCompositeFieldReadAccess(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + + self.ensureNestedMessageExists(m, 'oneof_nested_message') + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertEqual(11, m.oneof_uint32) + + def testOneofHasField(self): + m = unittest_pb2.TestAllTypes() + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' + self.assertTrue(m.HasField('oneof_field')) + m.ClearField('oneof_bytes') + self.assertFalse(m.HasField('oneof_field')) + + def testOneofClearField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_field') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearSetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_uint32') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearUnsetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + self.ensureNestedMessageExists(m, 'oneof_nested_message') + m.ClearField('oneof_nested_message') + self.assertEqual(11, m.oneof_uint32) + self.assertTrue(m.HasField('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + + def testSortEmptyRepeatedCompositeContainer(self): """Exercise a scenario that has led to segfaults in the past. @@ -489,6 +642,35 @@ class MessageTest(unittest.TestCase): m = unittest_pb2.TestAllTypes() m.repeated_nested_message.sort() + def testHasFieldOnRepeatedField(self): + """Using HasField on a repeated field should raise an exception. + """ + m = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError) as _: + m.HasField('repeated_int32') + + +class ValidTypeNamesTest(basetest.TestCase): + + def assertImportFromName(self, msg, base_name): + # Parse <type 'module.class_name'> to extra 'some.name' as a string. + tp_name = str(type(msg)).split("'")[1] + valid_names = ('Repeated%sContainer' % base_name, + 'Repeated%sFieldContainer' % base_name) + self.assertTrue(any(tp_name.endswith(v) for v in valid_names), + '%r does end with any of %r' % (tp_name, valid_names)) + + parts = tp_name.split('.') + class_name = parts[-1] + module_name = '.'.join(parts[:-1]) + __import__(module_name, fromlist=[class_name]) + + def testTypeNamesCanBeImported(self): + # If import doesn't work, pickling won't work either. + pb = unittest_pb2.TestAllTypes() + self.assertImportFromName(pb.repeated_int32, 'Scalar') + self.assertImportFromName(pb.repeated_nested_message, 'Composite') + if __name__ == '__main__': - unittest.main() + basetest.main() |