diff options
author | kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2009-12-18 02:11:36 +0000 |
---|---|---|
committer | kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2009-12-18 02:11:36 +0000 |
commit | fccb146e3fe437b0df1e9c50d4b8e1080ddb4bd9 (patch) | |
tree | 9f2d9fe0267d96a54e541377ffeada3d0bff0d1d /python/google/protobuf/internal/reflection_test.py | |
parent | d5cf7b55a6a1f959d1646785f63ca2b62da78079 (diff) |
Massive roll-up of changes. See CHANGES.txt.
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 354 |
1 files changed, 307 insertions, 47 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 86101774..2c9fa30b 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -38,6 +38,7 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' import operator +import struct import unittest # TODO(robinson): When we split this test in two, only some of these imports @@ -56,6 +57,51 @@ from google.protobuf.internal import test_util from google.protobuf.internal import decoder +class _MiniDecoder(object): + """Decodes a stream of values from a string. + + Once upon a time we actually had a class called decoder.Decoder. Then we + got rid of it during a redesign that made decoding much, much faster overall. + But a couple tests in this file used it to check that the serialized form of + a message was correct. So, this class implements just the methods that were + used by said tests, so that we don't have to rewrite the tests. + """ + + def __init__(self, bytes): + self._bytes = bytes + self._pos = 0 + + def ReadVarint(self): + result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) + return result + + ReadInt32 = ReadVarint + ReadInt64 = ReadVarint + ReadUInt32 = ReadVarint + ReadUInt64 = ReadVarint + + def ReadSInt64(self): + return wire_format.ZigZagDecode(self.ReadVarint()) + + ReadSInt32 = ReadSInt64 + + def ReadFieldNumberAndWireType(self): + return wire_format.UnpackTag(self.ReadVarint()) + + def ReadFloat(self): + result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] + self._pos += 4 + return result + + def ReadDouble(self): + result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] + self._pos += 8 + return result + + def EndOfStream(self): + return self._pos == len(self._bytes) + + class ReflectionTest(unittest.TestCase): def assertIs(self, values, others): @@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase): for i in range(len(values)): self.assertTrue(values[i] is others[i]) + def testScalarConstructor(self): + # Constructor with only scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_double=54.321, + optional_string='optional_string') + + self.assertEqual(24, proto.optional_int32) + self.assertEqual(54.321, proto.optional_double) + self.assertEqual('optional_string', proto.optional_string) + + def testRepeatedScalarConstructor(self): + # Constructor with only repeated scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_int32=[1, 2, 3, 4], + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_string=["optional_string"]) + + self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals(["optional_string"], list(proto.repeated_string)) + + def testRepeatedCompositeConstructor(self): + # Constructor with only repeated composite types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + repeatedgroup=[ + unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) + + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + self.assertEquals( + [unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], + list(proto.repeatedgroup)) + + def testMixedConstructor(self): + # Constructor with only mixed types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_string='optional_string', + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)]) + + self.assertEqual(24, proto.optional_int32) + self.assertEqual('optional_string', proto.optional_string) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + def testSimpleHasBits(self): # Test a scalar. proto = unittest_pb2.TestAllTypes() @@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase): proto.optional_fixed32 = 1 proto.optional_int32 = 5 proto.optional_string = 'foo' + # Access sub-message but don't set it yet. + nested_message = proto.optional_nested_message self.assertEqual( [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], proto.ListFields()) + proto.optional_nested_message.bb = 123 + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), + (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], + nested_message) ], + proto.ListFields()) + def testRepeatedListFields(self): proto = unittest_pb2.TestAllTypes() proto.repeated_fixed32.append(1) @@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase): proto.repeated_string.append('baz') proto.repeated_string.extend(str(x) for x in xrange(2)) proto.optional_int32 = 21 + proto.repeated_bool # Access but don't set anything; should not be listed. self.assertEqual( [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), @@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase): extendee_proto.ClearExtension(extension) extension_proto.foreign_message_int = 23 - self.assertTrue(not toplevel.HasField('submessage')) self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) def testExtensionFailureModes(self): @@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def assertInitialized(self, proto): + self.assertTrue(proto.IsInitialized()) + # Neither method should raise an exception. + proto.SerializeToString() + proto.SerializePartialToString() + + def assertNotInitialized(self, proto): + self.assertFalse(proto.IsInitialized()) + self.assertRaises(message.EncodeError, proto.SerializeToString) + # "Partial" serialization doesn't care if message is uninitialized. + proto.SerializePartialToString() + def testIsInitialized(self): # Trivial cases - all optional fields and extensions. proto = unittest_pb2.TestAllTypes() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) proto = unittest_pb2.TestAllExtensions() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # The case of uninitialized required fields. proto = unittest_pb2.TestRequired() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.a = proto.b = proto.c = 2 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # The case of uninitialized submessage. proto = unittest_pb2.TestRequiredForeign() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) proto.optional_message.a = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.optional_message.b = 0 proto.optional_message.c = 0 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized repeated submessage. message1 = proto.repeated_message.add() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message1.a = message1.b = message1.c = 0 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized repeated group in an extension. proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.TestRequired.multi message1 = proto.Extensions[extension].add() message2 = proto.Extensions[extension].add() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message1.a = 1 message1.b = 1 message1.c = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message2.a = 2 message2.b = 2 message2.c = 2 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized nonrepeated message in an extension. proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.TestRequired.single proto.Extensions[extension].a = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.Extensions[extension].b = 2 proto.Extensions[extension].c = 3 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) + + # Try passing an errors list. + errors = [] + proto = unittest_pb2.TestRequired() + self.assertFalse(proto.IsInitialized(errors)) + self.assertEqual(errors, ['a', 'b', 'c']) def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase): test_utf8_bytes, len(test_utf8_bytes) * '\xff') self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + def testEmptyNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.CopyFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.ParseFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + serialized = proto.SerializeToString() + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFromString(serialized) + self.assertTrue(proto2.HasField('optional_nested_message')) + + def testSetInParent(self): + proto = unittest_pb2.TestAllTypes() + self.assertFalse(proto.HasField('optionalgroup')) + proto.optionalgroup.SetInParent() + self.assertTrue(proto.HasField('optionalgroup')) + # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. @@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase): second_proto.MergeFromString(serialized) self.assertEqual(first_proto, second_proto) + def testSerializeNegativeValues(self): + first_proto = unittest_pb2.TestAllTypes() + + first_proto.optional_int32 = -1 + first_proto.optional_int64 = -(2 << 40) + first_proto.optional_sint32 = -3 + first_proto.optional_sint64 = -(4 << 40) + first_proto.optional_sfixed32 = -5 + first_proto.optional_sfixed64 = -(6 << 40) + + second_proto = unittest_pb2.TestAllTypes.FromString( + first_proto.SerializeToString()) + + self.assertEqual(first_proto, second_proto) + + def testParseTruncated(self): + first_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + + for truncation_point in xrange(len(serialized) + 1): + try: + second_proto = unittest_pb2.TestAllTypes() + unknown_fields = unittest_pb2.TestEmptyMessage() + pos = second_proto._InternalParse(serialized, 0, truncation_point) + # If we didn't raise an error then we read exactly the amount expected. + self.assertEqual(truncation_point, pos) + + # Parsing to unknown fields should not throw if parsing to known fields + # did not. + try: + pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) + self.assertEqual(truncation_point, pos2) + except message.DecodeError: + self.fail('Parsing unknown fields failed when parsing known fields ' + 'did not.') + except message.DecodeError: + # Parsing unknown fields should also fail. + self.assertRaises(message.DecodeError, unknown_fields._InternalParse, + serialized, 0, truncation_point) + def testCanonicalSerializationOrder(self): proto = more_messages_pb2.OutOfOrderFields() # These are also their tag numbers. Even though we're setting these in @@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase): proto.optional_int32 = 1 serialized = proto.SerializeToString() self.assertEqual(proto.ByteSize(), len(serialized)) - d = decoder.Decoder(serialized) + d = _MiniDecoder(serialized) ReadTag = d.ReadFieldNumberAndWireType self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) self.assertEqual(1, d.ReadInt32()) @@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.a is not set.') + 'Message is missing required fields: a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.b is not set.') + 'Message is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.c is not set.') + 'Message is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase): self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) + def testSerializeUninitializedSubMessage(self): + proto = unittest_pb2.TestRequiredForeign() + + # Sub-message doesn't exist yet, so this succeeds. + proto.SerializeToString() + + proto.optional_message.a = 1 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'optional_message.b,optional_message.c') + + proto.optional_message.b = 2 + proto.optional_message.c = 3 + proto.SerializeToString() + + proto.repeated_message.add().a = 1 + proto.repeated_message.add().b = 2 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'repeated_message[0].b,repeated_message[0].c,' + 'repeated_message[1].a,repeated_message[1].c') + + proto.repeated_message[0].b = 2 + proto.repeated_message[0].c = 3 + proto.repeated_message[1].a = 1 + proto.repeated_message[1].c = 3 + proto.SerializeToString() + def testSerializeAllPackedFields(self): first_proto = unittest_pb2.TestPackedTypes() second_proto = unittest_pb2.TestPackedTypes() @@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase): proto.packed_float.append(2.0) # 4 bytes, will be before double serialized = proto.SerializeToString() self.assertEqual(proto.ByteSize(), len(serialized)) - d = decoder.Decoder(serialized) + d = _MiniDecoder(serialized) ReadTag = d.ReadFieldNumberAndWireType self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) self.assertEqual(1+1+1+2, d.ReadInt32()) @@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1000.0, d.ReadDouble()) self.assertTrue(d.EndOfStream()) + def testParsePackedFromUnpacked(self): + unpacked = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(unpacked) + packed = unittest_pb2.TestPackedTypes() + packed.MergeFromString(unpacked.SerializeToString()) + expected = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(expected) + self.assertEqual(expected, packed) + + def testParseUnpackedFromPacked(self): + packed = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(packed) + unpacked = unittest_pb2.TestUnpackedTypes() + unpacked.MergeFromString(packed.SerializeToString()) + expected = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(expected) + self.assertEqual(expected, unpacked) + def testFieldNumbers(self): proto = unittest_pb2.TestAllTypes() self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) @@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase): field_descriptor.label) -class UtilityTest(unittest.TestCase): - - def testImergeSorted(self): - ImergeSorted = reflection._ImergeSorted - # Various types of emptiness. - self.assertEqual([], list(ImergeSorted())) - self.assertEqual([], list(ImergeSorted([]))) - self.assertEqual([], list(ImergeSorted([], []))) - - # One nonempty list. - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], []))) - self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3]))) - - # Merging some nonempty lists together. - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], []))) - - # Elements repeated across component iterators. - self.assertEqual([1, 2, 2, 3, 3], - list(ImergeSorted([1, 2], [3], [2, 3]))) - - # Elements repeated within an iterator. - self.assertEqual([1, 2, 2, 3, 3], - list(ImergeSorted([1, 2, 2], [3], [3]))) - if __name__ == '__main__': unittest.main() |