diff options
Diffstat (limited to 'python')
27 files changed, 1357 insertions, 418 deletions
diff --git a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/message_test.py b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/message_test.py index 53e9d507..e71b295b 100755 --- a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/message_test.py +++ b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/message_test.py @@ -55,6 +55,11 @@ from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import message +try: + cmp # Python 2 +except NameError: + cmp = lambda x, y: (x > y) - (x < y) # Python 3 + # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. def isnan(val): diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index eb45e127..b8f5140b 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -107,6 +107,7 @@ class DescriptorDatabase(object): 'some.package.name.Message' 'some.package.name.Message.NestedEnum' + 'some.package.name.Message.some_field' The file descriptor proto containing the specified symbol must be added to this database using the Add method or else an error will be raised. @@ -120,8 +121,16 @@ class DescriptorDatabase(object): Raises: KeyError if no file contains the specified symbol. """ - - return self._file_desc_protos_by_symbol[symbol] + try: + return self._file_desc_protos_by_symbol[symbol] + except KeyError: + # Fields, enum values, and nested extensions are not in + # _file_desc_protos_by_symbol. Try to find the top level + # descriptor. Non-existent nested symbol under a valid top level + # descriptor can also be found. The behavior is the same with + # protobuf C++. + top_level, _, _ = symbol.rpartition('.') + return self._file_desc_protos_by_symbol[top_level] def _ExtractSymbols(desc_proto, package): diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 3dbe0fd0..cb7146b6 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -127,9 +127,6 @@ class DescriptorPool(object): self._service_descriptors = {} self._file_descriptors = {} self._toplevel_extensions = {} - # TODO(jieluo): Remove _file_desc_by_toplevel_extension when - # FieldDescriptor.file is added in code gen. - self._file_desc_by_toplevel_extension = {} # We store extensions in two two-level mappings: The first key is the # descriptor of the message being extended, the second key is the extension # full name or its tag number. @@ -255,11 +252,6 @@ class DescriptorPool(object): """ self._AddFileDescriptor(file_desc) - # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. - # Remove it when FieldDescriptor.file is added in code gen. - for extension in file_desc.extensions_by_name.values(): - self._file_desc_by_toplevel_extension[ - extension.full_name] = file_desc def _AddFileDescriptor(self, file_desc): """Adds a FileDescriptor to the pool, non-recursively. @@ -339,7 +331,7 @@ class DescriptorPool(object): pass try: - return self._file_desc_by_toplevel_extension[symbol] + return self._toplevel_extensions[symbol].file except KeyError: pass @@ -405,6 +397,23 @@ class DescriptorPool(object): message_descriptor = self.FindMessageTypeByName(message_name) return message_descriptor.fields_by_name[field_name] + def FindOneofByName(self, full_name): + """Loads the named oneof descriptor from the pool. + + Args: + full_name: The full name of the oneof descriptor to load. + + Returns: + The oneof descriptor for the named oneof. + + Raises: + KeyError: if the oneof cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, oneof_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.oneofs_by_name[oneof_name] + def FindExtensionByName(self, full_name): """Loads the named extension descriptor from the pool. diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index 422af590..bce71bb8 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -61,10 +61,15 @@ if _api_version < 0: # Still unspecified? del _use_fast_cpp_protos _api_version = 2 except ImportError: - if _proto_extension_modules_exist_in_build: - if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2. - _api_version = 2 - # TODO(b/17427486): Make Python 2 default to C++ impl v2. + try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import use_pure_python + del use_pure_python # Avoids a pylint error and namespace pollution. + except ImportError: + if _proto_extension_modules_exist_in_build: + if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2. + _api_version = 2 + # TODO(b/17427486): Make Python 2 default to C++ impl v2. _default_implementation_type = ( 'python' if _api_version <= 0 else 'cpp') @@ -137,3 +142,29 @@ def Version(): # For internal use only def IsPythonDefaultSerializationDeterministic(): return _python_deterministic_proto_serialization + +# DO NOT USE: For migration and testing only. Will be removed when Proto3 +# defaults to preserve unknowns. +if _implementation_type == 'cpp': + try: + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + + def GetPythonProto3PreserveUnknownsDefault(): + return _message.GetPythonProto3PreserveUnknownsDefault() + + def SetPythonProto3PreserveUnknownsDefault(preserve): + _message.SetPythonProto3PreserveUnknownsDefault(preserve) + except ImportError: + # Unrecognized cpp implementation. Skipping the unknown fields APIs. + pass +else: + _python_proto3_preserve_unknowns_default = False + + def GetPythonProto3PreserveUnknownsDefault(): + return _python_proto3_preserve_unknowns_default + + def SetPythonProto3PreserveUnknownsDefault(preserve): + global _python_proto3_preserve_unknowns_default + _python_proto3_preserve_unknowns_default = preserve + diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index 5225a458..1f1a3db9 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -39,6 +39,7 @@ try: except ImportError: import unittest +from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database @@ -54,16 +55,49 @@ class DescriptorDatabaseTest(unittest.TestCase): self.assertEqual(file_desc_proto, db.FindFileByName( 'google/protobuf/internal/factory_test2.proto')) + # Can find message type. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message')) + # Can find nested message type. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) + # Can find enum type. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Enum')) + # Can find nested enum type. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum')) + # Can find field. + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.list_field')) + # Can find enum value. + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0')) + # Can find top level extension. + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.another_field')) + # Can find nested extension inside a message. + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.one_more_field')) + + # Can find service. + file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb) + db.Add(file_desc_proto2) + self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol( + 'protobuf_unittest.TestService')) + + # Non-existent field under a valid top level symbol can also be + # found. The behavior is the same with protobuf C++. + self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.none_field')) + + self.assertRaises(KeyError, + db.FindFileContainingSymbol, + 'protobuf_unittest.NoneMessage') + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 6015e6f8..15c857bb 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -60,26 +60,8 @@ from google.protobuf import message_factory from google.protobuf import symbol_database -class DescriptorPoolTest(unittest.TestCase): - def setUp(self): - # TODO(jieluo): Should make the pool which is created by - # serialized_pb same with generated pool. - # TODO(jieluo): More test coverage for the generated pool. - self.pool = descriptor_pool.DescriptorPool() - self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( - factory_test1_pb2.DESCRIPTOR.serialized_pb) - self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( - factory_test2_pb2.DESCRIPTOR.serialized_pb) - self.pool.Add(self.factory_test1_fd) - self.pool.Add(self.factory_test2_fd) - - self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( - unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) - self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( - unittest_import_pb2.DESCRIPTOR.serialized_pb)) - self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( - unittest_pb2.DESCRIPTOR.serialized_pb)) +class DescriptorPoolTestBase(object): def testFindFileByName(self): name1 = 'google/protobuf/internal/factory_test1.proto' @@ -137,14 +119,6 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google/protobuf/unittest.proto', file_desc5.name) - # Tests the generated pool. - assert descriptor_pool.Default().FindFileContainingSymbol( - 'google.protobuf.python.internal.Factory2Message.one_more_field') - assert descriptor_pool.Default().FindFileContainingSymbol( - 'google.protobuf.python.internal.another_field') - assert descriptor_pool.Default().FindFileContainingSymbol( - 'protobuf_unittest.TestService') - def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') @@ -231,6 +205,27 @@ class DescriptorPoolTest(unittest.TestCase): msg2.fields_by_name[name].containing_oneof) self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + def testFindTypeErrors(self): + self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '') + + # TODO(jieluo): Fix python to raise correct errors. + if api_implementation.Type() == 'cpp': + self.assertRaises(TypeError, self.pool.FindMethodByName, 0) + self.assertRaises(KeyError, self.pool.FindMethodByName, '') + error_type = TypeError + else: + error_type = AttributeError + self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0) + self.assertRaises(error_type, self.pool.FindFieldByName, 0) + self.assertRaises(error_type, self.pool.FindExtensionByName, 0) + self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0) + self.assertRaises(error_type, self.pool.FindOneofByName, 0) + self.assertRaises(error_type, self.pool.FindServiceByName, 0) + self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0) + if api_implementation.Type() == 'python': + error_type = KeyError + self.assertRaises(error_type, self.pool.FindFileByName, 0) + def testFindMessageTypeByNameFailure(self): with self.assertRaises(KeyError): self.pool.FindMessageTypeByName('Does not exist') @@ -270,6 +265,11 @@ class DescriptorPoolTest(unittest.TestCase): self.pool.FindEnumTypeByName('Does not exist') def testFindFieldByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find field correctly + # when descriptor pool is using an underlying database. + return field = self.pool.FindFieldByName( 'google.protobuf.python.internal.Factory1Message.list_value') self.assertEqual(field.name, 'list_value') @@ -279,7 +279,24 @@ class DescriptorPoolTest(unittest.TestCase): with self.assertRaises(KeyError): self.pool.FindFieldByName('Does not exist') + def testFindOneofByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find oneof correctly + # when descriptor pool is using an underlying database. + return + oneof = self.pool.FindOneofByName( + 'google.protobuf.python.internal.Factory2Message.oneof_field') + self.assertEqual(oneof.name, 'oneof_field') + with self.assertRaises(KeyError): + self.pool.FindOneofByName('Does not exist') + def testFindExtensionByName(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix cpp extension to find extension correctly + # when descriptor pool is using an underlying database. + return # An extension defined in a message. extension = self.pool.FindExtensionByName( 'google.protobuf.python.internal.Factory2Message.one_more_field') @@ -352,6 +369,8 @@ class DescriptorPoolTest(unittest.TestCase): def testFindService(self): service = self.pool.FindServiceByName('protobuf_unittest.TestService') self.assertEqual(service.full_name, 'protobuf_unittest.TestService') + with self.assertRaises(KeyError): + self.pool.FindServiceByName('Does not exist') def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() @@ -361,24 +380,17 @@ class DescriptorPoolTest(unittest.TestCase): self.testFindMessageTypeByName() def testAddSerializedFile(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return self.pool = descriptor_pool.DescriptorPool() self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) self.testFindMessageTypeByName() - def testComplexNesting(self): - more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( - more_messages_pb2.DESCRIPTOR.serialized_pb) - test1_desc = descriptor_pb2.FileDescriptorProto.FromString( - descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) - test2_desc = descriptor_pb2.FileDescriptorProto.FromString( - descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) - self.pool.Add(more_messages_desc) - self.pool.Add(test1_desc) - self.pool.Add(test2_desc) - TEST1_FILE.CheckFile(self, self.pool) - TEST2_FILE.CheckFile(self, self.pool) - def testEnumDefaultValue(self): """Test the default value of enums which don't start at zero.""" @@ -397,6 +409,12 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR) _CheckDefaultValue(file_descriptor) + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return # Then check the dynamic pool and its internal DescriptorDatabase. descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) @@ -444,6 +462,110 @@ class DescriptorPoolTest(unittest.TestCase): unittest_pb2.TestAllTypes.DESCRIPTOR.full_name)) _CheckDefaultValues(message_class()) + def testAddFileDescriptor(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + self.pool.Add(file_desc) + self.pool.AddSerializedFile(file_desc.SerializeToString()) + + def testComplexNesting(self): + if isinstance(self, SecondaryDescriptorFromDescriptorDB): + if api_implementation.Type() == 'cpp': + # Cpp extension cannot call Add on a DescriptorPool + # that uses a DescriptorDatabase. + # TODO(jieluo): Fix python and cpp extension diff. + return + more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( + more_messages_pb2.DESCRIPTOR.serialized_pb) + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(more_messages_desc) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + +class DefaultDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.Default() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def testFindMethods(self): + self.assertIs( + self.pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + self.pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + self.pool.FindFieldByName( + 'protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + self.pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + self.pool.FindExtensionByName( + 'protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + self.pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + self.pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) + + +class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + + +class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase, + unittest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db = descriptor_database.DescriptorDatabase() + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + self.pool = descriptor_pool.DescriptorPool(descriptor_db=db) + class ProtoFile(object): @@ -568,6 +690,11 @@ class MessageField(object): test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) test.assertEqual(file_desc, field_desc.file) + # TODO(jieluo): Fix python and cpp extension diff for message field + # default value. + if api_implementation.Type() == 'cpp': + test.assertRaises( + NotImplementedError, getattr, field_desc, 'default_value') class StringField(object): @@ -739,6 +866,25 @@ class AddDescriptorTest(unittest.TestCase): 'some/file.proto') self.assertEqual(pool.FindMessageTypeByName('package.Message').name, 'Message') + # Test no package + file_proto = descriptor_pb2.FileDescriptorProto( + name='some/filename/container.proto') + message_proto = file_proto.message_type.add( + name='TopMessage') + message_proto.field.add( + name='bb', + number=1, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL) + enum_proto = file_proto.enum_type.add(name='TopEnum') + enum_proto.value.add(name='FOREIGN_FOO', number=4) + file_proto.service.add(name='TopService') + pool = descriptor_pool.DescriptorPool() + pool.Add(file_proto) + self.assertEqual('TopMessage', + pool.FindMessageTypeByName('TopMessage').name) + self.assertEqual('TopEnum', pool.FindEnumTypeByName('TopEnum').name) + self.assertEqual('TopService', pool.FindServiceByName('TopService').name) def testFileDescriptorOptionsWithCustomDescriptorPool(self): # Create a descriptor pool, and add a new FileDescriptorProto to it. @@ -757,40 +903,18 @@ class AddDescriptorTest(unittest.TestCase): # The object returned by GetOptions() is cached. self.assertIs(options, file_descriptor.GetOptions()) - -class DefaultPoolTest(unittest.TestCase): - - def testFindMethods(self): - pool = descriptor_pool.Default() - self.assertIs( - pool.FindFileByName('google/protobuf/unittest.proto'), - unittest_pb2.DESCRIPTOR) - self.assertIs( - pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), - unittest_pb2.TestAllTypes.DESCRIPTOR) - self.assertIs( - pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), - unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) - self.assertIs( - pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), - unittest_pb2.ForeignEnum.DESCRIPTOR) - if api_implementation.Type() != 'cpp': - self.skipTest('Only the C++ implementation correctly indexes all types') - self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( - pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), - unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) - self.assertIs( - pool.FindServiceByName('protobuf_unittest.TestService'), - unittest_pb2.DESCRIPTOR.services_by_name['TestService']) - - def testAddFileDescriptor(self): - pool = descriptor_pool.Default() - file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') - pool.Add(file_desc) - pool.AddSerializedFile(file_desc.SerializeToString()) + def testAddTypeError(self): + pool = descriptor_pool.DescriptorPool() + with self.assertRaises(TypeError): + pool.AddDescriptor(0) + with self.assertRaises(TypeError): + pool.AddEnumDescriptor(0) + with self.assertRaises(TypeError): + pool.AddServiceDescriptor(0) + with self.assertRaises(TypeError): + pool.AddExtensionDescriptor(0) + with self.assertRaises(TypeError): + pool.AddFileDescriptor(0) TEST1_FILE = ProtoFile( diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index c0010081..f1d5934e 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -107,6 +107,12 @@ class DescriptorTest(unittest.TestCase): self.my_message.enum_types_by_name[ 'ForeignEnum'].values_by_number[4].name, self.my_message.EnumValueName('ForeignEnum', 4)) + with self.assertRaises(KeyError): + self.my_message.EnumValueName('ForeignEnum', 999) + with self.assertRaises(KeyError): + self.my_message.EnumValueName('NoneEnum', 999) + with self.assertRaises(TypeError): + self.my_message.EnumValueName() def testEnumFixups(self): self.assertEqual(self.my_enum, self.my_enum.values[0].type) @@ -134,15 +140,17 @@ class DescriptorTest(unittest.TestCase): def testSimpleCustomOptions(self): file_descriptor = unittest_custom_options_pb2.DESCRIPTOR - message_descriptor =\ - unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR + message_descriptor = (unittest_custom_options_pb2. + TestMessageWithCustomOptions.DESCRIPTOR) field_descriptor = message_descriptor.fields_by_name['field1'] oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof'] enum_descriptor = message_descriptor.enum_types_by_name['AnEnum'] - enum_value_descriptor =\ - message_descriptor.enum_values_by_name['ANENUM_VAL2'] - service_descriptor =\ - unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR + enum_value_descriptor = (message_descriptor. + enum_values_by_name['ANENUM_VAL2']) + other_enum_value_descriptor = (message_descriptor. + enum_values_by_name['ANENUM_VAL1']) + service_descriptor = (unittest_custom_options_pb2. + TestServiceWithCustomOptions.DESCRIPTOR) method_descriptor = service_descriptor.FindMethodByName('Foo') file_options = file_descriptor.GetOptions() @@ -178,6 +186,11 @@ class DescriptorTest(unittest.TestCase): unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR) self.assertTrue(file_descriptor.has_options) self.assertFalse(message_descriptor.has_options) + self.assertTrue(field_descriptor.has_options) + self.assertTrue(oneof_descriptor.has_options) + self.assertTrue(enum_descriptor.has_options) + self.assertTrue(enum_value_descriptor.has_options) + self.assertFalse(other_enum_value_descriptor.has_options) def testDifferentCustomOptionTypes(self): kint32min = -2**31 @@ -400,6 +413,12 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.package, 'protobuf_unittest') self.assertEqual(self.my_file.pool, self.pool) + self.assertFalse(self.my_file.has_options) + self.assertEqual('proto2', self.my_file.syntax) + file_proto = descriptor_pb2.FileDescriptorProto() + self.my_file.CopyToProto(file_proto) + self.assertEqual(self.my_file.serialized_pb, + file_proto.SerializeToString()) # Generated modules also belong to the default pool. self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default()) @@ -407,13 +426,31 @@ class DescriptorTest(unittest.TestCase): api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 'Immutability of descriptors is only enforced in v2 implementation') def testImmutableCppDescriptor(self): + file_descriptor = unittest_pb2.DESCRIPTOR message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name['optional_int32'] + enum_descriptor = message_descriptor.enum_types_by_name['NestedEnum'] + oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field'] with self.assertRaises(AttributeError): message_descriptor.fields_by_name = None with self.assertRaises(TypeError): message_descriptor.fields_by_name['Another'] = None with self.assertRaises(TypeError): message_descriptor.fields.append(None) + with self.assertRaises(AttributeError): + field_descriptor.containing_type = message_descriptor + with self.assertRaises(AttributeError): + file_descriptor.has_options = False + with self.assertRaises(AttributeError): + field_descriptor.has_options = False + with self.assertRaises(AttributeError): + oneof_descriptor.has_options = False + with self.assertRaises(AttributeError): + enum_descriptor.has_options = False + with self.assertRaises(AttributeError) as e: + message_descriptor.has_options = True + self.assertEqual('attribute is not writable: has_options', + str(e.exception)) class NewDescriptorTest(DescriptorTest): @@ -442,6 +479,12 @@ class GeneratedDescriptorTest(unittest.TestCase): self.CheckDescriptorMapping(message_descriptor.fields_by_name) self.CheckDescriptorMapping(message_descriptor.fields_by_number) self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name) + self.CheckDescriptorMapping(message_descriptor.enum_types_by_name) + self.CheckDescriptorMapping(message_descriptor.enum_values_by_name) + self.CheckDescriptorMapping(message_descriptor.oneofs_by_name) + self.CheckDescriptorMapping(message_descriptor.enum_types[0].values_by_name) + # Test extension range + self.assertEqual(message_descriptor.extension_ranges, []) def CheckFieldDescriptor(self, field_descriptor): # Basic properties @@ -450,6 +493,7 @@ class GeneratedDescriptorTest(unittest.TestCase): self.assertEqual(field_descriptor.full_name, 'protobuf_unittest.TestAllTypes.optional_int32') self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes') + self.assertEqual(field_descriptor.file, unittest_pb2.DESCRIPTOR) # Test equality and hashability self.assertEqual(field_descriptor, field_descriptor) self.assertEqual( @@ -461,32 +505,73 @@ class GeneratedDescriptorTest(unittest.TestCase): field_descriptor) self.assertIn(field_descriptor, [field_descriptor]) self.assertIn(field_descriptor, {field_descriptor: None}) + self.assertEqual(None, field_descriptor.extension_scope) + self.assertEqual(None, field_descriptor.enum_type) + if api_implementation.Type() == 'cpp': + # For test coverage only + self.assertEqual(field_descriptor.id, field_descriptor.id) def CheckDescriptorSequence(self, sequence): # Verifies that a property like 'messageDescriptor.fields' has all the # properties of an immutable abc.Sequence. + self.assertNotEqual(sequence, + unittest_pb2.TestAllExtensions.DESCRIPTOR.fields) + self.assertNotEqual(sequence, []) + self.assertNotEqual(sequence, 1) + self.assertFalse(sequence == 1) # Only for cpp test coverage + self.assertEqual(sequence, sequence) + expected_list = list(sequence) + self.assertEqual(expected_list, sequence) self.assertGreater(len(sequence), 0) # Sized - self.assertEqual(len(sequence), len(list(sequence))) # Iterable + self.assertEqual(len(sequence), len(expected_list)) # Iterable + self.assertEqual(sequence[len(sequence) -1], sequence[-1]) item = sequence[0] self.assertEqual(item, sequence[0]) self.assertIn(item, sequence) # Container self.assertEqual(sequence.index(item), 0) self.assertEqual(sequence.count(item), 1) + other_item = unittest_pb2.NestedTestAllTypes.DESCRIPTOR.fields[0] + self.assertNotIn(other_item, sequence) + self.assertEqual(sequence.count(other_item), 0) + self.assertRaises(ValueError, sequence.index, other_item) + self.assertRaises(ValueError, sequence.index, []) reversed_iterator = reversed(sequence) self.assertEqual(list(reversed_iterator), list(sequence)[::-1]) self.assertRaises(StopIteration, next, reversed_iterator) + expected_list[0] = 'change value' + self.assertNotEqual(expected_list, sequence) + # TODO(jieluo): Change __repr__ support for DescriptorSequence. + if api_implementation.Type() == 'python': + self.assertEqual(str(list(sequence)), str(sequence)) + else: + self.assertEqual(str(sequence)[0], '<') def CheckDescriptorMapping(self, mapping): # Verifies that a property like 'messageDescriptor.fields' has all the # properties of an immutable abc.Mapping. + self.assertNotEqual( + mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name) + self.assertNotEqual(mapping, {}) + self.assertNotEqual(mapping, 1) + self.assertFalse(mapping == 1) # Only for cpp test coverage + excepted_dict = dict(mapping.items()) + self.assertEqual(mapping, excepted_dict) + self.assertEqual(mapping, mapping) self.assertGreater(len(mapping), 0) # Sized - self.assertEqual(len(mapping), len(list(mapping))) # Iterable + self.assertEqual(len(mapping), len(excepted_dict)) # Iterable if sys.version_info >= (3,): key, item = next(iter(mapping.items())) else: key, item = mapping.items()[0] self.assertIn(key, mapping) # Container self.assertEqual(mapping.get(key), item) + with self.assertRaises(TypeError): + mapping.get() + # TODO(jieluo): Fix python and cpp extension diff. + if api_implementation.Type() == 'python': + self.assertRaises(TypeError, mapping.get, []) + else: + self.assertEqual(None, mapping.get([])) # keys(), iterkeys() &co item = (next(iter(mapping.keys())), next(iter(mapping.values()))) self.assertEqual(item, next(iter(mapping.items()))) @@ -497,6 +582,18 @@ class GeneratedDescriptorTest(unittest.TestCase): CheckItems(mapping.keys(), mapping.iterkeys()) CheckItems(mapping.values(), mapping.itervalues()) CheckItems(mapping.items(), mapping.iteritems()) + excepted_dict[key] = 'change value' + self.assertNotEqual(mapping, excepted_dict) + del excepted_dict[key] + excepted_dict['new_key'] = 'new' + self.assertNotEqual(mapping, excepted_dict) + self.assertRaises(KeyError, mapping.__getitem__, 'key_error') + self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1) + # TODO(jieluo): Add __repr__ support for DescriptorMapping. + if api_implementation.Type() == 'python': + self.assertEqual(len(str(dict(mapping.items()))), len(str(mapping))) + else: + self.assertEqual(str(mapping)[0], '<') def testDescriptor(self): message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -506,13 +603,26 @@ class GeneratedDescriptorTest(unittest.TestCase): field_descriptor = message_descriptor.fields_by_camelcase_name[ 'optionalInt32'] self.CheckFieldDescriptor(field_descriptor) + enum_descriptor = unittest_pb2.DESCRIPTOR.enum_types_by_name[ + 'ForeignEnum'] + self.assertEqual(None, enum_descriptor.containing_type) + # Test extension range + self.assertEqual( + unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges, + [(1, 536870912)]) + self.assertEqual( + unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges, + [(42, 43), (4143, 4244), (65536, 536870912)]) def testCppDescriptorContainer(self): - # Check that the collection is still valid even if the parent disappeared. - enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum'] - values = enum.values - del enum - self.assertEqual('FOO', values[0].name) + containing_file = unittest_pb2.DESCRIPTOR + self.CheckDescriptorSequence(containing_file.dependencies) + self.CheckDescriptorMapping(containing_file.message_types_by_name) + self.CheckDescriptorMapping(containing_file.enum_types_by_name) + self.CheckDescriptorMapping(containing_file.services_by_name) + self.CheckDescriptorMapping(containing_file.extensions_by_name) + self.CheckDescriptorMapping( + unittest_pb2.TestNestedExtension.DESCRIPTOR.extensions_by_name) def testCppDescriptorContainer_Iterator(self): # Same test with the iterator @@ -526,6 +636,18 @@ class GeneratedDescriptorTest(unittest.TestCase): self.assertEqual(service_descriptor.name, 'TestService') self.assertEqual(service_descriptor.methods[0].name, 'Foo') self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR) + self.assertEqual(service_descriptor.index, 0) + self.CheckDescriptorMapping(service_descriptor.methods_by_name) + + def testOneofDescriptor(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field'] + self.assertFalse(oneof_descriptor.has_options) + self.assertEqual(message_descriptor, oneof_descriptor.containing_type) + self.assertEqual('oneof_field', oneof_descriptor.name) + self.assertEqual('protobuf_unittest.TestAllTypes.oneof_field', + oneof_descriptor.full_name) + self.assertEqual(0, oneof_descriptor.index) class DescriptorCopyToProtoTest(unittest.TestCase): @@ -663,49 +785,64 @@ class DescriptorCopyToProtoTest(unittest.TestCase): descriptor_pb2.DescriptorProto, TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) - # Disable this test so we can make changes to the proto file. - # TODO(xiaofeng): Enable this test after cl/55530659 is submitted. - # - # def testCopyToProto_FileDescriptor(self): - # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" - # name: 'google/protobuf/unittest_import.proto' - # package: 'protobuf_unittest_import' - # dependency: 'google/protobuf/unittest_import_public.proto' - # message_type: < - # name: 'ImportMessage' - # field: < - # name: 'd' - # number: 1 - # label: 1 # Optional - # type: 5 # TYPE_INT32 - # > - # > - # """ + - # """enum_type: < - # name: 'ImportEnum' - # value: < - # name: 'IMPORT_FOO' - # number: 7 - # > - # value: < - # name: 'IMPORT_BAR' - # number: 8 - # > - # value: < - # name: 'IMPORT_BAZ' - # number: 9 - # > - # > - # options: < - # java_package: 'com.google.protobuf.test' - # optimize_for: 1 # SPEED - # > - # public_dependency: 0 - # """) - # self._InternalTestCopyToProto( - # unittest_import_pb2.DESCRIPTOR, - # descriptor_pb2.FileDescriptorProto, - # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + def testCopyToProto_FileDescriptor(self): + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + name: 'google/protobuf/unittest_import.proto' + package: 'protobuf_unittest_import' + dependency: 'google/protobuf/unittest_import_public.proto' + message_type: < + name: 'ImportMessage' + field: < + name: 'd' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + > + > + """ + + """enum_type: < + name: 'ImportEnum' + value: < + name: 'IMPORT_FOO' + number: 7 + > + value: < + name: 'IMPORT_BAR' + number: 8 + > + value: < + name: 'IMPORT_BAZ' + number: 9 + > + > + enum_type: < + name: 'ImportEnumForMap' + value: < + name: 'UNKNOWN' + number: 0 + > + value: < + name: 'FOO' + number: 1 + > + value: < + name: 'BAR' + number: 2 + > + > + options: < + java_package: 'com.google.protobuf.test' + optimize_for: 1 # SPEED + """ + + """ + cc_enable_arenas: true + > + public_dependency: 0 + """) + self._InternalTestCopyToProto( + unittest_import_pb2.DESCRIPTOR, + descriptor_pb2.FileDescriptorProto, + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) def testCopyToProto_ServiceDescriptor(self): TEST_SERVICE_ASCII = """ @@ -721,12 +858,47 @@ class DescriptorCopyToProtoTest(unittest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - # TODO(rocking): enable this test after the proto descriptor change is - # checked in. - #self._InternalTestCopyToProto( - # unittest_pb2.TestService.DESCRIPTOR, - # descriptor_pb2.ServiceDescriptorProto, - # TEST_SERVICE_ASCII) + self._InternalTestCopyToProto( + unittest_pb2.TestService.DESCRIPTOR, + descriptor_pb2.ServiceDescriptorProto, + TEST_SERVICE_ASCII) + + @unittest.skipIf( + api_implementation.Type() == 'python', + 'It is not implemented in python.') + # TODO(jieluo): Add support for pure python or remove in c extension. + def testCopyToProto_MethodDescriptor(self): + expected_ascii = """ + name: 'Foo' + input_type: '.protobuf_unittest.FooRequest' + output_type: '.protobuf_unittest.FooResponse' + """ + method_descriptor = unittest_pb2.TestService.DESCRIPTOR.FindMethodByName( + 'Foo') + self._InternalTestCopyToProto( + method_descriptor, + descriptor_pb2.MethodDescriptorProto, + expected_ascii) + + @unittest.skipIf( + api_implementation.Type() == 'python', + 'Pure python does not raise error.') + # TODO(jieluo): Fix pure python to check with the proto type. + def testCopyToProto_TypeError(self): + file_proto = descriptor_pb2.FileDescriptorProto() + self.assertRaises(TypeError, + unittest_pb2.TestEmptyMessage.DESCRIPTOR.CopyToProto, + file_proto) + self.assertRaises(TypeError, + unittest_pb2.ForeignEnum.DESCRIPTOR.CopyToProto, + file_proto) + self.assertRaises(TypeError, + unittest_pb2.TestService.DESCRIPTOR.CopyToProto, + file_proto) + proto = descriptor_pb2.DescriptorProto() + self.assertRaises(TypeError, + unittest_import_pb2.DESCRIPTOR.CopyToProto, + proto) class MakeDescriptorTest(unittest.TestCase): @@ -774,6 +946,9 @@ class MakeDescriptorTest(unittest.TestCase): result.nested_types[0].enum_types[0]) self.assertFalse(result.has_options) self.assertFalse(result.fields[0].has_options) + if api_implementation.Type() == 'cpp': + with self.assertRaises(AttributeError): + result.fields[0].has_options = False def testMakeDescriptorWithUnsignedIntField(self): file_descriptor_proto = descriptor_pb2.FileDescriptorProto() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index ebec42e5..8c6a1189 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -819,7 +819,7 @@ def MapEncoder(field_descriptor): encode_message = MessageEncoder(field_descriptor.number, False, False) def EncodeField(write, value, deterministic): - value_keys = sorted(value.keys()) if deterministic else value.keys() + value_keys = sorted(value.keys()) if deterministic else value for key in value_keys: entry_msg = message_type._concrete_class(key=key, value=value[key]) encode_message(write, entry_msg, deterministic) diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 077b64db..b2cf7622 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -159,6 +159,16 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse(text, parsed_message) self.assertEqual(message, parsed_message) + def testUnknownEnumToJsonError(self): + message = json_format_proto3_pb2.TestMessage() + message.enum_value = 999 + # TODO(jieluo): should accept numeric unknown enum for proto3. + with self.assertRaises(json_format.SerializeToJsonError) as e: + json_format.MessageToJson(message) + self.assertEqual(str(e.exception), + 'Enum field contains an integer value which can ' + 'not mapped to an enum value.') + def testExtensionToJsonAndBack(self): message = unittest_mset_pb2.TestMessageSetContainer() ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension @@ -172,6 +182,10 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse(message_text, parsed_message) self.assertEqual(message, parsed_message) + def testExtensionErrors(self): + self.CheckError('{"[extensionField]": {}}', + 'Message type proto3.TestMessage does not have extensions') + def testExtensionToDictAndBack(self): message = unittest_mset_pb2.TestMessageSetContainer() ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension @@ -294,7 +308,18 @@ class JsonFormatTest(JsonFormatBase): self.assertEqual(message.int32_value, 1) def testMapFields(self): - message = json_format_proto3_pb2.TestMap() + message = json_format_proto3_pb2.TestNestedMap() + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"boolMap": {},' + '"int32Map": {},' + '"int64Map": {},' + '"uint32Map": {},' + '"uint64Map": {},' + '"stringMap": {},' + '"mapMap": {}' + '}')) message.bool_map[True] = 1 message.bool_map[False] = 2 message.int32_map[1] = 2 @@ -307,17 +332,19 @@ class JsonFormatTest(JsonFormatBase): message.uint64_map[2] = 3 message.string_map['1'] = 2 message.string_map['null'] = 3 + message.map_map['1'].bool_map[True] = 3 self.assertEqual( - json.loads(json_format.MessageToJson(message, True)), + json.loads(json_format.MessageToJson(message, False)), json.loads('{' '"boolMap": {"false": 2, "true": 1},' '"int32Map": {"1": 2, "2": 3},' '"int64Map": {"1": 2, "2": 3},' '"uint32Map": {"1": 2, "2": 3},' '"uint64Map": {"1": 2, "2": 3},' - '"stringMap": {"1": 2, "null": 3}' + '"stringMap": {"1": 2, "null": 3},' + '"mapMap": {"1": {"boolMap": {"true": 3}}}' '}')) - parsed_message = json_format_proto3_pb2.TestMap() + parsed_message = json_format_proto3_pb2.TestNestedMap() self.CheckParseBack(message, parsed_message) def testOneofFields(self): @@ -703,6 +730,9 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse, '{"repeatedInt32Value":[1, null]}', parsed_message) + self.CheckError('{"repeatedMessageValue":[null]}', + 'Failed to parse repeatedMessageValue field: null is not' + ' allowed to be used as an element in a repeated field.') def testNanFloat(self): message = json_format_proto3_pb2.TestMessage() @@ -727,6 +757,11 @@ class JsonFormatTest(JsonFormatBase): '{"enumValue": "baz"}', 'Failed to parse enumValue field: Invalid enum value baz ' 'for enum type proto3.EnumType.') + # TODO(jieluo): fix json format to accept numeric unknown enum for proto3. + self.CheckError( + '{"enumValue": 12345}', + 'Failed to parse enumValue field: Invalid enum value 12345 ' + 'for enum type proto3.EnumType.') def testParseBadIdentifer(self): self.CheckError('{int32Value: 1}', @@ -799,6 +834,11 @@ class JsonFormatTest(JsonFormatBase): self.CheckError('{"bytesValue": "AQI*"}', 'Failed to parse bytesValue field: Incorrect padding.') + def testInvalidRepeated(self): + self.CheckError('{"repeatedInt32Value": 12345}', + (r'Failed to parse repeatedInt32Value field: repeated field' + r' repeatedInt32Value must be in \[\] which is 12345.')) + def testInvalidMap(self): message = json_format_proto3_pb2.TestMap() text = '{"int32Map": {"null": 2, "2": 3}}' @@ -824,6 +864,12 @@ class JsonFormatTest(JsonFormatBase): json_format.ParseError, 'Failed to load JSON: duplicate key a', json_format.Parse, text, message) + text = r'{"stringMap": 0}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse stringMap field: Map field string_map must be ' + 'in a dict which is 0.', + json_format.Parse, text, message) def testInvalidTimestamp(self): message = json_format_proto3_pb2.TestTimestamp() @@ -911,6 +957,12 @@ class JsonFormatTest(JsonFormatBase): json_format.MessageToJson(message)) self.assertEqual('{\n "int32_value": 12345\n}', json_format.MessageToJson(message, False, True)) + # When including_default_value_fields is True. + message = json_format_proto3_pb2.TestTimestamp() + self.assertEqual('{\n "repeatedValue": []\n}', + json_format.MessageToJson(message, True, False)) + self.assertEqual('{\n "repeated_value": []\n}', + json_format.MessageToJson(message, True, True)) # Parsers accept both original proto field names and lowerCamelCase names. message = json_format_proto3_pb2.TestMessage() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 4caa2443..a1b6bb81 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -40,6 +40,7 @@ except ImportError: import unittest from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database @@ -130,6 +131,21 @@ class MessageFactoryTest(unittest.TestCase): msg1.Extensions[ext2] = 'test2' self.assertEqual('test1', msg1.Extensions[ext1]) self.assertEqual('test2', msg1.Extensions[ext2]) + self.assertEqual(None, + msg1.Extensions._FindExtensionByNumber(12321)) + if api_implementation.Type() == 'cpp': + # TODO(jieluo): Fix len to return the correct value. + # self.assertEqual(2, len(msg1.Extensions)) + self.assertEqual(len(msg1.Extensions), len(msg1.Extensions)) + self.assertRaises(TypeError, + msg1.Extensions._FindExtensionByName, 0) + self.assertRaises(TypeError, + msg1.Extensions._FindExtensionByNumber, '') + else: + self.assertEqual(None, + msg1.Extensions._FindExtensionByName(0)) + self.assertEqual(None, + msg1.Extensions._FindExtensionByNumber('')) def testDuplicateExtensionNumber(self): pool = descriptor_pool.DescriptorPool() @@ -183,7 +199,14 @@ class MessageFactoryTest(unittest.TestCase): with self.assertRaises(Exception) as cm: factory.GetMessages([f.name]) - self.assertIsInstance(cm.exception, (AssertionError, ValueError)) + self.assertIn(str(cm.exception), + ['Extensions ' + '"google.protobuf.python.internal.Duplicate.extension_field" and' + ' "google.protobuf.python.internal.Extension.extension_field"' + ' both try to extend message type' + ' "google.protobuf.python.internal.Container"' + ' with field number 2.', + 'Double registration of Extensions']) if __name__ == '__main__': diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 29a515b2..4622f10f 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -53,10 +53,15 @@ import six import sys try: - import unittest2 as unittest #PY26 + import unittest2 as unittest # PY26 except ImportError: import unittest +try: + cmp # Python 2 +except NameError: + cmp = lambda x, y: (x > y) - (x < y) # Python 3 +from google.protobuf import map_proto2_unittest_pb2 from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 @@ -65,6 +70,7 @@ from google.protobuf import descriptor_pool from google.protobuf import message_factory from google.protobuf import text_format from google.protobuf.internal import api_implementation +from google.protobuf.internal import encoder from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks @@ -136,6 +142,18 @@ class MessageTest(BaseTestCase): golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) + def testParseErrors(self, message_module): + msg = message_module.TestAllTypes() + self.assertRaises(TypeError, msg.FromString, 0) + self.assertRaises(Exception, msg.FromString, '0') + # TODO(jieluo): Fix cpp extension to check unexpected end-group tag. + # b/27494216 + if api_implementation.Type() == 'python': + end_tag = encoder.TagBytes(1, 4) + with self.assertRaises(message.DecodeError) as context: + msg.FromString(end_tag) + self.assertEqual('Unexpected end-group tag.', str(context.exception)) + def testDeterminismParameters(self, message_module): # This message is always deterministically serialized, even if determinism # is disabled, so we can use it to verify that all the determinism @@ -605,6 +623,13 @@ class MessageTest(BaseTestCase): self.assertIsInstance(m.repeated_nested_message, collections.MutableSequence) + def testRepeatedFieldsNotHashable(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(TypeError): + hash(m.repeated_int32) + with self.assertRaises(TypeError): + hash(m.repeated_nested_message) + def testRepeatedFieldInsideNestedMessage(self, message_module): m = message_module.NestedTestAllTypes() m.payload.repeated_int32.extend([]) @@ -622,6 +647,7 @@ class MessageTest(BaseTestCase): def testOneofGetCaseNonexistingField(self, message_module): m = message_module.TestAllTypes() self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + self.assertRaises(Exception, m.WhichOneof, 0) def testOneofDefaultValues(self, message_module): m = message_module.TestAllTypes() @@ -997,6 +1023,8 @@ class MessageTest(BaseTestCase): m = message_module.TestAllTypes() with self.assertRaises(IndexError) as _: m.repeated_nested_message.pop() + with self.assertRaises(TypeError) as _: + m.repeated_nested_message.pop('0') for i in range(5): n = m.repeated_nested_message.add() n.bb = i @@ -1005,6 +1033,39 @@ class MessageTest(BaseTestCase): self.assertEqual(2, m.repeated_nested_message.pop(1).bb) self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) + def testRepeatedCompareWithSelf(self, message_module): + m = message_module.TestAllTypes() + for i in range(5): + m.repeated_int32.insert(i, i) + n = m.repeated_nested_message.add() + n.bb = i + self.assertSequenceEqual(m.repeated_int32, m.repeated_int32) + self.assertEqual(m.repeated_nested_message, m.repeated_nested_message) + + def testReleasedNestedMessages(self, message_module): + """A case that lead to a segfault when a message detached from its parent + container has itself a child container. + """ + m = message_module.NestedTestAllTypes() + m = m.repeated_child.add() + m = m.child + m = m.repeated_child.add() + self.assertEqual(m.payload.optional_int32, 0) + + def testSetRepeatedComposite(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(AttributeError): + m.repeated_int32 = [] + m.repeated_int32.append(1) + if api_implementation.Type() == 'cpp': + # For test coverage: cpp has a different path if composite + # field is in cache + with self.assertRaises(TypeError): + m.repeated_int32 = [] + else: + with self.assertRaises(AttributeError): + m.repeated_int32 = [] + # Class to test proto2-only features (required, extensions, etc.) class Proto2Test(BaseTestCase): @@ -1057,18 +1118,46 @@ class Proto2Test(BaseTestCase): self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) - # TODO(tibell): The C++ implementations actually allows assignment - # of unknown enum values to *scalar* fields (but not repeated - # fields). Once checked enum fields becomes the default in the - # Python implementation, the C++ implementation should follow suit. def testAssignInvalidEnum(self): - """It should not be possible to assign an invalid enum number to an - enum field.""" + """Assigning an invalid enum number is not allowed in proto2.""" m = unittest_pb2.TestAllTypes() + # Proto2 can not assign unknown enum. with self.assertRaises(ValueError) as _: m.optional_nested_enum = 1234567 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) + # Assignment is a different code path than append for the C++ impl. + m.repeated_nested_enum.append(2) + m.repeated_nested_enum[0] = 2 + with self.assertRaises(ValueError): + m.repeated_nested_enum[0] = 123456 + + # Unknown enum value can be parsed but is ignored. + m2 = unittest_proto3_arena_pb2.TestAllTypes() + m2.optional_nested_enum = 1234567 + m2.repeated_nested_enum.append(7654321) + serialized = m2.SerializeToString() + + m3 = unittest_pb2.TestAllTypes() + m3.ParseFromString(serialized) + self.assertFalse(m3.HasField('optional_nested_enum')) + # 1 is the default value for optional_nested_enum. + self.assertEqual(1, m3.optional_nested_enum) + self.assertEqual(0, len(m3.repeated_nested_enum)) + m2.Clear() + m2.ParseFromString(m3.SerializeToString()) + self.assertEqual(1234567, m2.optional_nested_enum) + self.assertEqual(7654321, m2.repeated_nested_enum[0]) + + def testUnknownEnumMap(self): + m = map_proto2_unittest_pb2.TestEnumMap() + m.known_map_field[123] = 0 + with self.assertRaises(ValueError): + m.unknown_map_field[1] = 123 + + def testExtensionsErrors(self): + msg = unittest_pb2.TestAllTypes() + self.assertRaises(AttributeError, getattr, msg, 'Extensions') def testGoldenExtensions(self): golden_data = test_util.GoldenFileData('golden_message') @@ -1293,6 +1382,7 @@ class Proto3Test(BaseTestCase): """Assigning an unknown enum value is allowed and preserves the value.""" m = unittest_proto3_arena_pb2.TestAllTypes() + # Proto3 can assign unknown enums. m.optional_nested_enum = 1234567 self.assertEqual(1234567, m.optional_nested_enum) m.repeated_nested_enum.append(22334455) @@ -1307,18 +1397,10 @@ class Proto3Test(BaseTestCase): self.assertEqual(1234567, m2.optional_nested_enum) self.assertEqual(7654321, m2.repeated_nested_enum[0]) - # ParseFromString in Proto2 should accept unknown enums too. - m3 = unittest_pb2.TestAllTypes() - m3.ParseFromString(serialized) - m2.Clear() - m2.ParseFromString(m3.SerializeToString()) - self.assertEqual(1234567, m2.optional_nested_enum) - self.assertEqual(7654321, m2.repeated_nested_enum[0]) - # Map isn't really a proto3-only feature. But there is no proto2 equivalent # of google/protobuf/map_unittest.proto right now, so it's not easy to # test both with the same test like we do for the other proto2/proto3 tests. - # (google/protobuf/map_protobuf_unittest.proto is very different in the set + # (google/protobuf/map_proto2_unittest.proto is very different in the set # of messages and fields it contains). def testScalarMapDefaults(self): msg = map_unittest_pb2.TestMap() @@ -1379,12 +1461,21 @@ class Proto3Test(BaseTestCase): msg.map_int32_int32[5] = 15 self.assertEqual(15, msg.map_int32_int32.get(5)) + self.assertEqual(15, msg.map_int32_int32.get(5)) + with self.assertRaises(TypeError): + msg.map_int32_int32.get('') self.assertIsNone(msg.map_int32_foreign_message.get(5)) self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) submsg = msg.map_int32_foreign_message[5] self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) + # TODO(jieluo): Fix python and cpp extension diff. + if api_implementation.Type() == 'cpp': + with self.assertRaises(TypeError): + msg.map_int32_foreign_message.get('') + else: + self.assertEqual(None, msg.map_int32_foreign_message.get('')) def testScalarMap(self): msg = map_unittest_pb2.TestMap() @@ -1396,8 +1487,13 @@ class Proto3Test(BaseTestCase): msg.map_int64_int64[-2**33] = -2**34 msg.map_uint32_uint32[123] = 456 msg.map_uint64_uint64[2**33] = 2**34 + msg.map_int32_float[2] = 1.2 + msg.map_int32_double[1] = 3.3 msg.map_string_string['abc'] = '123' + msg.map_bool_bool[True] = True msg.map_int32_enum[888] = 2 + # Unknown numeric enum is supported in proto3. + msg.map_int32_enum[123] = 456 self.assertEqual([], msg.FindInitializationErrors()) @@ -1431,8 +1527,24 @@ class Proto3Test(BaseTestCase): self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) self.assertEqual(456, msg2.map_uint32_uint32[123]) self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + self.assertAlmostEqual(1.2, msg.map_int32_float[2]) + self.assertEqual(3.3, msg.map_int32_double[1]) self.assertEqual('123', msg2.map_string_string['abc']) + self.assertEqual(True, msg2.map_bool_bool[True]) self.assertEqual(2, msg2.map_int32_enum[888]) + self.assertEqual(456, msg2.map_int32_enum[123]) + # TODO(jieluo): Add cpp extension support. + if api_implementation.Type() == 'python': + self.assertEqual('{-123: -456}', + str(msg2.map_int32_int32)) + + def testMapEntryAlwaysSerialized(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[0] = 0 + msg.map_string_string[''] = '' + self.assertEqual(msg.ByteSize(), 12) + self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00', + msg.SerializeToString()) def testStringUnicodeConversionInMap(self): msg = map_unittest_pb2.TestMap() @@ -1485,6 +1597,11 @@ class Proto3Test(BaseTestCase): self.assertIn(123, msg2.map_int32_foreign_message) self.assertIn(-456, msg2.map_int32_foreign_message) self.assertEqual(2, len(msg2.map_int32_foreign_message)) + # TODO(jieluo): Fix text format for message map. + # TODO(jieluo): Add cpp extension support. + if api_implementation.Type() == 'python': + self.assertEqual(15, + len(str(msg2.map_int32_foreign_message))) def testNestedMessageMapItemDelete(self): msg = map_unittest_pb2.TestMap() @@ -1568,6 +1685,12 @@ class Proto3Test(BaseTestCase): del msg2.map_int32_foreign_message[222] self.assertFalse(222 in msg2.map_int32_foreign_message) + if api_implementation.Type() == 'cpp': + with self.assertRaises(TypeError): + del msg2.map_int32_foreign_message[''] + else: + with self.assertRaises(KeyError): + del msg2.map_int32_foreign_message[''] def testMergeFromBadType(self): msg = map_unittest_pb2.TestMap() @@ -1702,6 +1825,54 @@ class Proto3Test(BaseTestCase): matching_dict = {2: 4, 3: 6, 4: 8} self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) + def testPython2Map(self): + if sys.version_info < (3,): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + msg.map_int32_int32[5] = 10 + map_int32 = msg.map_int32_int32 + self.assertEqual(4, len(map_int32)) + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(msg.SerializeToString()) + + def CheckItems(seq, iterator): + self.assertEqual(next(iterator), seq[0]) + self.assertEqual(list(iterator), seq[1:]) + + CheckItems(map_int32.items(), map_int32.iteritems()) + CheckItems(map_int32.keys(), map_int32.iterkeys()) + CheckItems(map_int32.values(), map_int32.itervalues()) + + self.assertEqual(6, map_int32.get(3)) + self.assertEqual(None, map_int32.get(999)) + self.assertEqual(6, map_int32.pop(3)) + self.assertEqual(0, map_int32.pop(3)) + self.assertEqual(3, len(map_int32)) + key, value = map_int32.popitem() + self.assertEqual(2 * key, value) + self.assertEqual(2, len(map_int32)) + map_int32.clear() + self.assertEqual(0, len(map_int32)) + + with self.assertRaises(KeyError): + map_int32.popitem() + + self.assertEqual(0, map_int32.setdefault(2)) + self.assertEqual(1, len(map_int32)) + + map_int32.update(msg2.map_int32_int32) + self.assertEqual(4, len(map_int32)) + + with self.assertRaises(TypeError): + map_int32.update(msg2.map_int32_int32, + msg2.map_int32_int32) + with self.assertRaises(TypeError): + map_int32.update(0) + with self.assertRaises(TypeError): + map_int32.update(value=12) + def testMapItems(self): # Map items used to have strange behaviors when use c extension. Because # [] may reorder the map and invalidate any exsting iterators. @@ -1832,6 +2003,9 @@ class Proto3Test(BaseTestCase): del msg.map_int32_int32[4] self.assertEqual(0, len(msg.map_int32_int32)) + with self.assertRaises(KeyError): + del msg.map_int32_all_types[32] + def testMapsAreMapping(self): msg = map_unittest_pb2.TestMap() self.assertIsInstance(msg.map_int32_int32, collections.Mapping) @@ -1840,6 +2014,14 @@ class Proto3Test(BaseTestCase): self.assertIsInstance(msg.map_int32_foreign_message, collections.MutableMapping) + def testMapsCompare(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[-123] = -456 + self.assertEqual(msg.map_int32_int32, msg.map_int32_int32) + self.assertEqual(msg.map_int32_foreign_message, + msg.map_int32_foreign_message) + self.assertNotEqual(msg.map_int32_int32, 0) + def testMapFindInitializationErrorsSmokeTest(self): msg = map_unittest_pb2.TestMap() msg.map_string_string['abc'] = '123' @@ -1927,8 +2109,9 @@ class PackedFieldTest(BaseTestCase): self.assertEqual(golden_data, message.SerializeToString()) -@unittest.skipIf(api_implementation.Type() != 'cpp', - 'explicit tests of the C++ implementation') +@unittest.skipIf(api_implementation.Type() != 'cpp' or + sys.version_info < (2, 7), + 'explicit tests of the C++ implementation for PY27 and above') class OversizeProtosTest(BaseTestCase): @classmethod diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index c363d843..975e3b4d 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -1010,11 +1010,16 @@ def _AddByteSizeMethod(message_descriptor, cls): return self._cached_byte_size size = 0 - for field_descriptor, field_value in self.ListFields(): - size += field_descriptor._sizer(field_value) - - for tag_bytes, value_bytes in self._unknown_fields: - size += len(tag_bytes) + len(value_bytes) + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + size = descriptor.fields_by_name['key']._sizer(self.key) + size += descriptor.fields_by_name['value']._sizer(self.value) + else: + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) self._cached_byte_size = size self._cached_byte_size_dirty = False @@ -1053,11 +1058,20 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): api_implementation.IsPythonDefaultSerializationDeterministic()) else: deterministic = bool(deterministic) - for field_descriptor, field_value in self.ListFields(): - field_descriptor._encoder(write_bytes, field_value, deterministic) - for tag_bytes, value_bytes in self._unknown_fields: - write_bytes(tag_bytes) - write_bytes(value_bytes) + + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + descriptor.fields_by_name['key']._encoder( + write_bytes, self.key, deterministic) + descriptor.fields_by_name['value']._encoder( + write_bytes, self.value, deterministic) + else: + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value, deterministic) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) cls._InternalSerialize = InternalSerialize @@ -1095,7 +1109,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not is_proto3: + if (not is_proto3 or + api_implementation.GetPythonProto3PreserveUnknownsDefault()): if not unknown_field_list: unknown_field_list = self._unknown_fields = [] unknown_field_list.append( diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 55b0d72e..0306ff46 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -620,6 +620,14 @@ class ReflectionTest(BaseTestCase): self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo') + self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo') + self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo') + # TODO(jieluo): Fix type checking difference for python and c extension + if api_implementation.Type() == 'python': + self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1) + else: + proto.optional_bool = 1.1 def assertIntegerTypes(self, integer_fn): """Verifies setting of scalar integers. @@ -686,8 +694,10 @@ class ReflectionTest(BaseTestCase): self.assertEqual(expected_min, getattr(pb, field_name)) setattr(pb, field_name, expected_max) self.assertEqual(expected_max, getattr(pb, field_name)) - self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) - self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) + self.assertRaises((ValueError, TypeError), setattr, pb, field_name, + expected_min - 1) + self.assertRaises((ValueError, TypeError), setattr, pb, field_name, + expected_max + 1) TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) @@ -696,7 +706,7 @@ class ReflectionTest(BaseTestCase): # A bit of white-box testing since -1 is an int and not a long in C++ and # so goes down a different path. pb = unittest_pb2.TestAllTypes() - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, TypeError)): pb.optional_uint64 = integer_fn(-(1 << 63)) pb = unittest_pb2.TestAllTypes() @@ -720,6 +730,12 @@ class ReflectionTest(BaseTestCase): proto.repeated_int32[0] = 23 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, []) + self.assertRaises(TypeError, proto.repeated_int32.__setitem__, + 'index', 23) + + proto.repeated_string.append('2') + self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10) # Repeated enums tests. #proto.repeated_nested_enum.append(0) @@ -1007,6 +1023,14 @@ class ReflectionTest(BaseTestCase): self.assertEqual(4, len(proto.repeated_nested_message)) self.assertEqual(n1, proto.repeated_nested_message[2]) self.assertEqual(n2, proto.repeated_nested_message[3]) + self.assertRaises(TypeError, + proto.repeated_nested_message.extend, n1) + self.assertRaises(TypeError, + proto.repeated_nested_message.extend, [0]) + wrong_message_type = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, + proto.repeated_nested_message.extend, + [wrong_message_type]) # Test clearing. proto.ClearField('repeated_nested_message') @@ -1018,6 +1042,8 @@ class ReflectionTest(BaseTestCase): self.assertEqual(1, len(proto.repeated_nested_message)) self.assertEqual(23, proto.repeated_nested_message[0].bb) self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) + with self.assertRaises(Exception): + proto.repeated_nested_message[0] = 23 def testRepeatedCompositeRemove(self): proto = unittest_pb2.TestAllTypes() @@ -1642,8 +1668,11 @@ class ReflectionTest(BaseTestCase): proto.SerializeToString() proto.SerializePartialToString() - def assertNotInitialized(self, proto): + def assertNotInitialized(self, proto, error_size=None): + errors = [] self.assertFalse(proto.IsInitialized()) + self.assertFalse(proto.IsInitialized(errors)) + self.assertEqual(error_size, len(errors)) self.assertRaises(message.EncodeError, proto.SerializeToString) # "Partial" serialization doesn't care if message is uninitialized. proto.SerializePartialToString() @@ -1657,7 +1686,7 @@ class ReflectionTest(BaseTestCase): # The case of uninitialized required fields. proto = unittest_pb2.TestRequired() - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 3) proto.a = proto.b = proto.c = 2 self.assertInitialized(proto) @@ -1665,14 +1694,14 @@ class ReflectionTest(BaseTestCase): proto = unittest_pb2.TestRequiredForeign() self.assertInitialized(proto) proto.optional_message.a = 1 - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 2) proto.optional_message.b = 0 proto.optional_message.c = 0 self.assertInitialized(proto) # Uninitialized repeated submessage. message1 = proto.repeated_message.add() - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 3) message1.a = message1.b = message1.c = 0 self.assertInitialized(proto) @@ -1681,11 +1710,11 @@ class ReflectionTest(BaseTestCase): extension = unittest_pb2.TestRequired.multi message1 = proto.Extensions[extension].add() message2 = proto.Extensions[extension].add() - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 6) message1.a = 1 message1.b = 1 message1.c = 1 - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 3) message2.a = 2 message2.b = 2 message2.c = 2 @@ -1695,7 +1724,7 @@ class ReflectionTest(BaseTestCase): proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.TestRequired.single proto.Extensions[extension].a = 1 - self.assertNotInitialized(proto) + self.assertNotInitialized(proto, 2) proto.Extensions[extension].b = 2 proto.Extensions[extension].c = 3 self.assertInitialized(proto) @@ -2154,6 +2183,8 @@ class ByteSizeTest(BaseTestCase): foreign_message_1 = self.proto.repeated_nested_message.add() foreign_message_1.bb = 9 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + repeated_nested_message = copy.deepcopy( + self.proto.repeated_nested_message) # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. del self.proto.repeated_nested_message[0] @@ -2174,6 +2205,16 @@ class ByteSizeTest(BaseTestCase): del self.proto.repeated_nested_message[0] self.assertEqual(0, self.Size()) + self.assertEqual(2, len(repeated_nested_message)) + del repeated_nested_message[0:1] + # TODO(jieluo): Fix cpp extension bug when delete repeated message. + if api_implementation.Type() == 'python': + self.assertEqual(1, len(repeated_nested_message)) + del repeated_nested_message[-1] + # TODO(jieluo): Fix cpp extension bug when delete repeated message. + if api_implementation.Type() == 'python': + self.assertEqual(0, len(repeated_nested_message)) + def testRepeatedGroups(self): # 2-byte START_GROUP plus 2-byte END_GROUP. group_0 = self.proto.repeatedgroup.add() @@ -2190,6 +2231,10 @@ class ByteSizeTest(BaseTestCase): proto.Extensions[extension] = 23 # 1 byte for tag, 1 byte for value. self.assertEqual(2, proto.ByteSize()) + field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ + 'optional_int32'] + with self.assertRaises(KeyError): + proto.Extensions[field] = 23 def testCacheInvalidationForNonrepeatedScalar(self): # Test non-extension. diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index 62900b1d..77239f44 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -82,6 +82,10 @@ class FooUnitTest(unittest.TestCase): service_descriptor = unittest_pb2.TestService.GetDescriptor() srvc.CallMethod(service_descriptor.methods[1], rpc_controller, unittest_pb2.BarRequest(), MyCallback) + self.assertTrue(srvc.GetRequestClass(service_descriptor.methods[1]) is + unittest_pb2.BarRequest) + self.assertTrue(srvc.GetResponseClass(service_descriptor.methods[1]) is + unittest_pb2.BarResponse) self.assertEqual('Method Bar not implemented.', rpc_controller.failure_message) self.assertEqual(None, self.callback_response) diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 269d0e2d..a6e34ef5 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -39,11 +39,15 @@ __author__ = 'robinson@google.com (Will Robinson)' import numbers import operator import os.path -import sys from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 -from google.protobuf import descriptor_pb2 + +try: + long # Python 2 +except NameError: + long = int # Python 3 + # Tests whether the given TestAllTypes message is proto2 or not. # This is used to gate several fields/features that only exist @@ -51,6 +55,7 @@ from google.protobuf import descriptor_pb2 def IsProto2(message): return message.DESCRIPTOR.syntax == "proto2" + def SetAllNonLazyFields(message): """Sets every non-lazy field in the message to a unique value. @@ -128,22 +133,37 @@ def SetAllNonLazyFields(message): message.repeated_string_piece.append(u'224') message.repeated_cord.append(u'225') - # Add a second one of each field. - message.repeated_int32.append(301) - message.repeated_int64.append(302) - message.repeated_uint32.append(303) - message.repeated_uint64.append(304) - message.repeated_sint32.append(305) - message.repeated_sint64.append(306) - message.repeated_fixed32.append(307) - message.repeated_fixed64.append(308) - message.repeated_sfixed32.append(309) - message.repeated_sfixed64.append(310) - message.repeated_float.append(311) - message.repeated_double.append(312) - message.repeated_bool.append(False) - message.repeated_string.append(u'315') - message.repeated_bytes.append(b'316') + # Add a second one of each field and set value by index. + message.repeated_int32.append(0) + message.repeated_int64.append(0) + message.repeated_uint32.append(0) + message.repeated_uint64.append(0) + message.repeated_sint32.append(0) + message.repeated_sint64.append(0) + message.repeated_fixed32.append(0) + message.repeated_fixed64.append(0) + message.repeated_sfixed32.append(0) + message.repeated_sfixed64.append(0) + message.repeated_float.append(0) + message.repeated_double.append(0) + message.repeated_bool.append(True) + message.repeated_string.append(u'0') + message.repeated_bytes.append(b'0') + message.repeated_int32[1] = 301 + message.repeated_int64[1] = 302 + message.repeated_uint32[1] = 303 + message.repeated_uint64[1] = 304 + message.repeated_sint32[1] = 305 + message.repeated_sint64[1] = 306 + message.repeated_fixed32[1] = 307 + message.repeated_fixed64[1] = 308 + message.repeated_sfixed32[1] = 309 + message.repeated_sfixed64[1] = 310 + message.repeated_float[1] = 311 + message.repeated_double[1] = 312 + message.repeated_bool[1] = False + message.repeated_string[1] = u'315' + message.repeated_bytes[1] = b'316' if IsProto2(message): message.repeatedgroup.add().a = 317 @@ -152,7 +172,8 @@ def SetAllNonLazyFields(message): message.repeated_import_message.add().d = 320 message.repeated_lazy_message.add().bb = 327 - message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) + message.repeated_nested_enum[1] = unittest_pb2.TestAllTypes.BAZ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) if IsProto2(message): message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) @@ -707,8 +728,8 @@ class NonStandardInteger(numbers.Integral): NonStandardInteger is the minimal legal specification for a custom Integral. As such, it does not support 0 < x < 5 and it is not hashable. - Note: This is added here instead of relying on numpy or a similar library with - custom integers to limit dependencies. + Note: This is added here instead of relying on numpy or a similar library + with custom integers to limit dependencies. """ def __init__(self, val, error_string_on_conversion=None): @@ -845,4 +866,3 @@ class NonStandardInteger(numbers.Integral): def __repr__(self): return 'NonStandardInteger(%s)' % self.val - diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 424b29cc..1214c3ea 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -35,6 +35,7 @@ __author__ = 'kenton@google.com (Kenton Varda)' +import math import re import six import string @@ -53,8 +54,8 @@ from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import any_test_pb2 as test_extend_any -from google.protobuf.internal import test_util from google.protobuf.internal import message_set_extensions_pb2 +from google.protobuf.internal import test_util from google.protobuf import descriptor_pool from google.protobuf import text_format @@ -371,7 +372,10 @@ class TextFormatTest(TextFormatBase): def testParseInvalidUtf8(self, message_module): message = message_module.TestAllTypes() text = 'repeated_string: "\\xc3\\xc3"' - self.assertRaises(text_format.ParseError, text_format.Parse, text, message) + with self.assertRaises(text_format.ParseError) as e: + text_format.Parse(text, message) + self.assertEqual(e.exception.GetLine(), 1) + self.assertEqual(e.exception.GetColumn(), 28) def testParseSingleWord(self, message_module): message = message_module.TestAllTypes() @@ -784,13 +788,14 @@ class Proto2Tests(TextFormatBase): ' bin: "\xe0"' ' [nested_unknown_ext]: {\n' ' i: 23\n' + ' x: x\n' ' test: "test_string"\n' ' floaty_float: -0.315\n' ' num: -inf\n' ' multiline_str: "abc"\n' ' "def"\n' ' "xyz."\n' - ' [nested_unknown_ext]: <\n' + ' [nested_unknown_ext.ext]: <\n' ' i: 23\n' ' i: 24\n' ' pointfloat: .3\n' @@ -896,6 +901,14 @@ class Proto2Tests(TextFormatBase): self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str) + def testParseBadIdentifier(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { "bb": 1 }') + with self.assertRaises(text_format.ParseError) as e: + text_format.Parse(text, message) + self.assertEqual(str(e.exception), + '1:27 : Expected identifier or number, got "bb".') + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' @@ -1095,6 +1108,19 @@ class Proto3Tests(unittest.TestCase): ' < data: "string" > ' '>') + def testUnknownEnums(self): + message = unittest_proto3_arena_pb2.TestAllTypes() + message2 = unittest_proto3_arena_pb2.TestAllTypes() + message.optional_nested_enum = 999 + text_string = text_format.MessageToString(message) + # TODO(jieluo): proto3 should support numeric unknown enum. + with self.assertRaises(text_format.ParseError) as e: + text_format.Parse(text_string, message2) + self.assertEqual(999, message2.optional_nested_enum) + self.assertEqual(str(e.exception), + '1:23 : Enum type "proto3_arena_unittest.TestAllTypes.' + 'NestedEnum" has no value with number 999.') + def testMergeExpandedAny(self): message = any_test_pb2.TestAny() text = ('any_value {\n' @@ -1180,6 +1206,15 @@ class Proto3Tests(unittest.TestCase): message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) + def testMergeMissingAnyEndToken(self): + message = any_test_pb2.TestAny() + text = ('any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string"\n') + with self.assertRaises(text_format.ParseError) as e: + text_format.Merge(text, message) + self.assertEqual(str(e.exception), '3:11 : Expected "}".') + class TokenizerTest(unittest.TestCase): @@ -1191,7 +1226,7 @@ class TokenizerTest(unittest.TestCase): 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' - 'False_bool: False True_bool: True') + 'False_bool: False True_bool: True X:iNf Y:-inF Z:nAN') tokenizer = text_format.Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -1239,7 +1274,13 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeIdentifier, 'False_bool'), ':', (tokenizer.ConsumeBool, False), (tokenizer.ConsumeIdentifier, 'True_bool'), ':', - (tokenizer.ConsumeBool, True)] + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'X'), ':', + (tokenizer.ConsumeFloat, float('inf')), + (tokenizer.ConsumeIdentifier, 'Y'), ':', + (tokenizer.ConsumeFloat, float('-inf')), + (tokenizer.ConsumeIdentifier, 'Z'), ':', + (tokenizer.ConsumeFloat, float('nan'))] i = 0 while not tokenizer.AtEnd(): @@ -1248,6 +1289,8 @@ class TokenizerTest(unittest.TestCase): token = tokenizer.token self.assertEqual(token, m) tokenizer.NextToken() + elif isinstance(m[1], float) and math.isnan(m[1]): + self.assertTrue(math.isnan(m[0]())) else: self.assertEqual(m[1], m[0]()) i += 1 @@ -1266,10 +1309,15 @@ class TokenizerTest(unittest.TestCase): self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger()) self.assertTrue(tokenizer.AtEnd()) - text = '-0 0' + text = '-0 0 0 1.2' tokenizer = text_format.Tokenizer(text.splitlines()) self.assertEqual(0, tokenizer.ConsumeInteger()) self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(True, tokenizer.TryConsumeInteger()) + self.assertEqual(False, tokenizer.TryConsumeInteger()) + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() + self.assertEqual(1.2, tokenizer.ConsumeFloat()) self.assertTrue(tokenizer.AtEnd()) def testConsumeIntegers(self): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index d614eaa8..9bdb6f27 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -54,10 +54,13 @@ from google.protobuf.internal import type_checkers BaseTestCase = testing_refleaks.BaseTestCase -def SkipIfCppImplementation(func): +# CheckUnknownField() cannot be used by the C++ implementation because +# some protect members are called. It is not a behavior difference +# for python and C++ implementation. +def SkipCheckUnknownFieldIfCppImplementation(func): return unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python')(func) + 'Addtional test for pure python involved protect members')(func) class UnknownFieldsTest(BaseTestCase): @@ -77,11 +80,24 @@ class UnknownFieldsTest(BaseTestCase): # stdout. self.assertTrue(data == self.all_fields_data) - def testSerializeProto3(self): - # Verify that proto3 doesn't preserve unknown fields. + def expectSerializeProto3(self, preserve): message = unittest_proto3_arena_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) - self.assertEqual(0, len(message.SerializeToString())) + if preserve: + self.assertEqual(self.all_fields_data, message.SerializeToString()) + else: + self.assertEqual(0, len(message.SerializeToString())) + + def testSerializeProto3(self): + # Verify that proto3 unknown fields behavior. + default_preserve = (api_implementation + .GetPythonProto3PreserveUnknownsDefault()) + self.assertEqual(False, default_preserve) + self.expectSerializeProto3(default_preserve) + api_implementation.SetPythonProto3PreserveUnknownsDefault( + not default_preserve) + self.expectSerializeProto3(not default_preserve) + api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve) def testByteSize(self): self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) @@ -154,12 +170,13 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # GetUnknownField() checks a detail of the Python implementation, which stores - # unknown fields as serialized strings. It cannot be used by the C++ - # implementation: it's enough to check that the message is correctly - # serialized. + # CheckUnknownField() is an additional Pure Python check which checks + # a detail of unknown fields. It cannot be used by the C++ + # implementation because some protect members are called. + # The test is added for historical reasons. It is not necessary as + # serialized string is checked. - def GetUnknownField(self, name): + def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) @@ -168,42 +185,35 @@ class UnknownFieldsAccessorsTest(BaseTestCase): if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] decoder(value, 0, len(value), self.all_fields, result_dict) - return result_dict[field_descriptor] - - @SkipIfCppImplementation - def testEnum(self): - value = self.GetUnknownField('optional_nested_enum') - self.assertEqual(self.all_fields.optional_nested_enum, value) - - @SkipIfCppImplementation - def testRepeatedEnum(self): - value = self.GetUnknownField('repeated_nested_enum') - self.assertEqual(self.all_fields.repeated_nested_enum, value) - - @SkipIfCppImplementation - def testVarint(self): - value = self.GetUnknownField('optional_int32') - self.assertEqual(self.all_fields.optional_int32, value) - - @SkipIfCppImplementation - def testFixed32(self): - value = self.GetUnknownField('optional_fixed32') - self.assertEqual(self.all_fields.optional_fixed32, value) - - @SkipIfCppImplementation - def testFixed64(self): - value = self.GetUnknownField('optional_fixed64') - self.assertEqual(self.all_fields.optional_fixed64, value) - - @SkipIfCppImplementation - def testLengthDelimited(self): - value = self.GetUnknownField('optional_string') - self.assertEqual(self.all_fields.optional_string, value) - - @SkipIfCppImplementation - def testGroup(self): - value = self.GetUnknownField('optionalgroup') - self.assertEqual(self.all_fields.optionalgroup, value) + self.assertEqual(expected_value, result_dict[field_descriptor]) + + @SkipCheckUnknownFieldIfCppImplementation + def testCheckUnknownFieldValue(self): + # Test enum. + self.CheckUnknownField('optional_nested_enum', + self.all_fields.optional_nested_enum) + # Test repeated enum. + self.CheckUnknownField('repeated_nested_enum', + self.all_fields.repeated_nested_enum) + + # Test varint. + self.CheckUnknownField('optional_int32', + self.all_fields.optional_int32) + # Test fixed32. + self.CheckUnknownField('optional_fixed32', + self.all_fields.optional_fixed32) + + # Test fixed64. + self.CheckUnknownField('optional_fixed64', + self.all_fields.optional_fixed64) + + # Test lengthd elimited. + self.CheckUnknownField('optional_string', + self.all_fields.optional_string) + + # Test group. + self.CheckUnknownField('optionalgroup', + self.all_fields.optionalgroup) def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() @@ -263,12 +273,13 @@ class UnknownEnumValuesTest(BaseTestCase): self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - # GetUnknownField() checks a detail of the Python implementation, which stores - # unknown fields as serialized strings. It cannot be used by the C++ - # implementation: it's enough to check that the message is correctly - # serialized. + # CheckUnknownField() is an additional Pure Python check which checks + # a detail of unknown fields. It cannot be used by the C++ + # implementation because some protect members are called. + # The test is added for historical reasons. It is not necessary as + # serialized string is checked. - def GetUnknownField(self, name): + def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) @@ -278,7 +289,7 @@ class UnknownEnumValuesTest(BaseTestCase): decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ tag_bytes][0] decoder(value, 0, len(value), self.message, result_dict) - return result_dict[field_descriptor] + self.assertEqual(expected_value, result_dict[field_descriptor]) def testUnknownParseMismatchEnumValue(self): just_string = missing_enum_values_pb2.JustString() @@ -294,38 +305,27 @@ class UnknownEnumValuesTest(BaseTestCase): self.assertEqual(missing.optional_nested_enum, 0) def testUnknownEnumValue(self): - if api_implementation.Type() == 'cpp': - # The CPP implementation of protos (wrongly) allows unknown enum values - # for proto2. - self.assertTrue(self.missing_message.HasField('optional_nested_enum')) - self.assertEqual(self.message.optional_nested_enum, - self.missing_message.optional_nested_enum) - else: - # On the other hand, the Python implementation considers unknown values - # as unknown fields. This is the correct behavior. - self.assertFalse(self.missing_message.HasField('optional_nested_enum')) - value = self.GetUnknownField('optional_nested_enum') - self.assertEqual(self.message.optional_nested_enum, value) - self.missing_message.ClearField('optional_nested_enum') self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + self.assertEqual(self.missing_message.optional_nested_enum, 2) + # Clear does not do anything. + serialized = self.missing_message.SerializeToString() + self.missing_message.ClearField('optional_nested_enum') + self.assertEqual(self.missing_message.SerializeToString(), serialized) def testUnknownRepeatedEnumValue(self): - if api_implementation.Type() == 'cpp': - # For repeated enums, both implementations agree. - self.assertEqual([], self.missing_message.repeated_nested_enum) - else: - self.assertEqual([], self.missing_message.repeated_nested_enum) - value = self.GetUnknownField('repeated_nested_enum') - self.assertEqual(self.message.repeated_nested_enum, value) + self.assertEqual([], self.missing_message.repeated_nested_enum) def testUnknownPackedEnumValue(self): - if api_implementation.Type() == 'cpp': - # For repeated enums, both implementations agree. - self.assertEqual([], self.missing_message.packed_nested_enum) - else: - self.assertEqual([], self.missing_message.packed_nested_enum) - value = self.GetUnknownField('packed_nested_enum') - self.assertEqual(self.message.packed_nested_enum, value) + self.assertEqual([], self.missing_message.packed_nested_enum) + + @SkipCheckUnknownFieldIfCppImplementation + def testCheckUnknownFieldValueForEnum(self): + self.CheckUnknownField('optional_nested_enum', + self.message.optional_nested_enum) + self.CheckUnknownField('repeated_nested_enum', + self.message.repeated_nested_enum) + self.CheckUnknownField('packed_nested_enum', + self.message.packed_nested_enum) def testRoundTrip(self): new_message = missing_enum_values_pb2.TestEnumValues() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d0c7ffda..57b96998 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -473,7 +473,7 @@ def _IsValidPath(message_descriptor, path): parts = path.split('.') last = parts.pop() for name in parts: - field = message_descriptor.fields_by_name[name] + field = message_descriptor.fields_by_name.get(name) if (field is None or field.label == FieldDescriptor.LABEL_REPEATED or field.type != FieldDescriptor.TYPE_MESSAGE): @@ -698,6 +698,12 @@ def _SetStructValue(struct_value, value): struct_value.string_value = value elif isinstance(value, _INT_OR_FLOAT): struct_value.number_value = value + elif isinstance(value, dict): + struct_value.struct_value.Clear() + struct_value.struct_value.update(value) + elif isinstance(value, list): + struct_value.list_value.Clear() + struct_value.list_value.extend(value) else: raise ValueError('Unexpected type') @@ -733,13 +739,21 @@ class Struct(object): def get_or_create_list(self, key): """Returns a list for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('list_value'): + # Clear will mark list_value modified which will indeed create a list. + self.fields[key].list_value.Clear() return self.fields[key].list_value def get_or_create_struct(self, key): """Returns a struct for this key, creating if it didn't exist already.""" + if not self.fields[key].HasField('struct_value'): + # Clear will mark struct_value modified which will indeed create a struct. + self.fields[key].struct_value.Clear() return self.fields[key].struct_value - # TODO(haberman): allow constructing/merging from dict. + def update(self, dictionary): # pylint: disable=invalid-name + for key, value in dictionary.items(): + _SetStructValue(self.fields[key], value) class ListValue(object): @@ -768,11 +782,17 @@ class ListValue(object): def add_struct(self): """Appends and returns a struct value as the next value in the list.""" - return self.values.add().struct_value + struct_value = self.values.add().struct_value + # Clear will mark struct_value modified which will indeed create a struct. + struct_value.Clear() + return struct_value def add_list(self): """Appends and returns a list value as the next value in the list.""" - return self.values.add().list_value + list_value = self.values.add().list_value + # Clear will mark list_value modified which will indeed create a list. + list_value.Clear() + return list_value WKTBASES = { diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 123a537c..70975da1 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -105,6 +105,10 @@ class TimeUtilTest(TimeUtilTestBase): self.assertEqual(8 * 3600, message.seconds) self.assertEqual(0, message.nanos) + # It is not easy to check with current time. For test coverage only. + message.GetCurrentTime() + self.assertNotEqual(8 * 3600, message.seconds) + def testDurationSerializeAndParse(self): message = duration_pb2.Duration() # Generated output should contain 3, 6, or 9 fractional digits. @@ -268,6 +272,17 @@ class TimeUtilTest(TimeUtilTestBase): def testInvalidTimestamp(self): message = timestamp_pb2.Timestamp() self.assertRaisesRegexp( + well_known_types.ParseError, + 'Failed to parse timestamp: missing valid timezone offset.', + message.FromJsonString, + '') + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Failed to parse timestamp: invalid trailing data ' + '1970-01-01T00:00:01Ztrail.', + message.FromJsonString, + '1970-01-01T00:00:01Ztrail') + self.assertRaisesRegexp( ValueError, 'time data \'10000-01-01T00:00:00\' does not match' ' format \'%Y-%m-%dT%H:%M:%S\'', @@ -322,6 +337,13 @@ class TimeUtilTest(TimeUtilTestBase): r'Duration is not valid\: Seconds -315576000001 must be in range' r' \[-315576000000\, 315576000000\].', message.ToJsonString) + message.seconds = 0 + message.nanos = 999999999 + 1 + self.assertRaisesRegexp( + well_known_types.Error, + r'Duration is not valid\: Nanos 1000000000 must be in range' + r' \[-999999999\, 999999999\].', + message.ToJsonString) class FieldMaskTest(unittest.TestCase): @@ -363,10 +385,37 @@ class FieldMaskTest(unittest.TestCase): self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) for field in msg_descriptor.fields: self.assertTrue(field.name in mask.paths) + + def testIsValidForDescriptor(self): + msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + # Empty mask + mask = field_mask_pb2.FieldMask() + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + # All fields from descriptor + mask.AllFieldsFromDescriptor(msg_descriptor) + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + # Child under optional message mask.paths.append('optional_nested_message.bb') self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + # Repeated field is only allowed in the last position of path mask.paths.append('repeated_nested_message.bb') self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + # Invalid top level field + mask = field_mask_pb2.FieldMask() + mask.paths.append('xxx') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + # Invalid field in root + mask = field_mask_pb2.FieldMask() + mask.paths.append('xxx.zzz') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + # Invalid field in internal node + mask = field_mask_pb2.FieldMask() + mask.paths.append('optional_nested_message.xxx.zzz') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + # Invalid field in leaf + mask = field_mask_pb2.FieldMask() + mask.paths.append('optional_nested_message.xxx') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) def testCanonicalFrom(self): mask = field_mask_pb2.FieldMask() @@ -422,6 +471,9 @@ class FieldMaskTest(unittest.TestCase): mask2.FromJsonString('foo.bar,bar') out_mask.Union(mask1, mask2) self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) + src = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError): + out_mask.Union(src, mask2) def testIntersect(self): mask1 = field_mask_pb2.FieldMask() @@ -546,6 +598,19 @@ class FieldMaskTest(unittest.TestCase): self.assertEqual(1, len(nested_dst.payload.repeated_int32)) self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) + def testMergeErrors(self): + src = unittest_pb2.TestAllTypes() + dst = unittest_pb2.TestAllTypes() + mask = field_mask_pb2.FieldMask() + test_util.SetAllFields(src) + mask.FromJsonString('optionalInt32.field') + with self.assertRaises(ValueError) as e: + mask.MergeMessage(src, dst) + self.assertEqual('Error: Field optional_int32 in message ' + 'protobuf_unittest.TestAllTypes is not a singular ' + 'message field and cannot have sub-fields.', + str(e.exception)) + def testSnakeCaseToCamelCase(self): self.assertEqual('fooBar', well_known_types._SnakeCaseToCamelCase('foo_bar')) @@ -611,6 +676,8 @@ class StructTest(unittest.TestCase): struct_list = struct.get_or_create_list('key5') struct_list.extend([6, 'seven', True, False, None]) struct_list.add_struct()['subkey2'] = 9 + struct['key6'] = {'subkey': {}} + struct['key7'] = [2, False] self.assertTrue(isinstance(struct, well_known_types.Struct)) self.assertEqual(5, struct['key1']) @@ -621,9 +688,10 @@ class StructTest(unittest.TestCase): inner_struct['subkey2'] = 9 self.assertEqual([6, 'seven', True, False, None, inner_struct], list(struct['key5'].items())) + self.assertEqual({}, dict(struct['key6']['subkey'].fields)) + self.assertEqual([2, False], list(struct['key7'].items())) serialized = struct.SerializeToString() - struct2 = struct_pb2.Struct() struct2.ParseFromString(serialized) @@ -651,6 +719,17 @@ class StructTest(unittest.TestCase): struct_list.add_list().extend([1, 'two', True, False, None]) self.assertEqual([1, 'two', True, False, None], list(struct_list[6].items())) + struct_list.extend([{'nested_struct': 30}, ['nested_list', 99], {}, []]) + self.assertEqual(11, len(struct_list.values)) + self.assertEqual(30, struct_list[7]['nested_struct']) + self.assertEqual('nested_list', struct_list[8][0]) + self.assertEqual(99, struct_list[8][1]) + self.assertEqual({}, dict(struct_list[9].fields)) + self.assertEqual([], list(struct_list[10].items())) + struct_list[0] = {'replace': 'set'} + struct_list[1] = ['replace', 'set'] + self.assertEqual('set', struct_list[0]['replace']) + self.assertEqual(['replace', 'set'], list(struct_list[1].items())) text_serialized = str(struct) struct3 = struct_pb2.Struct() @@ -660,6 +739,67 @@ class StructTest(unittest.TestCase): struct.get_or_create_struct('key3')['replace'] = 12 self.assertEqual(12, struct['key3']['replace']) + # Tests empty list. + struct.get_or_create_list('empty_list') + empty_list = struct['empty_list'] + self.assertEqual([], list(empty_list.items())) + list2 = struct_pb2.ListValue() + list2.add_list() + empty_list = list2[0] + self.assertEqual([], list(empty_list.items())) + + # Tests empty struct. + struct.get_or_create_struct('empty_struct') + empty_struct = struct['empty_struct'] + self.assertEqual({}, dict(empty_struct.fields)) + list2.add_struct() + empty_struct = list2[1] + self.assertEqual({}, dict(empty_struct.fields)) + + def testMergeFrom(self): + struct = struct_pb2.Struct() + struct_class = struct.__class__ + + dictionary = { + 'key1': 5, + 'key2': 'abc', + 'key3': True, + 'key4': {'subkey': 11.0}, + 'key5': [6, 'seven', True, False, None, {'subkey2': 9}], + 'key6': [['nested_list', True]], + 'empty_struct': {}, + 'empty_list': [] + } + struct.update(dictionary) + self.assertEqual(5, struct['key1']) + self.assertEqual('abc', struct['key2']) + self.assertIs(True, struct['key3']) + self.assertEqual(11, struct['key4']['subkey']) + inner_struct = struct_class() + inner_struct['subkey2'] = 9 + self.assertEqual([6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items())) + self.assertEqual(2, len(struct['key6'][0].values)) + self.assertEqual('nested_list', struct['key6'][0][0]) + self.assertEqual(True, struct['key6'][0][1]) + empty_list = struct['empty_list'] + self.assertEqual([], list(empty_list.items())) + empty_struct = struct['empty_struct'] + self.assertEqual({}, dict(empty_struct.fields)) + + # According to documentation: "When parsing from the wire or when merging, + # if there are duplicate map keys the last key seen is used". + duplicate = { + 'key4': {'replace': 20}, + 'key5': [[False, 5]] + } + struct.update(duplicate) + self.assertEqual(1, len(struct['key4'].fields)) + self.assertEqual(20, struct['key4']['replace']) + self.assertEqual(1, len(struct['key5'].values)) + self.assertEqual(False, struct['key5'][0][0]) + self.assertEqual(5, struct['key5'][0][1]) + class AnyTest(unittest.TestCase): diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 43be0701..abd15b77 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -720,14 +720,17 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, map_key, &value); ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue())); - // PyDict_DelItem will have key error if the key is not in the map. We do - // not want to call PyErr_Clear() which may clear other errors. Thus - // PyDict_Contains() check is called before delete. - int contains = PyDict_Contains(self->message_dict, key.get()); - if (contains < 0) { - return -1; - } - if (contains) { + PyObject* cmsg_value = PyDict_GetItem(self->message_dict, key.get()); + if (cmsg_value) { + // Need to keep CMessage stay alive if it is still referenced after + // deletion. Makes a new message and swaps values into CMessage + // instead of just removing. + CMessage* cmsg = reinterpret_cast<CMessage*>(cmsg_value); + Message* msg = cmsg->message; + cmsg->owner.reset(msg->New()); + cmsg->message = cmsg->owner.get(); + cmsg->parent = NULL; + msg->GetReflection()->Swap(msg, cmsg->message); if (PyDict_DelItem(self->message_dict, key.get()) < 0) { return -1; } diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 702c5d03..0f54506b 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -67,7 +67,6 @@ #include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 #define PyInt_AsLong PyLong_AsLong @@ -102,6 +101,17 @@ namespace message_meta { static int InsertEmptyWeakref(PyTypeObject* base); +namespace { +// Copied oveer from internal 'google/protobuf/stubs/strutil.h'. +inline void UpperString(string * s) { + string::iterator end = s->end(); + for (string::iterator i = s->begin(); i != end; ++i) { + // toupper() changes based on locale. We don't want this! + if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + } +} +} + // Add the number of a field descriptor to the containing message class. // Equivalent to: // _cls.<field>_FIELD_NUMBER = <number> @@ -595,19 +605,21 @@ void OutOfRangeError(PyObject* arg) { template<class RangeType, class ValueType> bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { - if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { - if (PyErr_ExceptionMatches(PyExc_OverflowError)) { - // Replace it with the same ValueError as pure python protos instead of - // the default one. - PyErr_Clear(); + if + GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. + return false; + } + if + GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { OutOfRangeError(arg); - } // Otherwise propagate existing error. - return false; - } - if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { - OutOfRangeError(arg); - return false; - } + return false; + } return true; } @@ -616,25 +628,29 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // The fast path. #if PY_MAJOR_VERSION < 3 // For the typical case, offer a fast path. - if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { - long int_result = PyInt_AsLong(arg); - if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { - *value = static_cast<T>(int_result); - return true; - } else { - OutOfRangeError(arg); - return false; + if + GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { + long int_result = PyInt_AsLong(arg); + if + GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { + *value = static_cast<T>(int_result); + return true; + } + else { + OutOfRangeError(arg); + return false; + } } - } #endif // This effectively defines an integer as "an object that can be cast as // an integer and can be used as an ordinal number". // This definition includes everything that implements numbers.Integral // and shouldn't cast the net too wide. - if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { - FormatTypeError(arg, "int, long"); - return false; - } + if + GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { + FormatTypeError(arg, "int, long"); + return false; + } // Now we have an integral number so we can safely use PyLong_ functions. // We need to treat the signed and unsigned cases differently in case arg is @@ -648,10 +664,11 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very // picky about the exact type. PyObject* casted = PyNumber_Long(arg); - if GOOGLE_PREDICT_FALSE(casted == NULL) { - // Propagate existing error. - return false; - } + if + GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } ulong_result = PyLong_AsUnsignedLongLong(casted); Py_DECREF(casted); } @@ -673,10 +690,11 @@ bool CheckAndGetInteger(PyObject* arg, T* value) { // Valid subclasses of numbers.Integral should have a __long__() method // so fall back to that. PyObject* casted = PyNumber_Long(arg); - if GOOGLE_PREDICT_FALSE(casted == NULL) { - // Propagate existing error. - return false; - } + if + GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } long_result = PyLong_AsLongLong(casted); Py_DECREF(casted); } @@ -699,10 +717,11 @@ template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { *value = PyFloat_AsDouble(arg); - if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { - FormatTypeError(arg, "int, long, float"); - return false; - } + if + GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { + FormatTypeError(arg, "int, long, float"); + return false; + } return true; } @@ -1543,20 +1562,7 @@ PyObject* HasField(CMessage* self, PyObject* arg) { if (message->GetReflection()->HasField(*message, field_descriptor)) { Py_RETURN_TRUE; } - if (!message->GetReflection()->SupportsUnknownEnumValues() && - field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { - // Special case: Python HasField() differs in semantics from C++ - // slightly: we return HasField('enum_field') == true if there is - // an unknown enum value present. To implement this we have to - // look in the UnknownFieldSet. - const UnknownFieldSet& unknown_field_set = - message->GetReflection()->GetUnknownFields(*message); - for (int i = 0; i < unknown_field_set.field_count(); ++i) { - if (unknown_field_set.field(i).number() == field_descriptor->number()) { - Py_RETURN_TRUE; - } - } - } + Py_RETURN_FALSE; } @@ -1735,12 +1741,6 @@ PyObject* ClearFieldByDescriptor( AssureWritable(self); Message* message = self->message; message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && - !message->GetReflection()->SupportsUnknownEnumValues()) { - UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - unknown_field_set->DeleteByNumber(field_descriptor->number()); - } Py_RETURN_NONE; } @@ -2335,27 +2335,9 @@ PyObject* InternalGetScalar(const Message* message, break; } case FieldDescriptor::CPPTYPE_ENUM: { - if (!message->GetReflection()->SupportsUnknownEnumValues() && - !message->GetReflection()->HasField(*message, field_descriptor)) { - // Look for the value in the unknown fields. - const UnknownFieldSet& unknown_field_set = - message->GetReflection()->GetUnknownFields(*message); - for (int i = 0; i < unknown_field_set.field_count(); ++i) { - if (unknown_field_set.field(i).number() == - field_descriptor->number() && - unknown_field_set.field(i).type() == - google::protobuf::UnknownField::TYPE_VARINT) { - result = PyInt_FromLong(unknown_field_set.field(i).varint()); - break; - } - } - } - - if (result == NULL) { - const EnumValueDescriptor* enum_value = - message->GetReflection()->GetEnum(*message, field_descriptor); - result = PyInt_FromLong(enum_value->number()); - } + const EnumValueDescriptor* enum_value = + message->GetReflection()->GetEnum(*message, field_descriptor); + result = PyInt_FromLong(enum_value->number()); break; } default: @@ -3077,5 +3059,4 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace python } // namespace protobuf - } // namespace google diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc index d90d9de3..7c4df47f 100644 --- a/python/google/protobuf/pyext/message_module.cc +++ b/python/google/protobuf/pyext/message_module.cc @@ -28,8 +28,33 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include <Python.h> + #include <google/protobuf/pyext/message.h> +#include <google/protobuf/message_lite.h> + +static PyObject* GetPythonProto3PreserveUnknownsDefault( + PyObject* /*m*/, PyObject* /*args*/) { + if (google::protobuf::internal::GetProto3PreserveUnknownsDefault()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PyObject* SetPythonProto3PreserveUnknownsDefault( + PyObject* /*m*/, PyObject* arg) { + if (!arg || !PyBool_Check(arg)) { + PyErr_SetString( + PyExc_TypeError, + "Argument to SetPythonProto3PreserveUnknownsDefault must be boolean"); + return NULL; + } + google::protobuf::internal::SetProto3PreserveUnknownsDefault(PyObject_IsTrue(arg)); + Py_RETURN_NONE; +} + static const char module_docstring[] = "python-proto2 is a module that can be used to enhance proto2 Python API\n" "performance.\n" @@ -41,6 +66,14 @@ static PyMethodDef ModuleMethods[] = { {"SetAllowOversizeProtos", (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, METH_O, "Enable/disable oversize proto parsing."}, + // DO NOT USE: For migration and testing only. + {"GetPythonProto3PreserveUnknownsDefault", + (PyCFunction)GetPythonProto3PreserveUnknownsDefault, + METH_NOARGS, "Get Proto3 preserve unknowns default."}, + // DO NOT USE: For migration and testing only. + {"SetPythonProto3PreserveUnknownsDefault", + (PyCFunction)SetPythonProto3PreserveUnknownsDefault, + METH_O, "Enable/disable proto3 unknowns preservation."}, { NULL, NULL} }; diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto index cce645d7..2e50df74 100644 --- a/python/google/protobuf/pyext/python.proto +++ b/python/google/protobuf/pyext/python.proto @@ -58,11 +58,11 @@ message ForeignMessage { repeated int32 d = 2; } -message TestAllExtensions { +message TestAllExtensions { // extension begin extensions 1 to max; -} +} // extension end -extend TestAllExtensions { +extend TestAllExtensions { // extension begin optional TestAllTypes.NestedMessage optional_nested_message_extension = 1; repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2; -} +} // extension end diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index 54998800..5a7832cd 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -261,22 +261,6 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { result = ToStringObject(field_descriptor, value); break; } - case FieldDescriptor::CPPTYPE_MESSAGE: { - PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>( - &CMessage_Type), NULL); - if (py_cmsg == NULL) { - return NULL; - } - CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); - const Message& msg = reflection->GetRepeatedMessage( - *message, field_descriptor, index); - cmsg->owner = self->owner; - cmsg->parent = self->parent; - cmsg->message = const_cast<Message*>(&msg); - cmsg->read_only = false; - result = reinterpret_cast<PyObject*>(py_cmsg); - break; - } default: PyErr_Format( PyExc_SystemError, diff --git a/python/mox.py b/python/mox.py index 257468e5..43db0219 100755 --- a/python/mox.py +++ b/python/mox.py @@ -778,7 +778,7 @@ class Comparator: rhs: any python object """ - raise NotImplementedError, 'method must be implemented by a subclass.' + raise NotImplementedError('method must be implemented by a subclass.') def __eq__(self, rhs): return self.equals(rhs) diff --git a/python/setup.py b/python/setup.py index 70b7de5c..efb74fe7 100755 --- a/python/setup.py +++ b/python/setup.py @@ -78,6 +78,7 @@ def generate_proto(source, require = True): def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/any_test.proto", False) + generate_proto("../src/google/protobuf/map_proto2_unittest.proto", False) generate_proto("../src/google/protobuf/map_unittest.proto", False) generate_proto("../src/google/protobuf/test_messages_proto3.proto", False) generate_proto("../src/google/protobuf/test_messages_proto2.proto", False) diff --git a/python/tox.ini b/python/tox.ini index baa96dba..38a81b4f 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py{26,27,33,34}-{cpp,python} + py{27,33,34,35,36}-{cpp,python} [testenv] usedevelop=true |