diff options
author | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
---|---|---|
committer | jieluo@google.com <jieluo@google.com@630680e5-0e50-0410-840e-4b1c322b438d> | 2014-08-12 21:10:30 +0000 |
commit | bde4a3254a7de58911941b0fbf38e9dd992de973 (patch) | |
tree | 02b151c2ec6e9be2e9d5ea0efc406aabe6958ae7 /python/google/protobuf/internal/reflection_test.py | |
parent | d7339318a33c5f9e8b5dded4077223fbd4ebf229 (diff) |
down integrate python opensource to svn
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 419 |
1 files changed, 341 insertions, 78 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index ed286461..b3c414c7 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,11 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy import gc import operator import struct -import unittest +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -49,6 +50,7 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 @@ -102,7 +104,7 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) @@ -533,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -549,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -566,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -594,6 +607,30 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() @@ -613,29 +650,6 @@ class ReflectionTest(unittest.TestCase): pb.optional_nested_enum = 1 self.assertEqual(1, pb.optional_nested_enum) - # Invalid enum values. - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - bytes_size_before = pb.ByteSize() - - pb.optional_nested_enum = 4 - self.assertEqual(4, pb.optional_nested_enum) - - pb.optional_nested_enum = 0 - self.assertEqual(0, pb.optional_nested_enum) - - # Make sure that setting the same enum field doesn't just add unknown - # fields (but overwrites them). - self.assertEqual(bytes_size_before, pb.ByteSize()) - - # Is the invalid value preserved after serialization? - serialized = pb.SerializeToString() - pb2 = unittest_pb2.TestAllTypes() - pb2.ParseFromString(serialized) - self.assertEqual(0, pb2.optional_nested_enum) - self.assertEqual(pb, pb2) - def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) @@ -749,9 +763,9 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignEnum.items()) proto = unittest_pb2.TestAllTypes() - self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) - self.assertEqual([1, 2, 3], proto.NestedEnum.values()) - self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], proto.NestedEnum.items()) def testRepeatedScalars(self): @@ -1155,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -1451,6 +1473,19 @@ class ReflectionTest(unittest.TestCase): proto2 = unittest_pb2.TestAllExtensions() self.assertRaises(TypeError, proto1.CopyFrom, proto2) + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() # C++ implementation does not support lazy fields right now so leave it @@ -1496,11 +1531,23 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(6, foreign.c) nested.bb = 15 foreign.c = 16 - self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertFalse(proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) - self.assertTrue(not proto.HasField('optional_foreign_message')) + self.assertFalse(proto.HasField('optional_foreign_message')) self.assertEqual(0, proto.optional_foreign_message.c) + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1571,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1588,17 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - if api_implementation.Type() == 'python': - # Values of type 'str' are also accepted as long as they can be - # encoded in UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1621,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1632,7 +1712,8 @@ class ReflectionTest(unittest.TestCase): # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) @@ -1643,17 +1724,22 @@ class ReflectionTest(unittest.TestCase): # MergeFromString and thus has no way to throw the exception. # # The pure Python API always returns objects of type 'unicode' (UTF-8 - # encoded), or 'str' (in 7 bit ASCII). - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') unicode_decode_failed = False try: - message2.MergeFromString(bytes) - except UnicodeDecodeError as e: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: unicode_decode_failed = True string_field = message2.str - self.assertTrue(unicode_decode_failed or type(string_field) == str) + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1667,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1690,7 +1779,7 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() @@ -1706,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1792,7 +1881,7 @@ class FullProtosEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1825,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1837,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -2133,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -2149,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -2157,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -2249,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -2274,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -2327,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -2345,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -2355,7 +2474,9 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" @@ -2406,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -2478,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2511,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2520,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2572,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2590,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2640,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2667,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() |