aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/reflection_test.py
diff options
context:
space:
mode:
authorGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
committerGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
commitfccb146e3fe437b0df1e9c50d4b8e1080ddb4bd9 (patch)
tree9f2d9fe0267d96a54e541377ffeada3d0bff0d1d /python/google/protobuf/internal/reflection_test.py
parentd5cf7b55a6a1f959d1646785f63ca2b62da78079 (diff)
Massive roll-up of changes. See CHANGES.txt.
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py354
1 files changed, 307 insertions, 47 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 86101774..2c9fa30b 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -38,6 +38,7 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)'
import operator
+import struct
import unittest
# TODO(robinson): When we split this test in two, only some of these imports
@@ -56,6 +57,51 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import decoder
+class _MiniDecoder(object):
+ """Decodes a stream of values from a string.
+
+ Once upon a time we actually had a class called decoder.Decoder. Then we
+ got rid of it during a redesign that made decoding much, much faster overall.
+ But a couple tests in this file used it to check that the serialized form of
+ a message was correct. So, this class implements just the methods that were
+ used by said tests, so that we don't have to rewrite the tests.
+ """
+
+ def __init__(self, bytes):
+ self._bytes = bytes
+ self._pos = 0
+
+ def ReadVarint(self):
+ result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
+ return result
+
+ ReadInt32 = ReadVarint
+ ReadInt64 = ReadVarint
+ ReadUInt32 = ReadVarint
+ ReadUInt64 = ReadVarint
+
+ def ReadSInt64(self):
+ return wire_format.ZigZagDecode(self.ReadVarint())
+
+ ReadSInt32 = ReadSInt64
+
+ def ReadFieldNumberAndWireType(self):
+ return wire_format.UnpackTag(self.ReadVarint())
+
+ def ReadFloat(self):
+ result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+ self._pos += 4
+ return result
+
+ def ReadDouble(self):
+ result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+ self._pos += 8
+ return result
+
+ def EndOfStream(self):
+ return self._pos == len(self._bytes)
+
+
class ReflectionTest(unittest.TestCase):
def assertIs(self, values, others):
@@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase):
for i in range(len(values)):
self.assertTrue(values[i] is others[i])
+ def testScalarConstructor(self):
+ # Constructor with only scalar types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=24,
+ optional_double=54.321,
+ optional_string='optional_string')
+
+ self.assertEqual(24, proto.optional_int32)
+ self.assertEqual(54.321, proto.optional_double)
+ self.assertEqual('optional_string', proto.optional_string)
+
+ def testRepeatedScalarConstructor(self):
+ # Constructor with only repeated scalar types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ repeated_int32=[1, 2, 3, 4],
+ repeated_double=[1.23, 54.321],
+ repeated_bool=[True, False, False],
+ repeated_string=["optional_string"])
+
+ self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
+ self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+ self.assertEquals([True, False, False], list(proto.repeated_bool))
+ self.assertEquals(["optional_string"], list(proto.repeated_string))
+
+ def testRepeatedCompositeConstructor(self):
+ # Constructor with only repeated composite types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ repeated_nested_message=[
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ repeated_foreign_message=[
+ unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ repeatedgroup=[
+ unittest_pb2.TestAllTypes.RepeatedGroup(),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
+
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ list(proto.repeated_nested_message))
+ self.assertEquals(
+ [unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ list(proto.repeated_foreign_message))
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.RepeatedGroup(),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
+ list(proto.repeatedgroup))
+
+ def testMixedConstructor(self):
+ # Constructor with only mixed types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=24,
+ optional_string='optional_string',
+ repeated_double=[1.23, 54.321],
+ repeated_bool=[True, False, False],
+ repeated_nested_message=[
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ repeated_foreign_message=[
+ unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)])
+
+ self.assertEqual(24, proto.optional_int32)
+ self.assertEqual('optional_string', proto.optional_string)
+ self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+ self.assertEquals([True, False, False], list(proto.repeated_bool))
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ list(proto.repeated_nested_message))
+ self.assertEquals(
+ [unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ list(proto.repeated_foreign_message))
+
def testSimpleHasBits(self):
# Test a scalar.
proto = unittest_pb2.TestAllTypes()
@@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase):
proto.optional_fixed32 = 1
proto.optional_int32 = 5
proto.optional_string = 'foo'
+ # Access sub-message but don't set it yet.
+ nested_message = proto.optional_nested_message
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
(proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
(proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
proto.ListFields())
+ proto.optional_nested_message.bb = 123
+ self.assertEqual(
+ [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
+ (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
+ (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
+ (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
+ nested_message) ],
+ proto.ListFields())
+
def testRepeatedListFields(self):
proto = unittest_pb2.TestAllTypes()
proto.repeated_fixed32.append(1)
@@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_string.append('baz')
proto.repeated_string.extend(str(x) for x in xrange(2))
proto.optional_int32 = 21
+ proto.repeated_bool # Access but don't set anything; should not be listed.
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
(proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
@@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase):
extendee_proto.ClearExtension(extension)
extension_proto.foreign_message_int = 23
- self.assertTrue(not toplevel.HasField('submessage'))
self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
def testExtensionFailureModes(self):
@@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase):
empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto)
+ def assertInitialized(self, proto):
+ self.assertTrue(proto.IsInitialized())
+ # Neither method should raise an exception.
+ proto.SerializeToString()
+ proto.SerializePartialToString()
+
+ def assertNotInitialized(self, proto):
+ self.assertFalse(proto.IsInitialized())
+ self.assertRaises(message.EncodeError, proto.SerializeToString)
+ # "Partial" serialization doesn't care if message is uninitialized.
+ proto.SerializePartialToString()
+
def testIsInitialized(self):
# Trivial cases - all optional fields and extensions.
proto = unittest_pb2.TestAllTypes()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
proto = unittest_pb2.TestAllExtensions()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# The case of uninitialized required fields.
proto = unittest_pb2.TestRequired()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.a = proto.b = proto.c = 2
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# The case of uninitialized submessage.
proto = unittest_pb2.TestRequiredForeign()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
proto.optional_message.a = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.optional_message.b = 0
proto.optional_message.c = 0
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized repeated submessage.
message1 = proto.repeated_message.add()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message1.a = message1.b = message1.c = 0
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized repeated group in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message1.a = 1
message1.b = 1
message1.c = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message2.a = 2
message2.b = 2
message2.c = 2
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized nonrepeated message in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
+
+ # Try passing an errors list.
+ errors = []
+ proto = unittest_pb2.TestRequired()
+ self.assertFalse(proto.IsInitialized(errors))
+ self.assertEqual(errors, ['a', 'b', 'c'])
def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes()
@@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase):
test_utf8_bytes, len(test_utf8_bytes) * '\xff')
self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
+ def testEmptyNestedMessage(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.MergeFrom(
+ unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.CopyFrom(
+ unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.MergeFromString('')
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.ParseFromString('')
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ serialized = proto.SerializeToString()
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.MergeFromString(serialized)
+ self.assertTrue(proto2.HasField('optional_nested_message'))
+
+ def testSetInParent(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertFalse(proto.HasField('optionalgroup'))
+ proto.optionalgroup.SetInParent()
+ self.assertTrue(proto.HasField('optionalgroup'))
+
# Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes.
@@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase):
second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto)
+ def testSerializeNegativeValues(self):
+ first_proto = unittest_pb2.TestAllTypes()
+
+ first_proto.optional_int32 = -1
+ first_proto.optional_int64 = -(2 << 40)
+ first_proto.optional_sint32 = -3
+ first_proto.optional_sint64 = -(4 << 40)
+ first_proto.optional_sfixed32 = -5
+ first_proto.optional_sfixed64 = -(6 << 40)
+
+ second_proto = unittest_pb2.TestAllTypes.FromString(
+ first_proto.SerializeToString())
+
+ self.assertEqual(first_proto, second_proto)
+
+ def testParseTruncated(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(first_proto)
+ serialized = first_proto.SerializeToString()
+
+ for truncation_point in xrange(len(serialized) + 1):
+ try:
+ second_proto = unittest_pb2.TestAllTypes()
+ unknown_fields = unittest_pb2.TestEmptyMessage()
+ pos = second_proto._InternalParse(serialized, 0, truncation_point)
+ # If we didn't raise an error then we read exactly the amount expected.
+ self.assertEqual(truncation_point, pos)
+
+ # Parsing to unknown fields should not throw if parsing to known fields
+ # did not.
+ try:
+ pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
+ self.assertEqual(truncation_point, pos2)
+ except message.DecodeError:
+ self.fail('Parsing unknown fields failed when parsing known fields '
+ 'did not.')
+ except message.DecodeError:
+ # Parsing unknown fields should also fail.
+ self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
+ serialized, 0, truncation_point)
+
def testCanonicalSerializationOrder(self):
proto = more_messages_pb2.OutOfOrderFields()
# These are also their tag numbers. Even though we're setting these in
@@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase):
proto.optional_int32 = 1
serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized))
- d = decoder.Decoder(serialized)
+ d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(1, d.ReadInt32())
@@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.a is not set.')
+ 'Message is missing required fields: a,b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.b is not set.')
+ 'Message is missing required fields: b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.c is not set.')
+ 'Message is missing required fields: c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
+ def testSerializeUninitializedSubMessage(self):
+ proto = unittest_pb2.TestRequiredForeign()
+
+ # Sub-message doesn't exist yet, so this succeeds.
+ proto.SerializeToString()
+
+ proto.optional_message.a = 1
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Message is missing required fields: '
+ 'optional_message.b,optional_message.c')
+
+ proto.optional_message.b = 2
+ proto.optional_message.c = 3
+ proto.SerializeToString()
+
+ proto.repeated_message.add().a = 1
+ proto.repeated_message.add().b = 2
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Message is missing required fields: '
+ 'repeated_message[0].b,repeated_message[0].c,'
+ 'repeated_message[1].a,repeated_message[1].c')
+
+ proto.repeated_message[0].b = 2
+ proto.repeated_message[0].c = 3
+ proto.repeated_message[1].a = 1
+ proto.repeated_message[1].c = 3
+ proto.SerializeToString()
+
def testSerializeAllPackedFields(self):
first_proto = unittest_pb2.TestPackedTypes()
second_proto = unittest_pb2.TestPackedTypes()
@@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase):
proto.packed_float.append(2.0) # 4 bytes, will be before double
serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized))
- d = decoder.Decoder(serialized)
+ d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
self.assertEqual(1+1+1+2, d.ReadInt32())
@@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1000.0, d.ReadDouble())
self.assertTrue(d.EndOfStream())
+ def testParsePackedFromUnpacked(self):
+ unpacked = unittest_pb2.TestUnpackedTypes()
+ test_util.SetAllUnpackedFields(unpacked)
+ packed = unittest_pb2.TestPackedTypes()
+ packed.MergeFromString(unpacked.SerializeToString())
+ expected = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(expected)
+ self.assertEqual(expected, packed)
+
+ def testParseUnpackedFromPacked(self):
+ packed = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(packed)
+ unpacked = unittest_pb2.TestUnpackedTypes()
+ unpacked.MergeFromString(packed.SerializeToString())
+ expected = unittest_pb2.TestUnpackedTypes()
+ test_util.SetAllUnpackedFields(expected)
+ self.assertEqual(expected, unpacked)
+
def testFieldNumbers(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
@@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase):
field_descriptor.label)
-class UtilityTest(unittest.TestCase):
-
- def testImergeSorted(self):
- ImergeSorted = reflection._ImergeSorted
- # Various types of emptiness.
- self.assertEqual([], list(ImergeSorted()))
- self.assertEqual([], list(ImergeSorted([])))
- self.assertEqual([], list(ImergeSorted([], [])))
-
- # One nonempty list.
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
-
- # Merging some nonempty lists together.
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
-
- # Elements repeated across component iterators.
- self.assertEqual([1, 2, 2, 3, 3],
- list(ImergeSorted([1, 2], [3], [2, 3])))
-
- # Elements repeated within an iterator.
- self.assertEqual([1, 2, 2, 3, 3],
- list(ImergeSorted([1, 2, 2], [3], [3])))
-
if __name__ == '__main__':
unittest.main()