aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/message_test.py
diff options
context:
space:
mode:
authorGravatar jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2014-08-12 21:10:30 +0000
committerGravatar jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2014-08-12 21:10:30 +0000
commitbde4a3254a7de58911941b0fbf38e9dd992de973 (patch)
tree02b151c2ec6e9be2e9d5ea0efc406aabe6958ae7 /python/google/protobuf/internal/message_test.py
parentd7339318a33c5f9e8b5dded4077223fbd4ebf229 (diff)
down integrate python opensource to svn
Diffstat (limited to 'python/google/protobuf/internal/message_test.py')
-rwxr-xr-xpython/google/protobuf/internal/message_test.py274
1 files changed, 228 insertions, 46 deletions
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 53e9d507..f4c4ae00 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -47,9 +47,9 @@ import copy
import math
import operator
import pickle
+import sys
-import unittest
-from google.protobuf import unittest_import_pb2
+from google.apputils import basetest
from google.protobuf import unittest_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
@@ -68,10 +68,23 @@ def IsPosInf(val):
def IsNegInf(val):
return isinf(val) and (val < 0)
-class MessageTest(unittest.TestCase):
+
+class MessageTest(basetest.TestCase):
+
+ def testBadUtf8String(self):
+ if api_implementation.Type() != 'python':
+ self.skipTest("Skipping testBadUtf8String, currently only the python "
+ "api implementation raises UnicodeDecodeError when a "
+ "string field contains bad utf-8.")
+ bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
+ with self.assertRaises(UnicodeDecodeError) as context:
+ unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
+ self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
+ str(context.exception))
def testGoldenMessage(self):
- golden_data = test_util.GoldenFile('golden_message').read()
+ golden_data = test_util.GoldenFileData(
+ 'golden_message_oneof_implemented')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
test_util.ExpectAllFieldsSet(self, golden_message)
@@ -80,7 +93,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenExtensions(self):
- golden_data = test_util.GoldenFile('golden_message').read()
+ golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestAllExtensions()
@@ -91,7 +104,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenPackedMessage(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedTypes()
@@ -102,7 +115,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
golden_message = unittest_pb2.TestPackedExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedExtensions()
@@ -113,7 +126,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_copy.SerializeToString())
def testPickleSupport(self):
- golden_data = test_util.GoldenFile('golden_message').read()
+ golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
pickled_message = pickle.dumps(golden_message)
@@ -121,6 +134,7 @@ class MessageTest(unittest.TestCase):
unpickled_message = pickle.loads(pickled_message)
self.assertEquals(unpickled_message, golden_message)
+
def testPickleIncompleteProto(self):
golden_message = unittest_pb2.TestRequired(a=1)
pickled_message = pickle.dumps(golden_message)
@@ -132,10 +146,10 @@ class MessageTest(unittest.TestCase):
self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
def testPositiveInfinity(self):
- golden_data = ('\x5D\x00\x00\x80\x7F'
- '\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- '\xCD\x02\x00\x00\x80\x7F'
- '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCD\x02\x00\x00\x80\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.optional_float))
@@ -145,10 +159,10 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_message.SerializeToString())
def testNegativeInfinity(self):
- golden_data = ('\x5D\x00\x00\x80\xFF'
- '\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- '\xCD\x02\x00\x00\x80\xFF'
- '\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCD\x02\x00\x00\x80\xFF'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.optional_float))
@@ -158,10 +172,10 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_message.SerializeToString())
def testNotANumber(self):
- golden_data = ('\x5D\x00\x00\xC0\x7F'
- '\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
- '\xCD\x02\x00\x00\xC0\x7F'
- '\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ golden_data = (b'\x5D\x00\x00\xC0\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
+ b'\xCD\x02\x00\x00\xC0\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.optional_float))
@@ -182,8 +196,8 @@ class MessageTest(unittest.TestCase):
self.assertTrue(isnan(message.repeated_double[0]))
def testPositiveInfinityPacked(self):
- golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F'
- '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.packed_float[0]))
@@ -191,8 +205,8 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_message.SerializeToString())
def testNegativeInfinityPacked(self):
- golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF'
- '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.packed_float[0]))
@@ -200,8 +214,8 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_message.SerializeToString())
def testNotANumberPacked(self):
- golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F'
- '\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(isnan(golden_message.packed_float[0]))
@@ -303,6 +317,26 @@ class MessageTest(unittest.TestCase):
message.ParseFromString(message.SerializeToString())
self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
+ def testFloatPrinting(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_float = 2.0
+ self.assertEqual(str(message), 'optional_float: 2.0\n')
+
+ def testHighPrecisionFloatPrinting(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_double = 0.12345678912345678
+ if sys.version_info.major >= 3:
+ self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
+ else:
+ self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
+
+ def testUnknownFieldPrinting(self):
+ populated = unittest_pb2.TestAllTypes()
+ test_util.SetAllNonLazyFields(populated)
+ empty = unittest_pb2.TestEmptyMessage()
+ empty.ParseFromString(populated.SerializeToString())
+ self.assertEqual(str(empty), '')
+
def testSortingRepeatedScalarFieldsDefaultComparator(self):
"""Check some different types with the default comparator."""
message = unittest_pb2.TestAllTypes()
@@ -332,13 +366,13 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_string[1], 'b')
self.assertEqual(message.repeated_string[2], 'c')
- message.repeated_bytes.append('a')
- message.repeated_bytes.append('c')
- message.repeated_bytes.append('b')
+ message.repeated_bytes.append(b'a')
+ message.repeated_bytes.append(b'c')
+ message.repeated_bytes.append(b'b')
message.repeated_bytes.sort()
- self.assertEqual(message.repeated_bytes[0], 'a')
- self.assertEqual(message.repeated_bytes[1], 'b')
- self.assertEqual(message.repeated_bytes[2], 'c')
+ self.assertEqual(message.repeated_bytes[0], b'a')
+ self.assertEqual(message.repeated_bytes[1], b'b')
+ self.assertEqual(message.repeated_bytes[2], b'c')
def testSortingRepeatedScalarFieldsCustomComparator(self):
"""Check some different types with custom comparator."""
@@ -347,7 +381,7 @@ class MessageTest(unittest.TestCase):
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
message.repeated_int32.append(-1)
- message.repeated_int32.sort(lambda x,y: cmp(abs(x), abs(y)))
+ message.repeated_int32.sort(key=abs)
self.assertEqual(message.repeated_int32[0], -1)
self.assertEqual(message.repeated_int32[1], -2)
self.assertEqual(message.repeated_int32[2], -3)
@@ -355,7 +389,7 @@ class MessageTest(unittest.TestCase):
message.repeated_string.append('aaa')
message.repeated_string.append('bb')
message.repeated_string.append('c')
- message.repeated_string.sort(lambda x,y: cmp(len(x), len(y)))
+ message.repeated_string.sort(key=len)
self.assertEqual(message.repeated_string[0], 'c')
self.assertEqual(message.repeated_string[1], 'bb')
self.assertEqual(message.repeated_string[2], 'aaa')
@@ -370,7 +404,7 @@ class MessageTest(unittest.TestCase):
message.repeated_nested_message.add().bb = 6
message.repeated_nested_message.add().bb = 5
message.repeated_nested_message.add().bb = 4
- message.repeated_nested_message.sort(lambda x,y: cmp(x.bb, y.bb))
+ message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
self.assertEqual(message.repeated_nested_message[0].bb, 1)
self.assertEqual(message.repeated_nested_message[1].bb, 2)
self.assertEqual(message.repeated_nested_message[2].bb, 3)
@@ -396,6 +430,7 @@ class MessageTest(unittest.TestCase):
message.repeated_nested_message.sort(key=get_bb, reverse=True)
self.assertEqual([k.bb for k in message.repeated_nested_message],
[6, 5, 4, 3, 2, 1])
+ if sys.version_info.major >= 3: return # No cmp sorting in PY3.
message.repeated_nested_message.sort(sort_function=cmp_bb)
self.assertEqual([k.bb for k in message.repeated_nested_message],
[1, 2, 3, 4, 5, 6])
@@ -407,7 +442,6 @@ class MessageTest(unittest.TestCase):
"""Check sorting a scalar field using list.sort() arguments."""
message = unittest_pb2.TestAllTypes()
- abs_cmp = lambda a, b: cmp(abs(a), abs(b))
message.repeated_int32.append(-3)
message.repeated_int32.append(-2)
message.repeated_int32.append(-1)
@@ -415,12 +449,13 @@ class MessageTest(unittest.TestCase):
self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
message.repeated_int32.sort(key=abs, reverse=True)
self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
- message.repeated_int32.sort(sort_function=abs_cmp)
- self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
- message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
- self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
+ if sys.version_info.major < 3: # No cmp sorting in PY3.
+ abs_cmp = lambda a, b: cmp(abs(a), abs(b))
+ message.repeated_int32.sort(sort_function=abs_cmp)
+ self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
+ message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
+ self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
- len_cmp = lambda a, b: cmp(len(a), len(b))
message.repeated_string.append('aaa')
message.repeated_string.append('bb')
message.repeated_string.append('c')
@@ -428,10 +463,47 @@ class MessageTest(unittest.TestCase):
self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
message.repeated_string.sort(key=len, reverse=True)
self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
- message.repeated_string.sort(sort_function=len_cmp)
- self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
- message.repeated_string.sort(cmp=len_cmp, reverse=True)
- self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
+ if sys.version_info.major < 3: # No cmp sorting in PY3.
+ len_cmp = lambda a, b: cmp(len(a), len(b))
+ message.repeated_string.sort(sort_function=len_cmp)
+ self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
+ message.repeated_string.sort(cmp=len_cmp, reverse=True)
+ self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
+
+ def testRepeatedFieldsComparable(self):
+ m1 = unittest_pb2.TestAllTypes()
+ m2 = unittest_pb2.TestAllTypes()
+ m1.repeated_int32.append(0)
+ m1.repeated_int32.append(1)
+ m1.repeated_int32.append(2)
+ m2.repeated_int32.append(0)
+ m2.repeated_int32.append(1)
+ m2.repeated_int32.append(2)
+ m1.repeated_nested_message.add().bb = 1
+ m1.repeated_nested_message.add().bb = 2
+ m1.repeated_nested_message.add().bb = 3
+ m2.repeated_nested_message.add().bb = 1
+ m2.repeated_nested_message.add().bb = 2
+ m2.repeated_nested_message.add().bb = 3
+
+ if sys.version_info.major >= 3: return # No cmp() in PY3.
+
+ # These comparisons should not raise errors.
+ _ = m1 < m2
+ _ = m1.repeated_nested_message < m2.repeated_nested_message
+
+ # Make sure cmp always works. If it wasn't defined, these would be
+ # id() comparisons and would all fail.
+ self.assertEqual(cmp(m1, m2), 0)
+ self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
+ self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
+ self.assertEqual(cmp(m1.repeated_nested_message,
+ m2.repeated_nested_message), 0)
+ with self.assertRaises(TypeError):
+ # Can't compare repeated composite containers to lists.
+ cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
+
+ # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
def testParsingMerge(self):
"""Check the merge behavior when a required or optional field appears
@@ -482,6 +554,87 @@ class MessageTest(unittest.TestCase):
self.assertEqual(len(parsing_merge.Extensions[
unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+ def ensureNestedMessageExists(self, msg, attribute):
+ """Make sure that a nested message object exists.
+
+ As soon as a nested message attribute is accessed, it will be present in the
+ _fields dict, without being marked as actually being set.
+ """
+ getattr(msg, attribute)
+ self.assertFalse(msg.HasField(attribute))
+
+ def testOneofGetCaseNonexistingField(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
+
+ def testOneofSemantics(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ m.oneof_uint32 = 11
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+
+ m.oneof_string = u'foo'
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertTrue(m.HasField('oneof_string'))
+
+ m.oneof_nested_message.bb = 11
+ self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_string'))
+ self.assertTrue(m.HasField('oneof_nested_message'))
+
+ m.oneof_bytes = b'bb'
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_nested_message'))
+ self.assertTrue(m.HasField('oneof_bytes'))
+
+ def testOneofCompositeFieldReadAccess(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+
+ self.ensureNestedMessageExists(m, 'oneof_nested_message')
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertEqual(11, m.oneof_uint32)
+
+ def testOneofHasField(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertFalse(m.HasField('oneof_field'))
+ m.oneof_uint32 = 11
+ self.assertTrue(m.HasField('oneof_field'))
+ m.oneof_bytes = b'bb'
+ self.assertTrue(m.HasField('oneof_field'))
+ m.ClearField('oneof_bytes')
+ self.assertFalse(m.HasField('oneof_field'))
+
+ def testOneofClearField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ m.ClearField('oneof_field')
+ self.assertFalse(m.HasField('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ def testOneofClearSetField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ m.ClearField('oneof_uint32')
+ self.assertFalse(m.HasField('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ def testOneofClearUnsetField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ self.ensureNestedMessageExists(m, 'oneof_nested_message')
+ m.ClearField('oneof_nested_message')
+ self.assertEqual(11, m.oneof_uint32)
+ self.assertTrue(m.HasField('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+
+
def testSortEmptyRepeatedCompositeContainer(self):
"""Exercise a scenario that has led to segfaults in the past.
@@ -489,6 +642,35 @@ class MessageTest(unittest.TestCase):
m = unittest_pb2.TestAllTypes()
m.repeated_nested_message.sort()
+ def testHasFieldOnRepeatedField(self):
+ """Using HasField on a repeated field should raise an exception.
+ """
+ m = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError) as _:
+ m.HasField('repeated_int32')
+
+
+class ValidTypeNamesTest(basetest.TestCase):
+
+ def assertImportFromName(self, msg, base_name):
+ # Parse <type 'module.class_name'> to extra 'some.name' as a string.
+ tp_name = str(type(msg)).split("'")[1]
+ valid_names = ('Repeated%sContainer' % base_name,
+ 'Repeated%sFieldContainer' % base_name)
+ self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
+ '%r does end with any of %r' % (tp_name, valid_names))
+
+ parts = tp_name.split('.')
+ class_name = parts[-1]
+ module_name = '.'.join(parts[:-1])
+ __import__(module_name, fromlist=[class_name])
+
+ def testTypeNamesCanBeImported(self):
+ # If import doesn't work, pickling won't work either.
+ pb = unittest_pb2.TestAllTypes()
+ self.assertImportFromName(pb.repeated_int32, 'Scalar')
+ self.assertImportFromName(pb.repeated_nested_message, 'Composite')
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()