diff options
author | Jisi Liu <jisi.liu@gmail.com> | 2016-04-28 14:34:59 -0700 |
---|---|---|
committer | Jisi Liu <jisi.liu@gmail.com> | 2016-04-28 14:34:59 -0700 |
commit | cf14183bcd5485b4a71541599ddce0b35eb71352 (patch) | |
tree | 12f6e5eb731d7a70cdac4cdafc8b3131629413e2 /python/google/protobuf/internal | |
parent | f00300d7f04f1c38a7d70e271f9232b94dd0e326 (diff) |
Down integrate from Google internal.
Diffstat (limited to 'python/google/protobuf/internal')
13 files changed, 313 insertions, 69 deletions
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ffcf7511..460a4a6c 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -32,6 +32,7 @@ """ import os +import warnings import sys try: @@ -78,6 +79,11 @@ _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', if _implementation_type != 'python': _implementation_type = 'cpp' +if 'PyPy' in sys.version and _implementation_type == 'cpp': + warnings.warn('PyPy does not work yet with cpp protocol buffers. ' + 'Falling back to the python implementation.') + _implementation_type = 'python' + # This environment variable can be used to switch between the two # 'cpp' implementations, overriding the compile-time constants in the # _api_implementation module. Right now only '2' is supported. Any other diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 4b1811d8..6a13e0bc 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -51,6 +51,7 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import more_messages_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool @@ -60,11 +61,8 @@ from google.protobuf import symbol_database class DescriptorPoolTest(unittest.TestCase): - def CreatePool(self): - return descriptor_pool.DescriptorPool() - def setUp(self): - self.pool = self.CreatePool() + self.pool = descriptor_pool.DescriptorPool() self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( factory_test1_pb2.DESCRIPTOR.serialized_pb) self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( @@ -275,10 +273,13 @@ class DescriptorPoolTest(unittest.TestCase): self.testFindMessageTypeByName() def testComplexNesting(self): + more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( + more_messages_pb2.DESCRIPTOR.serialized_pb) test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) test2_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(more_messages_desc) self.pool.Add(test1_desc) self.pool.Add(test2_desc) TEST1_FILE.CheckFile(self, self.pool) @@ -350,25 +351,15 @@ class DescriptorPoolTest(unittest.TestCase): _CheckDefaultValues(message_class()) -@unittest.skipIf(api_implementation.Type() != 'cpp', - 'explicit tests of the C++ implementation') -class CppDescriptorPoolTest(DescriptorPoolTest): - # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true - # C++ descriptor pool object for C++ implementation. - - def CreatePool(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool() - - class ProtoFile(object): - def __init__(self, name, package, messages, dependencies=None): + def __init__(self, name, package, messages, dependencies=None, + public_dependencies=None): self.name = name self.package = package self.messages = messages self.dependencies = dependencies or [] + self.public_dependencies = public_dependencies or [] def CheckFile(self, test, pool): file_desc = pool.FindFileByName(self.name) @@ -376,6 +367,8 @@ class ProtoFile(object): test.assertEqual(self.package, file_desc.package) dependencies_names = [f.name for f in file_desc.dependencies] test.assertEqual(self.dependencies, dependencies_names) + public_dependencies_names = [f.name for f in file_desc.public_dependencies] + test.assertEqual(self.public_dependencies, public_dependencies_names) for name, msg_type in self.messages.items(): msg_type.CheckType(test, None, name, file_desc) @@ -613,18 +606,9 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes') - def _GetDescriptorPoolClass(self): - # Test with both implementations of descriptor pools. - if api_implementation.Type() == 'cpp': - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool - else: - return descriptor_pool.DescriptorPool - def testEmptyDescriptorPool(self): - # Check that an empty DescriptorPool() contains no message. - pool = self._GetDescriptorPoolClass()() + # Check that an empty DescriptorPool() contains no messages. + pool = descriptor_pool.DescriptorPool() proto_file_name = descriptor_pb2.DESCRIPTOR.name self.assertRaises(KeyError, pool.FindFileByName, proto_file_name) # Add the above file to the pool @@ -636,7 +620,7 @@ class AddDescriptorTest(unittest.TestCase): def testCustomDescriptorPool(self): # Create a new pool, and add a file descriptor. - pool = self._GetDescriptorPoolClass()() + pool = descriptor_pool.DescriptorPool() file_desc = descriptor_pb2.FileDescriptorProto( name='some/file.proto', package='package') file_desc.message_type.add(name='Message') @@ -757,7 +741,9 @@ TEST2_FILE = ProtoFile( ExtensionField(1001, 'DescriptorPoolTest1')), ]), }, - dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto', + 'google/protobuf/internal/more_messages.proto'], + public_dependencies=['google/protobuf/internal/more_messages.proto']) if __name__ == '__main__': diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto index e3fa660c..a218eccb 100644 --- a/python/google/protobuf/internal/descriptor_pool_test2.proto +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -33,6 +33,7 @@ syntax = "proto2"; package google.protobuf.python.internal; import "google/protobuf/internal/descriptor_pool_test1.proto"; +import public "google/protobuf/internal/more_messages.proto"; message DescriptorPoolTest3 { diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 54b1f688..7bb7d1ac 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -131,6 +131,60 @@ class MessageFactoryTest(unittest.TestCase): self.assertEqual('test1', msg1.Extensions[ext1]) self.assertEqual('test2', msg1.Extensions[ext2]) + def testDuplicateExtensionNumber(self): + pool = descriptor_pool.DescriptorPool() + factory = message_factory.MessageFactory(pool=pool) + + # Add Container message. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/container.proto' + f.package = 'google.protobuf.python.internal' + msg = f.message_type.add() + msg.name = 'Container' + rng = msg.extension_range.add() + rng.start = 1 + rng.end = 10 + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Container', msgs) + + # Extend container. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/extension.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Extension' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Extension' + ext.extendee = 'Container' + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Extension', msgs) + + # Add Duplicate extending the same field number. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/duplicate.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Duplicate' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Duplicate' + ext.extendee = 'Container' + pool.Add(f) + + with self.assertRaises(Exception) as cm: + factory.GetMessages([f.name]) + + self.assertIsInstance(cm.exception, (AssertionError, ValueError)) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 1232ccc9..4ee31d8e 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -57,18 +57,18 @@ try: except ImportError: import unittest -from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool -from google.protobuf import map_unittest_pb2 from google.protobuf import message_factory from google.protobuf import text_format -from google.protobuf import unittest_pb2 -from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message +from google.protobuf.internal import _parameterized if six.PY3: long = int @@ -1265,7 +1265,10 @@ class Proto3Test(unittest.TestCase): self.assertFalse(-2**33 in msg.map_int64_int64) self.assertFalse(123 in msg.map_uint32_uint32) self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse(123 in msg.map_int32_double) + self.assertFalse(False in msg.map_bool_bool) self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(111 in msg.map_int32_bytes) self.assertFalse(888 in msg.map_int32_enum) # Accessing an unset key returns the default. @@ -1273,7 +1276,12 @@ class Proto3Test(unittest.TestCase): self.assertEqual(0, msg.map_int64_int64[-2**33]) self.assertEqual(0, msg.map_uint32_uint32[123]) self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual(0.0, msg.map_int32_double[123]) + self.assertTrue(isinstance(msg.map_int32_double[123], float)) + self.assertEqual(False, msg.map_bool_bool[False]) + self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(b'', msg.map_int32_bytes[111]) self.assertEqual(0, msg.map_int32_enum[888]) # It also sets the value in the map @@ -1281,7 +1289,10 @@ class Proto3Test(unittest.TestCase): self.assertTrue(-2**33 in msg.map_int64_int64) self.assertTrue(123 in msg.map_uint32_uint32) self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue(123 in msg.map_int32_double) + self.assertTrue(False in msg.map_bool_bool) self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(111 in msg.map_int32_bytes) self.assertTrue(888 in msg.map_int32_enum) self.assertIsInstance(msg.map_string_string['abc'], six.text_type) @@ -1587,6 +1598,21 @@ class Proto3Test(unittest.TestCase): matching_dict = {2: 4, 3: 6, 4: 8} self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) + def testMapItems(self): + # Map items used to have strange behaviors when use c extension. Because + # [] may reorder the map and invalidate any exsting iterators. + # TODO(jieluo): Check if [] reordering the map is a bug or intended + # behavior. + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = '' + msg.map_string_string['trainable_variables'] = '' + msg.map_string_string['variables'] = '' + msg.map_string_string['init_op'] = '' + msg.map_string_string['summaries'] = '' + items1 = msg.map_string_string.items() + items2 = msg.map_string_string.items() + self.assertEqual(items1, items2) + def testMapIterationClearMessage(self): # Iterator needs to work even if message and map are deleted. msg = map_unittest_pb2.TestMap() diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py index 822ad895..36dfbfde 100644 --- a/python/google/protobuf/internal/proto_builder_test.py +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -40,6 +40,7 @@ try: import unittest2 as unittest except ImportError: import unittest + from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool from google.protobuf import proto_builder diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 87f60666..f8f73dd2 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -56,7 +56,14 @@ import struct import weakref import six -import six.moves.copyreg as copyreg +try: + import six.moves.copyreg as copyreg +except ImportError: + # On some platforms, for example gMac, we run native Python because there is + # nothing like hermetic Python. This means lesser control on the system and + # the six.moves package may be missing (is missing on 20150321 on gMac). Be + # extra conservative and try to load the old replacement if it fails. + import copy_reg as copyreg # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers @@ -490,6 +497,9 @@ def _AddInitMethod(message_descriptor, cls): if field is None: raise TypeError("%s() got an unexpected keyword argument '%s'" % (message_descriptor.name, field_name)) + if field_value is None: + # field=None is the same as no field at all. + continue if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite @@ -889,17 +899,6 @@ def _AddClearExtensionMethod(cls): cls.ClearExtension = ClearExtension -def _AddClearMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def Clear(self): - # Clear fields. - self._fields = {} - self._unknown_fields = () - self._oneofs = {} - self._Modified() - cls.Clear = Clear - - def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): @@ -999,16 +998,6 @@ def _AddUnicodeMethod(unused_message_descriptor, cls): cls.__unicode__ = __unicode__ -def _AddSetListenerMethod(cls): - """Helper for _AddMessageMethods().""" - def SetListener(self, listener): - if listener is None: - self._listener = message_listener_mod.NullMessageListener() - else: - self._listener = listener - cls._SetListener = SetListener - - def _BytesForNonRepeatedElement(value, field_number, field_type): """Returns the number of bytes needed to serialize a non-repeated element. The returned byte count includes space for tag information and any @@ -1288,6 +1277,32 @@ def _AddWhichOneofMethod(message_descriptor, cls): cls.WhichOneof = WhichOneof +def _Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._oneofs = {} + self._Modified() + + +def _DiscardUnknownFields(self): + self._unknown_fields = [] + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for sub_message in value: + sub_message.DiscardUnknownFields() + else: + value.DiscardUnknownFields() + + +def _SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -1296,12 +1311,10 @@ def _AddMessageMethods(message_descriptor, cls): if message_descriptor.is_extendable: _AddClearExtensionMethod(cls) _AddHasExtensionMethod(cls) - _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) _AddReprMethod(message_descriptor, cls) _AddUnicodeMethod(message_descriptor, cls) - _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) @@ -1309,6 +1322,10 @@ def _AddMessageMethods(message_descriptor, cls): _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) _AddWhichOneofMethod(message_descriptor, cls) + # Adds methods which do not depend on cls. + cls.Clear = _Clear + cls.DiscardUnknownFields = _DiscardUnknownFields + cls._SetListener = _SetListener def _AddPrivateHelperMethods(message_descriptor, cls): @@ -1518,3 +1535,14 @@ class _ExtensionDict(object): Extension field descriptor. """ return self._extended_message._extensions_by_name.get(name, None) + + def _FindExtensionByNumber(self, number): + """Tries to find a known extension with the field number. + + Args: + number: Extension field number. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_number.get(number, None) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 9e61ea0e..6dc2fffe 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -120,11 +120,13 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes( optional_int32=24, optional_double=54.321, - optional_string='optional_string') + optional_string='optional_string', + optional_float=None) self.assertEqual(24, proto.optional_int32) self.assertEqual(54.321, proto.optional_double) self.assertEqual('optional_string', proto.optional_string) + self.assertFalse(proto.HasField("optional_float")) def testRepeatedScalarConstructor(self): # Constructor with only repeated scalar types should succeed. @@ -132,12 +134,14 @@ class ReflectionTest(unittest.TestCase): repeated_int32=[1, 2, 3, 4], repeated_double=[1.23, 54.321], repeated_bool=[True, False, False], - repeated_string=["optional_string"]) + repeated_string=["optional_string"], + repeated_float=None) self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) self.assertEqual([1.23, 54.321], list(proto.repeated_double)) self.assertEqual([True, False, False], list(proto.repeated_bool)) self.assertEqual(["optional_string"], list(proto.repeated_string)) + self.assertEqual([], list(proto.repeated_float)) def testRepeatedCompositeConstructor(self): # Constructor with only repeated composite types should succeed. @@ -188,7 +192,8 @@ class ReflectionTest(unittest.TestCase): repeated_foreign_message=[ unittest_pb2.ForeignMessage(c=-43), unittest_pb2.ForeignMessage(c=45324), - unittest_pb2.ForeignMessage(c=12)]) + unittest_pb2.ForeignMessage(c=12)], + optional_nested_message=None) self.assertEqual(24, proto.optional_int32) self.assertEqual('optional_string', proto.optional_string) @@ -205,6 +210,7 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + self.assertFalse(proto.HasField("optional_nested_message")) def testConstructorTypeError(self): self.assertRaises( diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 8ce0a44f..ab2bf05b 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -250,6 +250,36 @@ class TextFormatTest(TextFormatBase): message.c = 123 self.assertEqual('c: 123\n', str(message)) + def testPrintField(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintField(field, value, out) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintField(field, value) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + + def testPrintFieldValue(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintFieldValue(field, value, out) + self.assertEqual('0.0', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintFieldValue(field, value) + self.assertEqual('0.0', out.getvalue()) + out.close() + def testParseAllFields(self, message_module): message = message_module.TestAllTypes() test_util.SetAllFields(message) @@ -616,6 +646,26 @@ class Proto2Tests(TextFormatBase): ' text: \"bar\"\n' '}\n') + def testPrintMessageSetByFieldNumber(self): + out = text_format.TextWriter(False) + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + text_format.PrintMessage(message, out, use_field_number=True) + self.CompareToGoldenText( + out.getvalue(), + '1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + out.close() + def testPrintMessageSetAsOneLine(self): message = unittest_mset_pb2.TestMessageSetContainer() ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension @@ -656,6 +706,48 @@ class Proto2Tests(TextFormatBase): self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str) + def testParseMessageByFieldNumber(self): + message = unittest_pb2.TestAllTypes() + text = ('34: 1\n' + 'repeated_uint64: 2\n') + text_format.Parse(text, message, allow_field_number=True) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message, allow_field_number=True) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + + # Can't parse field number without set allow_field_number=True. + message = unittest_pb2.TestAllTypes() + text = '34:1\n' + six.assertRaisesRegex( + self, + text_format.ParseError, + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"34".'), + text_format.Parse, text, message) + + # Can't parse if field number is not found. + text = '1234:1\n' + six.assertRaisesRegex( + self, + text_format.ParseError, + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"1234".'), + text_format.Parse, text, message, allow_field_number=True) + def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) @@ -696,6 +788,7 @@ class Proto2Tests(TextFormatBase): text = ('message_set {\n' ' [unknown_extension] {\n' ' i: 23\n' + ' bin: "\xe0"' ' [nested_unknown_ext]: {\n' ' i: 23\n' ' test: "test_string"\n' diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index f30ca6a8..1be3ad9a 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -109,6 +109,16 @@ class TypeChecker(object): return proposed_value +class TypeCheckerWithDefault(TypeChecker): + + def __init__(self, default_value, *acceptable_types): + TypeChecker.__init__(self, acceptable_types) + self._default_value = default_value + + def DefaultValue(self): + return self._default_value + + # IntValueChecker and its subclasses perform integer type-checks # and bounds-checks. class IntValueChecker(object): @@ -212,12 +222,13 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), - _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault( + False, bool, int), + _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index bb2748e4..84073f1c 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -119,6 +119,26 @@ class UnknownFieldsTest(unittest.TestCase): message.ParseFromString(self.all_fields.SerializeToString()) self.assertNotEqual(self.empty_message, message) + def testDiscardUnknownFields(self): + self.empty_message.DiscardUnknownFields() + self.assertEqual(b'', self.empty_message.SerializeToString()) + # Test message field and repeated message field. + message = unittest_pb2.TestAllTypes() + other_message = unittest_pb2.TestAllTypes() + other_message.optional_string = 'discard' + message.optional_nested_message.ParseFromString( + other_message.SerializeToString()) + message.repeated_nested_message.add().ParseFromString( + other_message.SerializeToString()) + self.assertNotEqual( + b'', message.optional_nested_message.SerializeToString()) + self.assertNotEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + message.DiscardUnknownFields() + self.assertEqual(b'', message.optional_nested_message.SerializeToString()) + self.assertEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + class UnknownFieldsAccessorsTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d35fcc5f..7c5dffd0 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -82,10 +82,14 @@ class Any(object): msg.ParseFromString(self.value) return True + def TypeName(self): + """Returns the protobuf type name of the inner message.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] + def Is(self, descriptor): """Checks if this Any represents the given protobuf type.""" - # Only last part is to be used: b/25630112 - return self.type_url.split('/')[-1] == descriptor.full_name + return self.TypeName() == descriptor.full_name class Timestamp(object): diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 18329205..2f32ac99 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -610,6 +610,14 @@ class AnyTest(unittest.TestCase): raise AttributeError('%s should not have Pack method.' % msg_descriptor.full_name) + def testMessageName(self): + # Creates and sets message. + submessage = any_test_pb2.TestAny() + submessage.int_value = 12345 + msg = any_pb2.Any() + msg.Pack(submessage) + self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny') + def testPackWithCustomTypeUrl(self): submessage = any_test_pb2.TestAny() submessage.int_value = 12345 |