aboutsummaryrefslogtreecommitdiffhomepage
path: root/python
diff options
context:
space:
mode:
authorGravatar Adam Cozzette <acozzette@google.com>2017-09-12 10:32:01 -0700
committerGravatar Adam Cozzette <acozzette@google.com>2017-09-14 10:03:57 -0700
commit13fd045dbb2b4dacea32be162a41d5a4b0d1802f (patch)
treec219e7eb18b82523e36c6748861c403a14ea66ae /python
parentd1bc27caef8377a710370189675cb0958443e8f1 (diff)
Integrated internal changes from Google
Diffstat (limited to 'python')
-rw-r--r--python/google/protobuf/descriptor_database.py13
-rw-r--r--python/google/protobuf/descriptor_pool.py27
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py39
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py34
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py272
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py299
-rwxr-xr-xpython/google/protobuf/internal/encoder.py2
-rw-r--r--python/google/protobuf/internal/json_format_test.py60
-rw-r--r--python/google/protobuf/internal/message_factory_test.py16
-rwxr-xr-xpython/google/protobuf/internal/message_test.py208
-rwxr-xr-xpython/google/protobuf/internal/python_message.py37
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py66
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py4
-rwxr-xr-xpython/google/protobuf/internal/test_util.py54
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py60
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py158
-rw-r--r--python/google/protobuf/internal/well_known_types.py28
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py142
-rw-r--r--python/google/protobuf/pyext/map_container.cc19
-rw-r--r--python/google/protobuf/pyext/message.cc129
-rw-r--r--python/google/protobuf/pyext/message_module.cc33
-rw-r--r--python/google/protobuf/pyext/python.proto8
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc16
-rwxr-xr-xpython/setup.py1
24 files changed, 1316 insertions, 409 deletions
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index eb45e127..b8f5140b 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -107,6 +107,7 @@ class DescriptorDatabase(object):
'some.package.name.Message'
'some.package.name.Message.NestedEnum'
+ 'some.package.name.Message.some_field'
The file descriptor proto containing the specified symbol must be added to
this database using the Add method or else an error will be raised.
@@ -120,8 +121,16 @@ class DescriptorDatabase(object):
Raises:
KeyError if no file contains the specified symbol.
"""
-
- return self._file_desc_protos_by_symbol[symbol]
+ try:
+ return self._file_desc_protos_by_symbol[symbol]
+ except KeyError:
+ # Fields, enum values, and nested extensions are not in
+ # _file_desc_protos_by_symbol. Try to find the top level
+ # descriptor. Non-existent nested symbol under a valid top level
+ # descriptor can also be found. The behavior is the same with
+ # protobuf C++.
+ top_level, _, _ = symbol.rpartition('.')
+ return self._file_desc_protos_by_symbol[top_level]
def _ExtractSymbols(desc_proto, package):
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 3dbe0fd0..cb7146b6 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -127,9 +127,6 @@ class DescriptorPool(object):
self._service_descriptors = {}
self._file_descriptors = {}
self._toplevel_extensions = {}
- # TODO(jieluo): Remove _file_desc_by_toplevel_extension when
- # FieldDescriptor.file is added in code gen.
- self._file_desc_by_toplevel_extension = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
@@ -255,11 +252,6 @@ class DescriptorPool(object):
"""
self._AddFileDescriptor(file_desc)
- # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
- # Remove it when FieldDescriptor.file is added in code gen.
- for extension in file_desc.extensions_by_name.values():
- self._file_desc_by_toplevel_extension[
- extension.full_name] = file_desc
def _AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -339,7 +331,7 @@ class DescriptorPool(object):
pass
try:
- return self._file_desc_by_toplevel_extension[symbol]
+ return self._toplevel_extensions[symbol].file
except KeyError:
pass
@@ -405,6 +397,23 @@ class DescriptorPool(object):
message_descriptor = self.FindMessageTypeByName(message_name)
return message_descriptor.fields_by_name[field_name]
+ def FindOneofByName(self, full_name):
+ """Loads the named oneof descriptor from the pool.
+
+ Args:
+ full_name: The full name of the oneof descriptor to load.
+
+ Returns:
+ The oneof descriptor for the named oneof.
+
+ Raises:
+ KeyError: if the oneof cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, oneof_name = full_name.rpartition('.')
+ message_descriptor = self.FindMessageTypeByName(message_name)
+ return message_descriptor.oneofs_by_name[oneof_name]
+
def FindExtensionByName(self, full_name):
"""Loads the named extension descriptor from the pool.
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index 422af590..bce71bb8 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -61,10 +61,15 @@ if _api_version < 0: # Still unspecified?
del _use_fast_cpp_protos
_api_version = 2
except ImportError:
- if _proto_extension_modules_exist_in_build:
- if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
- _api_version = 2
- # TODO(b/17427486): Make Python 2 default to C++ impl v2.
+ try:
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.internal import use_pure_python
+ del use_pure_python # Avoids a pylint error and namespace pollution.
+ except ImportError:
+ if _proto_extension_modules_exist_in_build:
+ if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
+ _api_version = 2
+ # TODO(b/17427486): Make Python 2 default to C++ impl v2.
_default_implementation_type = (
'python' if _api_version <= 0 else 'cpp')
@@ -137,3 +142,29 @@ def Version():
# For internal use only
def IsPythonDefaultSerializationDeterministic():
return _python_deterministic_proto_serialization
+
+# DO NOT USE: For migration and testing only. Will be removed when Proto3
+# defaults to preserve unknowns.
+if _implementation_type == 'cpp':
+ try:
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+
+ def GetPythonProto3PreserveUnknownsDefault():
+ return _message.GetPythonProto3PreserveUnknownsDefault()
+
+ def SetPythonProto3PreserveUnknownsDefault(preserve):
+ _message.SetPythonProto3PreserveUnknownsDefault(preserve)
+ except ImportError:
+ # Unrecognized cpp implementation. Skipping the unknown fields APIs.
+ pass
+else:
+ _python_proto3_preserve_unknowns_default = False
+
+ def GetPythonProto3PreserveUnknownsDefault():
+ return _python_proto3_preserve_unknowns_default
+
+ def SetPythonProto3PreserveUnknownsDefault(preserve):
+ global _python_proto3_preserve_unknowns_default
+ _python_proto3_preserve_unknowns_default = preserve
+
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index 5225a458..1f1a3db9 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -39,6 +39,7 @@ try:
except ImportError:
import unittest
+from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
@@ -54,16 +55,49 @@ class DescriptorDatabaseTest(unittest.TestCase):
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
+ # Can find message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message'))
+ # Can find nested message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
+ # Can find enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum'))
+ # Can find nested enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum'))
+ # Can find field.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.list_field'))
+ # Can find enum value.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
+ # Can find top level extension.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.another_field'))
+ # Can find nested extension inside a message.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field'))
+
+ # Can find service.
+ file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb)
+ db.Add(file_desc_proto2)
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ 'protobuf_unittest.TestService'))
+
+ # Non-existent field under a valid top level symbol can also be
+ # found. The behavior is the same with protobuf C++.
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes.none_field'))
+
+ self.assertRaises(KeyError,
+ db.FindFileContainingSymbol,
+ 'protobuf_unittest.NoneMessage')
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 6015e6f8..15c857bb 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -60,26 +60,8 @@ from google.protobuf import message_factory
from google.protobuf import symbol_database
-class DescriptorPoolTest(unittest.TestCase):
- def setUp(self):
- # TODO(jieluo): Should make the pool which is created by
- # serialized_pb same with generated pool.
- # TODO(jieluo): More test coverage for the generated pool.
- self.pool = descriptor_pool.DescriptorPool()
- self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
- factory_test1_pb2.DESCRIPTOR.serialized_pb)
- self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
- factory_test2_pb2.DESCRIPTOR.serialized_pb)
- self.pool.Add(self.factory_test1_fd)
- self.pool.Add(self.factory_test2_fd)
-
- self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
- unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
- self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
- unittest_import_pb2.DESCRIPTOR.serialized_pb))
- self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
- unittest_pb2.DESCRIPTOR.serialized_pb))
+class DescriptorPoolTestBase(object):
def testFindFileByName(self):
name1 = 'google/protobuf/internal/factory_test1.proto'
@@ -137,14 +119,6 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google/protobuf/unittest.proto',
file_desc5.name)
- # Tests the generated pool.
- assert descriptor_pool.Default().FindFileContainingSymbol(
- 'google.protobuf.python.internal.Factory2Message.one_more_field')
- assert descriptor_pool.Default().FindFileContainingSymbol(
- 'google.protobuf.python.internal.another_field')
- assert descriptor_pool.Default().FindFileContainingSymbol(
- 'protobuf_unittest.TestService')
-
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
@@ -231,6 +205,27 @@ class DescriptorPoolTest(unittest.TestCase):
msg2.fields_by_name[name].containing_oneof)
self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields)
+ def testFindTypeErrors(self):
+ self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')
+
+ # TODO(jieluo): Fix python to raise correct errors.
+ if api_implementation.Type() == 'cpp':
+ self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
+ self.assertRaises(KeyError, self.pool.FindMethodByName, '')
+ error_type = TypeError
+ else:
+ error_type = AttributeError
+ self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0)
+ self.assertRaises(error_type, self.pool.FindFieldByName, 0)
+ self.assertRaises(error_type, self.pool.FindExtensionByName, 0)
+ self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
+ self.assertRaises(error_type, self.pool.FindOneofByName, 0)
+ self.assertRaises(error_type, self.pool.FindServiceByName, 0)
+ self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
+ if api_implementation.Type() == 'python':
+ error_type = KeyError
+ self.assertRaises(error_type, self.pool.FindFileByName, 0)
+
def testFindMessageTypeByNameFailure(self):
with self.assertRaises(KeyError):
self.pool.FindMessageTypeByName('Does not exist')
@@ -270,6 +265,11 @@ class DescriptorPoolTest(unittest.TestCase):
self.pool.FindEnumTypeByName('Does not exist')
def testFindFieldByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find field correctly
+ # when descriptor pool is using an underlying database.
+ return
field = self.pool.FindFieldByName(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
@@ -279,7 +279,24 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
+ def testFindOneofByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find oneof correctly
+ # when descriptor pool is using an underlying database.
+ return
+ oneof = self.pool.FindOneofByName(
+ 'google.protobuf.python.internal.Factory2Message.oneof_field')
+ self.assertEqual(oneof.name, 'oneof_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindOneofByName('Does not exist')
+
def testFindExtensionByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find extension correctly
+ # when descriptor pool is using an underlying database.
+ return
# An extension defined in a message.
extension = self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -352,6 +369,8 @@ class DescriptorPoolTest(unittest.TestCase):
def testFindService(self):
service = self.pool.FindServiceByName('protobuf_unittest.TestService')
self.assertEqual(service.full_name, 'protobuf_unittest.TestService')
+ with self.assertRaises(KeyError):
+ self.pool.FindServiceByName('Does not exist')
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
@@ -361,24 +380,17 @@ class DescriptorPoolTest(unittest.TestCase):
self.testFindMessageTypeByName()
def testAddSerializedFile(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
self.pool = descriptor_pool.DescriptorPool()
self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
self.testFindMessageTypeByName()
- def testComplexNesting(self):
- more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
- more_messages_pb2.DESCRIPTOR.serialized_pb)
- test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
- descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
- test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
- descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
- self.pool.Add(more_messages_desc)
- self.pool.Add(test1_desc)
- self.pool.Add(test2_desc)
- TEST1_FILE.CheckFile(self, self.pool)
- TEST2_FILE.CheckFile(self, self.pool)
-
def testEnumDefaultValue(self):
"""Test the default value of enums which don't start at zero."""
@@ -397,6 +409,12 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR)
_CheckDefaultValue(file_descriptor)
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
# Then check the dynamic pool and its internal DescriptorDatabase.
descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
@@ -444,6 +462,110 @@ class DescriptorPoolTest(unittest.TestCase):
unittest_pb2.TestAllTypes.DESCRIPTOR.full_name))
_CheckDefaultValues(message_class())
+ def testAddFileDescriptor(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ self.pool.Add(file_desc)
+ self.pool.AddSerializedFile(file_desc.SerializeToString())
+
+ def testComplexNesting(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
+ more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ more_messages_pb2.DESCRIPTOR.serialized_pb)
+ test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
+ test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(more_messages_desc)
+ self.pool.Add(test1_desc)
+ self.pool.Add(test2_desc)
+ TEST1_FILE.CheckFile(self, self.pool)
+ TEST2_FILE.CheckFile(self, self.pool)
+
+
+class DefaultDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
+
+ def setUp(self):
+ self.pool = descriptor_pool.Default()
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+
+ def testFindMethods(self):
+ self.assertIs(
+ self.pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindFieldByName(
+ 'protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ self.pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindExtensionByName(
+ 'protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ self.pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+ self.assertIs(
+ self.pool.FindServiceByName('protobuf_unittest.TestService'),
+ unittest_pb2.DESCRIPTOR.services_by_name['TestService'])
+
+
+class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
+
+ def setUp(self):
+ self.pool = descriptor_pool.DescriptorPool()
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(self.factory_test1_fd)
+ self.pool.Add(self.factory_test2_fd)
+
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+
+
+class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
+ unittest.TestCase):
+
+ def setUp(self):
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ db = descriptor_database.DescriptorDatabase()
+ db.Add(self.factory_test1_fd)
+ db.Add(self.factory_test2_fd)
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+ self.pool = descriptor_pool.DescriptorPool(descriptor_db=db)
+
class ProtoFile(object):
@@ -568,6 +690,11 @@ class MessageField(object):
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
test.assertEqual(file_desc, field_desc.file)
+ # TODO(jieluo): Fix python and cpp extension diff for message field
+ # default value.
+ if api_implementation.Type() == 'cpp':
+ test.assertRaises(
+ NotImplementedError, getattr, field_desc, 'default_value')
class StringField(object):
@@ -739,6 +866,25 @@ class AddDescriptorTest(unittest.TestCase):
'some/file.proto')
self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
'Message')
+ # Test no package
+ file_proto = descriptor_pb2.FileDescriptorProto(
+ name='some/filename/container.proto')
+ message_proto = file_proto.message_type.add(
+ name='TopMessage')
+ message_proto.field.add(
+ name='bb',
+ number=1,
+ type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
+ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
+ enum_proto = file_proto.enum_type.add(name='TopEnum')
+ enum_proto.value.add(name='FOREIGN_FOO', number=4)
+ file_proto.service.add(name='TopService')
+ pool = descriptor_pool.DescriptorPool()
+ pool.Add(file_proto)
+ self.assertEqual('TopMessage',
+ pool.FindMessageTypeByName('TopMessage').name)
+ self.assertEqual('TopEnum', pool.FindEnumTypeByName('TopEnum').name)
+ self.assertEqual('TopService', pool.FindServiceByName('TopService').name)
def testFileDescriptorOptionsWithCustomDescriptorPool(self):
# Create a descriptor pool, and add a new FileDescriptorProto to it.
@@ -757,40 +903,18 @@ class AddDescriptorTest(unittest.TestCase):
# The object returned by GetOptions() is cached.
self.assertIs(options, file_descriptor.GetOptions())
-
-class DefaultPoolTest(unittest.TestCase):
-
- def testFindMethods(self):
- pool = descriptor_pool.Default()
- self.assertIs(
- pool.FindFileByName('google/protobuf/unittest.proto'),
- unittest_pb2.DESCRIPTOR)
- self.assertIs(
- pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
- unittest_pb2.TestAllTypes.DESCRIPTOR)
- self.assertIs(
- pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
- unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
- self.assertIs(
- pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
- unittest_pb2.ForeignEnum.DESCRIPTOR)
- if api_implementation.Type() != 'cpp':
- self.skipTest('Only the C++ implementation correctly indexes all types')
- self.assertIs(
- pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
- unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
- self.assertIs(
- pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
- unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
- self.assertIs(
- pool.FindServiceByName('protobuf_unittest.TestService'),
- unittest_pb2.DESCRIPTOR.services_by_name['TestService'])
-
- def testAddFileDescriptor(self):
- pool = descriptor_pool.Default()
- file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
- pool.Add(file_desc)
- pool.AddSerializedFile(file_desc.SerializeToString())
+ def testAddTypeError(self):
+ pool = descriptor_pool.DescriptorPool()
+ with self.assertRaises(TypeError):
+ pool.AddDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddEnumDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddServiceDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddExtensionDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddFileDescriptor(0)
TEST1_FILE = ProtoFile(
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index c0010081..f1d5934e 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -107,6 +107,12 @@ class DescriptorTest(unittest.TestCase):
self.my_message.enum_types_by_name[
'ForeignEnum'].values_by_number[4].name,
self.my_message.EnumValueName('ForeignEnum', 4))
+ with self.assertRaises(KeyError):
+ self.my_message.EnumValueName('ForeignEnum', 999)
+ with self.assertRaises(KeyError):
+ self.my_message.EnumValueName('NoneEnum', 999)
+ with self.assertRaises(TypeError):
+ self.my_message.EnumValueName()
def testEnumFixups(self):
self.assertEqual(self.my_enum, self.my_enum.values[0].type)
@@ -134,15 +140,17 @@ class DescriptorTest(unittest.TestCase):
def testSimpleCustomOptions(self):
file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
- message_descriptor =\
- unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR
+ message_descriptor = (unittest_custom_options_pb2.
+ TestMessageWithCustomOptions.DESCRIPTOR)
field_descriptor = message_descriptor.fields_by_name['field1']
oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof']
enum_descriptor = message_descriptor.enum_types_by_name['AnEnum']
- enum_value_descriptor =\
- message_descriptor.enum_values_by_name['ANENUM_VAL2']
- service_descriptor =\
- unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR
+ enum_value_descriptor = (message_descriptor.
+ enum_values_by_name['ANENUM_VAL2'])
+ other_enum_value_descriptor = (message_descriptor.
+ enum_values_by_name['ANENUM_VAL1'])
+ service_descriptor = (unittest_custom_options_pb2.
+ TestServiceWithCustomOptions.DESCRIPTOR)
method_descriptor = service_descriptor.FindMethodByName('Foo')
file_options = file_descriptor.GetOptions()
@@ -178,6 +186,11 @@ class DescriptorTest(unittest.TestCase):
unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
self.assertTrue(file_descriptor.has_options)
self.assertFalse(message_descriptor.has_options)
+ self.assertTrue(field_descriptor.has_options)
+ self.assertTrue(oneof_descriptor.has_options)
+ self.assertTrue(enum_descriptor.has_options)
+ self.assertTrue(enum_value_descriptor.has_options)
+ self.assertFalse(other_enum_value_descriptor.has_options)
def testDifferentCustomOptionTypes(self):
kint32min = -2**31
@@ -400,6 +413,12 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
self.assertEqual(self.my_file.pool, self.pool)
+ self.assertFalse(self.my_file.has_options)
+ self.assertEqual('proto2', self.my_file.syntax)
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ self.my_file.CopyToProto(file_proto)
+ self.assertEqual(self.my_file.serialized_pb,
+ file_proto.SerializeToString())
# Generated modules also belong to the default pool.
self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default())
@@ -407,13 +426,31 @@ class DescriptorTest(unittest.TestCase):
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'Immutability of descriptors is only enforced in v2 implementation')
def testImmutableCppDescriptor(self):
+ file_descriptor = unittest_pb2.DESCRIPTOR
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ field_descriptor = message_descriptor.fields_by_name['optional_int32']
+ enum_descriptor = message_descriptor.enum_types_by_name['NestedEnum']
+ oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field']
with self.assertRaises(AttributeError):
message_descriptor.fields_by_name = None
with self.assertRaises(TypeError):
message_descriptor.fields_by_name['Another'] = None
with self.assertRaises(TypeError):
message_descriptor.fields.append(None)
+ with self.assertRaises(AttributeError):
+ field_descriptor.containing_type = message_descriptor
+ with self.assertRaises(AttributeError):
+ file_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ field_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ oneof_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ enum_descriptor.has_options = False
+ with self.assertRaises(AttributeError) as e:
+ message_descriptor.has_options = True
+ self.assertEqual('attribute is not writable: has_options',
+ str(e.exception))
class NewDescriptorTest(DescriptorTest):
@@ -442,6 +479,12 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.CheckDescriptorMapping(message_descriptor.fields_by_name)
self.CheckDescriptorMapping(message_descriptor.fields_by_number)
self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_types_by_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_values_by_name)
+ self.CheckDescriptorMapping(message_descriptor.oneofs_by_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_types[0].values_by_name)
+ # Test extension range
+ self.assertEqual(message_descriptor.extension_ranges, [])
def CheckFieldDescriptor(self, field_descriptor):
# Basic properties
@@ -450,6 +493,7 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.assertEqual(field_descriptor.full_name,
'protobuf_unittest.TestAllTypes.optional_int32')
self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
+ self.assertEqual(field_descriptor.file, unittest_pb2.DESCRIPTOR)
# Test equality and hashability
self.assertEqual(field_descriptor, field_descriptor)
self.assertEqual(
@@ -461,32 +505,73 @@ class GeneratedDescriptorTest(unittest.TestCase):
field_descriptor)
self.assertIn(field_descriptor, [field_descriptor])
self.assertIn(field_descriptor, {field_descriptor: None})
+ self.assertEqual(None, field_descriptor.extension_scope)
+ self.assertEqual(None, field_descriptor.enum_type)
+ if api_implementation.Type() == 'cpp':
+ # For test coverage only
+ self.assertEqual(field_descriptor.id, field_descriptor.id)
def CheckDescriptorSequence(self, sequence):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Sequence.
+ self.assertNotEqual(sequence,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR.fields)
+ self.assertNotEqual(sequence, [])
+ self.assertNotEqual(sequence, 1)
+ self.assertFalse(sequence == 1) # Only for cpp test coverage
+ self.assertEqual(sequence, sequence)
+ expected_list = list(sequence)
+ self.assertEqual(expected_list, sequence)
self.assertGreater(len(sequence), 0) # Sized
- self.assertEqual(len(sequence), len(list(sequence))) # Iterable
+ self.assertEqual(len(sequence), len(expected_list)) # Iterable
+ self.assertEqual(sequence[len(sequence) -1], sequence[-1])
item = sequence[0]
self.assertEqual(item, sequence[0])
self.assertIn(item, sequence) # Container
self.assertEqual(sequence.index(item), 0)
self.assertEqual(sequence.count(item), 1)
+ other_item = unittest_pb2.NestedTestAllTypes.DESCRIPTOR.fields[0]
+ self.assertNotIn(other_item, sequence)
+ self.assertEqual(sequence.count(other_item), 0)
+ self.assertRaises(ValueError, sequence.index, other_item)
+ self.assertRaises(ValueError, sequence.index, [])
reversed_iterator = reversed(sequence)
self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
self.assertRaises(StopIteration, next, reversed_iterator)
+ expected_list[0] = 'change value'
+ self.assertNotEqual(expected_list, sequence)
+ # TODO(jieluo): Change __repr__ support for DescriptorSequence.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(str(list(sequence)), str(sequence))
+ else:
+ self.assertEqual(str(sequence)[0], '<')
def CheckDescriptorMapping(self, mapping):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Mapping.
+ self.assertNotEqual(
+ mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name)
+ self.assertNotEqual(mapping, {})
+ self.assertNotEqual(mapping, 1)
+ self.assertFalse(mapping == 1) # Only for cpp test coverage
+ excepted_dict = dict(mapping.items())
+ self.assertEqual(mapping, excepted_dict)
+ self.assertEqual(mapping, mapping)
self.assertGreater(len(mapping), 0) # Sized
- self.assertEqual(len(mapping), len(list(mapping))) # Iterable
+ self.assertEqual(len(mapping), len(excepted_dict)) # Iterable
if sys.version_info >= (3,):
key, item = next(iter(mapping.items()))
else:
key, item = mapping.items()[0]
self.assertIn(key, mapping) # Container
self.assertEqual(mapping.get(key), item)
+ with self.assertRaises(TypeError):
+ mapping.get()
+ # TODO(jieluo): Fix python and cpp extension diff.
+ if api_implementation.Type() == 'python':
+ self.assertRaises(TypeError, mapping.get, [])
+ else:
+ self.assertEqual(None, mapping.get([]))
# keys(), iterkeys() &co
item = (next(iter(mapping.keys())), next(iter(mapping.values())))
self.assertEqual(item, next(iter(mapping.items())))
@@ -497,6 +582,18 @@ class GeneratedDescriptorTest(unittest.TestCase):
CheckItems(mapping.keys(), mapping.iterkeys())
CheckItems(mapping.values(), mapping.itervalues())
CheckItems(mapping.items(), mapping.iteritems())
+ excepted_dict[key] = 'change value'
+ self.assertNotEqual(mapping, excepted_dict)
+ del excepted_dict[key]
+ excepted_dict['new_key'] = 'new'
+ self.assertNotEqual(mapping, excepted_dict)
+ self.assertRaises(KeyError, mapping.__getitem__, 'key_error')
+ self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1)
+ # TODO(jieluo): Add __repr__ support for DescriptorMapping.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(len(str(dict(mapping.items()))), len(str(mapping)))
+ else:
+ self.assertEqual(str(mapping)[0], '<')
def testDescriptor(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -506,13 +603,26 @@ class GeneratedDescriptorTest(unittest.TestCase):
field_descriptor = message_descriptor.fields_by_camelcase_name[
'optionalInt32']
self.CheckFieldDescriptor(field_descriptor)
+ enum_descriptor = unittest_pb2.DESCRIPTOR.enum_types_by_name[
+ 'ForeignEnum']
+ self.assertEqual(None, enum_descriptor.containing_type)
+ # Test extension range
+ self.assertEqual(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
+ [(1, 536870912)])
+ self.assertEqual(
+ unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
+ [(42, 43), (4143, 4244), (65536, 536870912)])
def testCppDescriptorContainer(self):
- # Check that the collection is still valid even if the parent disappeared.
- enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
- values = enum.values
- del enum
- self.assertEqual('FOO', values[0].name)
+ containing_file = unittest_pb2.DESCRIPTOR
+ self.CheckDescriptorSequence(containing_file.dependencies)
+ self.CheckDescriptorMapping(containing_file.message_types_by_name)
+ self.CheckDescriptorMapping(containing_file.enum_types_by_name)
+ self.CheckDescriptorMapping(containing_file.services_by_name)
+ self.CheckDescriptorMapping(containing_file.extensions_by_name)
+ self.CheckDescriptorMapping(
+ unittest_pb2.TestNestedExtension.DESCRIPTOR.extensions_by_name)
def testCppDescriptorContainer_Iterator(self):
# Same test with the iterator
@@ -526,6 +636,18 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.assertEqual(service_descriptor.name, 'TestService')
self.assertEqual(service_descriptor.methods[0].name, 'Foo')
self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR)
+ self.assertEqual(service_descriptor.index, 0)
+ self.CheckDescriptorMapping(service_descriptor.methods_by_name)
+
+ def testOneofDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field']
+ self.assertFalse(oneof_descriptor.has_options)
+ self.assertEqual(message_descriptor, oneof_descriptor.containing_type)
+ self.assertEqual('oneof_field', oneof_descriptor.name)
+ self.assertEqual('protobuf_unittest.TestAllTypes.oneof_field',
+ oneof_descriptor.full_name)
+ self.assertEqual(0, oneof_descriptor.index)
class DescriptorCopyToProtoTest(unittest.TestCase):
@@ -663,49 +785,64 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
descriptor_pb2.DescriptorProto,
TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
- # Disable this test so we can make changes to the proto file.
- # TODO(xiaofeng): Enable this test after cl/55530659 is submitted.
- #
- # def testCopyToProto_FileDescriptor(self):
- # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
- # name: 'google/protobuf/unittest_import.proto'
- # package: 'protobuf_unittest_import'
- # dependency: 'google/protobuf/unittest_import_public.proto'
- # message_type: <
- # name: 'ImportMessage'
- # field: <
- # name: 'd'
- # number: 1
- # label: 1 # Optional
- # type: 5 # TYPE_INT32
- # >
- # >
- # """ +
- # """enum_type: <
- # name: 'ImportEnum'
- # value: <
- # name: 'IMPORT_FOO'
- # number: 7
- # >
- # value: <
- # name: 'IMPORT_BAR'
- # number: 8
- # >
- # value: <
- # name: 'IMPORT_BAZ'
- # number: 9
- # >
- # >
- # options: <
- # java_package: 'com.google.protobuf.test'
- # optimize_for: 1 # SPEED
- # >
- # public_dependency: 0
- # """)
- # self._InternalTestCopyToProto(
- # unittest_import_pb2.DESCRIPTOR,
- # descriptor_pb2.FileDescriptorProto,
- # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
+ def testCopyToProto_FileDescriptor(self):
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
+ name: 'google/protobuf/unittest_import.proto'
+ package: 'protobuf_unittest_import'
+ dependency: 'google/protobuf/unittest_import_public.proto'
+ message_type: <
+ name: 'ImportMessage'
+ field: <
+ name: 'd'
+ number: 1
+ label: 1 # Optional
+ type: 5 # TYPE_INT32
+ >
+ >
+ """ +
+ """enum_type: <
+ name: 'ImportEnum'
+ value: <
+ name: 'IMPORT_FOO'
+ number: 7
+ >
+ value: <
+ name: 'IMPORT_BAR'
+ number: 8
+ >
+ value: <
+ name: 'IMPORT_BAZ'
+ number: 9
+ >
+ >
+ enum_type: <
+ name: 'ImportEnumForMap'
+ value: <
+ name: 'UNKNOWN'
+ number: 0
+ >
+ value: <
+ name: 'FOO'
+ number: 1
+ >
+ value: <
+ name: 'BAR'
+ number: 2
+ >
+ >
+ options: <
+ java_package: 'com.google.protobuf.test'
+ optimize_for: 1 # SPEED
+ """ +
+ """
+ cc_enable_arenas: true
+ >
+ public_dependency: 0
+ """)
+ self._InternalTestCopyToProto(
+ unittest_import_pb2.DESCRIPTOR,
+ descriptor_pb2.FileDescriptorProto,
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
def testCopyToProto_ServiceDescriptor(self):
TEST_SERVICE_ASCII = """
@@ -721,12 +858,47 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
- # TODO(rocking): enable this test after the proto descriptor change is
- # checked in.
- #self._InternalTestCopyToProto(
- # unittest_pb2.TestService.DESCRIPTOR,
- # descriptor_pb2.ServiceDescriptorProto,
- # TEST_SERVICE_ASCII)
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestService.DESCRIPTOR,
+ descriptor_pb2.ServiceDescriptorProto,
+ TEST_SERVICE_ASCII)
+
+ @unittest.skipIf(
+ api_implementation.Type() == 'python',
+ 'It is not implemented in python.')
+ # TODO(jieluo): Add support for pure python or remove in c extension.
+ def testCopyToProto_MethodDescriptor(self):
+ expected_ascii = """
+ name: 'Foo'
+ input_type: '.protobuf_unittest.FooRequest'
+ output_type: '.protobuf_unittest.FooResponse'
+ """
+ method_descriptor = unittest_pb2.TestService.DESCRIPTOR.FindMethodByName(
+ 'Foo')
+ self._InternalTestCopyToProto(
+ method_descriptor,
+ descriptor_pb2.MethodDescriptorProto,
+ expected_ascii)
+
+ @unittest.skipIf(
+ api_implementation.Type() == 'python',
+ 'Pure python does not raise error.')
+ # TODO(jieluo): Fix pure python to check with the proto type.
+ def testCopyToProto_TypeError(self):
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ self.assertRaises(TypeError,
+ unittest_pb2.TestEmptyMessage.DESCRIPTOR.CopyToProto,
+ file_proto)
+ self.assertRaises(TypeError,
+ unittest_pb2.ForeignEnum.DESCRIPTOR.CopyToProto,
+ file_proto)
+ self.assertRaises(TypeError,
+ unittest_pb2.TestService.DESCRIPTOR.CopyToProto,
+ file_proto)
+ proto = descriptor_pb2.DescriptorProto()
+ self.assertRaises(TypeError,
+ unittest_import_pb2.DESCRIPTOR.CopyToProto,
+ proto)
class MakeDescriptorTest(unittest.TestCase):
@@ -774,6 +946,9 @@ class MakeDescriptorTest(unittest.TestCase):
result.nested_types[0].enum_types[0])
self.assertFalse(result.has_options)
self.assertFalse(result.fields[0].has_options)
+ if api_implementation.Type() == 'cpp':
+ with self.assertRaises(AttributeError):
+ result.fields[0].has_options = False
def testMakeDescriptorWithUnsignedIntField(self):
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index ebec42e5..8c6a1189 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -819,7 +819,7 @@ def MapEncoder(field_descriptor):
encode_message = MessageEncoder(field_descriptor.number, False, False)
def EncodeField(write, value, deterministic):
- value_keys = sorted(value.keys()) if deterministic else value.keys()
+ value_keys = sorted(value.keys()) if deterministic else value
for key in value_keys:
entry_msg = message_type._concrete_class(key=key, value=value[key])
encode_message(write, entry_msg, deterministic)
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 077b64db..b2cf7622 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -159,6 +159,16 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testUnknownEnumToJsonError(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.enum_value = 999
+ # TODO(jieluo): should accept numeric unknown enum for proto3.
+ with self.assertRaises(json_format.SerializeToJsonError) as e:
+ json_format.MessageToJson(message)
+ self.assertEqual(str(e.exception),
+ 'Enum field contains an integer value which can '
+ 'not mapped to an enum value.')
+
def testExtensionToJsonAndBack(self):
message = unittest_mset_pb2.TestMessageSetContainer()
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
@@ -172,6 +182,10 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse(message_text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testExtensionErrors(self):
+ self.CheckError('{"[extensionField]": {}}',
+ 'Message type proto3.TestMessage does not have extensions')
+
def testExtensionToDictAndBack(self):
message = unittest_mset_pb2.TestMessageSetContainer()
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
@@ -294,7 +308,18 @@ class JsonFormatTest(JsonFormatBase):
self.assertEqual(message.int32_value, 1)
def testMapFields(self):
- message = json_format_proto3_pb2.TestMap()
+ message = json_format_proto3_pb2.TestNestedMap()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"boolMap": {},'
+ '"int32Map": {},'
+ '"int64Map": {},'
+ '"uint32Map": {},'
+ '"uint64Map": {},'
+ '"stringMap": {},'
+ '"mapMap": {}'
+ '}'))
message.bool_map[True] = 1
message.bool_map[False] = 2
message.int32_map[1] = 2
@@ -307,17 +332,19 @@ class JsonFormatTest(JsonFormatBase):
message.uint64_map[2] = 3
message.string_map['1'] = 2
message.string_map['null'] = 3
+ message.map_map['1'].bool_map[True] = 3
self.assertEqual(
- json.loads(json_format.MessageToJson(message, True)),
+ json.loads(json_format.MessageToJson(message, False)),
json.loads('{'
'"boolMap": {"false": 2, "true": 1},'
'"int32Map": {"1": 2, "2": 3},'
'"int64Map": {"1": 2, "2": 3},'
'"uint32Map": {"1": 2, "2": 3},'
'"uint64Map": {"1": 2, "2": 3},'
- '"stringMap": {"1": 2, "null": 3}'
+ '"stringMap": {"1": 2, "null": 3},'
+ '"mapMap": {"1": {"boolMap": {"true": 3}}}'
'}'))
- parsed_message = json_format_proto3_pb2.TestMap()
+ parsed_message = json_format_proto3_pb2.TestNestedMap()
self.CheckParseBack(message, parsed_message)
def testOneofFields(self):
@@ -703,6 +730,9 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse,
'{"repeatedInt32Value":[1, null]}',
parsed_message)
+ self.CheckError('{"repeatedMessageValue":[null]}',
+ 'Failed to parse repeatedMessageValue field: null is not'
+ ' allowed to be used as an element in a repeated field.')
def testNanFloat(self):
message = json_format_proto3_pb2.TestMessage()
@@ -727,6 +757,11 @@ class JsonFormatTest(JsonFormatBase):
'{"enumValue": "baz"}',
'Failed to parse enumValue field: Invalid enum value baz '
'for enum type proto3.EnumType.')
+ # TODO(jieluo): fix json format to accept numeric unknown enum for proto3.
+ self.CheckError(
+ '{"enumValue": 12345}',
+ 'Failed to parse enumValue field: Invalid enum value 12345 '
+ 'for enum type proto3.EnumType.')
def testParseBadIdentifer(self):
self.CheckError('{int32Value: 1}',
@@ -799,6 +834,11 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('{"bytesValue": "AQI*"}',
'Failed to parse bytesValue field: Incorrect padding.')
+ def testInvalidRepeated(self):
+ self.CheckError('{"repeatedInt32Value": 12345}',
+ (r'Failed to parse repeatedInt32Value field: repeated field'
+ r' repeatedInt32Value must be in \[\] which is 12345.'))
+
def testInvalidMap(self):
message = json_format_proto3_pb2.TestMap()
text = '{"int32Map": {"null": 2, "2": 3}}'
@@ -824,6 +864,12 @@ class JsonFormatTest(JsonFormatBase):
json_format.ParseError,
'Failed to load JSON: duplicate key a',
json_format.Parse, text, message)
+ text = r'{"stringMap": 0}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse stringMap field: Map field string_map must be '
+ 'in a dict which is 0.',
+ json_format.Parse, text, message)
def testInvalidTimestamp(self):
message = json_format_proto3_pb2.TestTimestamp()
@@ -911,6 +957,12 @@ class JsonFormatTest(JsonFormatBase):
json_format.MessageToJson(message))
self.assertEqual('{\n "int32_value": 12345\n}',
json_format.MessageToJson(message, False, True))
+ # When including_default_value_fields is True.
+ message = json_format_proto3_pb2.TestTimestamp()
+ self.assertEqual('{\n "repeatedValue": []\n}',
+ json_format.MessageToJson(message, True, False))
+ self.assertEqual('{\n "repeated_value": []\n}',
+ json_format.MessageToJson(message, True, True))
# Parsers accept both original proto field names and lowerCamelCase names.
message = json_format_proto3_pb2.TestMessage()
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index aa7af2d0..a1b6bb81 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -40,6 +40,7 @@ except ImportError:
import unittest
from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
@@ -130,6 +131,21 @@ class MessageFactoryTest(unittest.TestCase):
msg1.Extensions[ext2] = 'test2'
self.assertEqual('test1', msg1.Extensions[ext1])
self.assertEqual('test2', msg1.Extensions[ext2])
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByNumber(12321))
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix len to return the correct value.
+ # self.assertEqual(2, len(msg1.Extensions))
+ self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
+ self.assertRaises(TypeError,
+ msg1.Extensions._FindExtensionByName, 0)
+ self.assertRaises(TypeError,
+ msg1.Extensions._FindExtensionByNumber, '')
+ else:
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByName(0))
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByNumber(''))
def testDuplicateExtensionNumber(self):
pool = descriptor_pool.DescriptorPool()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index dda72cdd..4622f10f 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -61,6 +61,7 @@ try:
except NameError:
cmp = lambda x, y: (x > y) - (x < y) # Python 3
+from google.protobuf import map_proto2_unittest_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
@@ -69,6 +70,7 @@ from google.protobuf import descriptor_pool
from google.protobuf import message_factory
from google.protobuf import text_format
from google.protobuf.internal import api_implementation
+from google.protobuf.internal import encoder
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
@@ -140,6 +142,18 @@ class MessageTest(BaseTestCase):
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
+ def testParseErrors(self, message_module):
+ msg = message_module.TestAllTypes()
+ self.assertRaises(TypeError, msg.FromString, 0)
+ self.assertRaises(Exception, msg.FromString, '0')
+ # TODO(jieluo): Fix cpp extension to check unexpected end-group tag.
+ # b/27494216
+ if api_implementation.Type() == 'python':
+ end_tag = encoder.TagBytes(1, 4)
+ with self.assertRaises(message.DecodeError) as context:
+ msg.FromString(end_tag)
+ self.assertEqual('Unexpected end-group tag.', str(context.exception))
+
def testDeterminismParameters(self, message_module):
# This message is always deterministically serialized, even if determinism
# is disabled, so we can use it to verify that all the determinism
@@ -609,6 +623,13 @@ class MessageTest(BaseTestCase):
self.assertIsInstance(m.repeated_nested_message,
collections.MutableSequence)
+ def testRepeatedFieldsNotHashable(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(TypeError):
+ hash(m.repeated_int32)
+ with self.assertRaises(TypeError):
+ hash(m.repeated_nested_message)
+
def testRepeatedFieldInsideNestedMessage(self, message_module):
m = message_module.NestedTestAllTypes()
m.payload.repeated_int32.extend([])
@@ -626,6 +647,7 @@ class MessageTest(BaseTestCase):
def testOneofGetCaseNonexistingField(self, message_module):
m = message_module.TestAllTypes()
self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
+ self.assertRaises(Exception, m.WhichOneof, 0)
def testOneofDefaultValues(self, message_module):
m = message_module.TestAllTypes()
@@ -1001,6 +1023,8 @@ class MessageTest(BaseTestCase):
m = message_module.TestAllTypes()
with self.assertRaises(IndexError) as _:
m.repeated_nested_message.pop()
+ with self.assertRaises(TypeError) as _:
+ m.repeated_nested_message.pop('0')
for i in range(5):
n = m.repeated_nested_message.add()
n.bb = i
@@ -1009,6 +1033,39 @@ class MessageTest(BaseTestCase):
self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
+ def testRepeatedCompareWithSelf(self, message_module):
+ m = message_module.TestAllTypes()
+ for i in range(5):
+ m.repeated_int32.insert(i, i)
+ n = m.repeated_nested_message.add()
+ n.bb = i
+ self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
+ self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
+
+ def testReleasedNestedMessages(self, message_module):
+ """A case that lead to a segfault when a message detached from its parent
+ container has itself a child container.
+ """
+ m = message_module.NestedTestAllTypes()
+ m = m.repeated_child.add()
+ m = m.child
+ m = m.repeated_child.add()
+ self.assertEqual(m.payload.optional_int32, 0)
+
+ def testSetRepeatedComposite(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
+ m.repeated_int32.append(1)
+ if api_implementation.Type() == 'cpp':
+ # For test coverage: cpp has a different path if composite
+ # field is in cache
+ with self.assertRaises(TypeError):
+ m.repeated_int32 = []
+ else:
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
+
# Class to test proto2-only features (required, extensions, etc.)
class Proto2Test(BaseTestCase):
@@ -1061,18 +1118,46 @@ class Proto2Test(BaseTestCase):
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
- # TODO(tibell): The C++ implementations actually allows assignment
- # of unknown enum values to *scalar* fields (but not repeated
- # fields). Once checked enum fields becomes the default in the
- # Python implementation, the C++ implementation should follow suit.
def testAssignInvalidEnum(self):
- """It should not be possible to assign an invalid enum number to an
- enum field."""
+ """Assigning an invalid enum number is not allowed in proto2."""
m = unittest_pb2.TestAllTypes()
+ # Proto2 can not assign unknown enum.
with self.assertRaises(ValueError) as _:
m.optional_nested_enum = 1234567
self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
+ # Assignment is a different code path than append for the C++ impl.
+ m.repeated_nested_enum.append(2)
+ m.repeated_nested_enum[0] = 2
+ with self.assertRaises(ValueError):
+ m.repeated_nested_enum[0] = 123456
+
+ # Unknown enum value can be parsed but is ignored.
+ m2 = unittest_proto3_arena_pb2.TestAllTypes()
+ m2.optional_nested_enum = 1234567
+ m2.repeated_nested_enum.append(7654321)
+ serialized = m2.SerializeToString()
+
+ m3 = unittest_pb2.TestAllTypes()
+ m3.ParseFromString(serialized)
+ self.assertFalse(m3.HasField('optional_nested_enum'))
+ # 1 is the default value for optional_nested_enum.
+ self.assertEqual(1, m3.optional_nested_enum)
+ self.assertEqual(0, len(m3.repeated_nested_enum))
+ m2.Clear()
+ m2.ParseFromString(m3.SerializeToString())
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
+ def testUnknownEnumMap(self):
+ m = map_proto2_unittest_pb2.TestEnumMap()
+ m.known_map_field[123] = 0
+ with self.assertRaises(ValueError):
+ m.unknown_map_field[1] = 123
+
+ def testExtensionsErrors(self):
+ msg = unittest_pb2.TestAllTypes()
+ self.assertRaises(AttributeError, getattr, msg, 'Extensions')
def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message')
@@ -1297,6 +1382,7 @@ class Proto3Test(BaseTestCase):
"""Assigning an unknown enum value is allowed and preserves the value."""
m = unittest_proto3_arena_pb2.TestAllTypes()
+ # Proto3 can assign unknown enums.
m.optional_nested_enum = 1234567
self.assertEqual(1234567, m.optional_nested_enum)
m.repeated_nested_enum.append(22334455)
@@ -1311,18 +1397,10 @@ class Proto3Test(BaseTestCase):
self.assertEqual(1234567, m2.optional_nested_enum)
self.assertEqual(7654321, m2.repeated_nested_enum[0])
- # ParseFromString in Proto2 should accept unknown enums too.
- m3 = unittest_pb2.TestAllTypes()
- m3.ParseFromString(serialized)
- m2.Clear()
- m2.ParseFromString(m3.SerializeToString())
- self.assertEqual(1234567, m2.optional_nested_enum)
- self.assertEqual(7654321, m2.repeated_nested_enum[0])
-
# Map isn't really a proto3-only feature. But there is no proto2 equivalent
# of google/protobuf/map_unittest.proto right now, so it's not easy to
# test both with the same test like we do for the other proto2/proto3 tests.
- # (google/protobuf/map_protobuf_unittest.proto is very different in the set
+ # (google/protobuf/map_proto2_unittest.proto is very different in the set
# of messages and fields it contains).
def testScalarMapDefaults(self):
msg = map_unittest_pb2.TestMap()
@@ -1383,12 +1461,21 @@ class Proto3Test(BaseTestCase):
msg.map_int32_int32[5] = 15
self.assertEqual(15, msg.map_int32_int32.get(5))
+ self.assertEqual(15, msg.map_int32_int32.get(5))
+ with self.assertRaises(TypeError):
+ msg.map_int32_int32.get('')
self.assertIsNone(msg.map_int32_foreign_message.get(5))
self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
submsg = msg.map_int32_foreign_message[5]
self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
+ # TODO(jieluo): Fix python and cpp extension diff.
+ if api_implementation.Type() == 'cpp':
+ with self.assertRaises(TypeError):
+ msg.map_int32_foreign_message.get('')
+ else:
+ self.assertEqual(None, msg.map_int32_foreign_message.get(''))
def testScalarMap(self):
msg = map_unittest_pb2.TestMap()
@@ -1400,8 +1487,13 @@ class Proto3Test(BaseTestCase):
msg.map_int64_int64[-2**33] = -2**34
msg.map_uint32_uint32[123] = 456
msg.map_uint64_uint64[2**33] = 2**34
+ msg.map_int32_float[2] = 1.2
+ msg.map_int32_double[1] = 3.3
msg.map_string_string['abc'] = '123'
+ msg.map_bool_bool[True] = True
msg.map_int32_enum[888] = 2
+ # Unknown numeric enum is supported in proto3.
+ msg.map_int32_enum[123] = 456
self.assertEqual([], msg.FindInitializationErrors())
@@ -1435,8 +1527,24 @@ class Proto3Test(BaseTestCase):
self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
self.assertEqual(456, msg2.map_uint32_uint32[123])
self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+ self.assertAlmostEqual(1.2, msg.map_int32_float[2])
+ self.assertEqual(3.3, msg.map_int32_double[1])
self.assertEqual('123', msg2.map_string_string['abc'])
+ self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888])
+ self.assertEqual(456, msg2.map_int32_enum[123])
+ # TODO(jieluo): Add cpp extension support.
+ if api_implementation.Type() == 'python':
+ self.assertEqual('{-123: -456}',
+ str(msg2.map_int32_int32))
+
+ def testMapEntryAlwaysSerialized(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[0] = 0
+ msg.map_string_string[''] = ''
+ self.assertEqual(msg.ByteSize(), 12)
+ self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
+ msg.SerializeToString())
def testStringUnicodeConversionInMap(self):
msg = map_unittest_pb2.TestMap()
@@ -1489,6 +1597,11 @@ class Proto3Test(BaseTestCase):
self.assertIn(123, msg2.map_int32_foreign_message)
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ # TODO(jieluo): Fix text format for message map.
+ # TODO(jieluo): Add cpp extension support.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(15,
+ len(str(msg2.map_int32_foreign_message)))
def testNestedMessageMapItemDelete(self):
msg = map_unittest_pb2.TestMap()
@@ -1572,6 +1685,12 @@ class Proto3Test(BaseTestCase):
del msg2.map_int32_foreign_message[222]
self.assertFalse(222 in msg2.map_int32_foreign_message)
+ if api_implementation.Type() == 'cpp':
+ with self.assertRaises(TypeError):
+ del msg2.map_int32_foreign_message['']
+ else:
+ with self.assertRaises(KeyError):
+ del msg2.map_int32_foreign_message['']
def testMergeFromBadType(self):
msg = map_unittest_pb2.TestMap()
@@ -1706,6 +1825,54 @@ class Proto3Test(BaseTestCase):
matching_dict = {2: 4, 3: 6, 4: 8}
self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
+ def testPython2Map(self):
+ if sys.version_info < (3,):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+ msg.map_int32_int32[5] = 10
+ map_int32 = msg.map_int32_int32
+ self.assertEqual(4, len(map_int32))
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(msg.SerializeToString())
+
+ def CheckItems(seq, iterator):
+ self.assertEqual(next(iterator), seq[0])
+ self.assertEqual(list(iterator), seq[1:])
+
+ CheckItems(map_int32.items(), map_int32.iteritems())
+ CheckItems(map_int32.keys(), map_int32.iterkeys())
+ CheckItems(map_int32.values(), map_int32.itervalues())
+
+ self.assertEqual(6, map_int32.get(3))
+ self.assertEqual(None, map_int32.get(999))
+ self.assertEqual(6, map_int32.pop(3))
+ self.assertEqual(0, map_int32.pop(3))
+ self.assertEqual(3, len(map_int32))
+ key, value = map_int32.popitem()
+ self.assertEqual(2 * key, value)
+ self.assertEqual(2, len(map_int32))
+ map_int32.clear()
+ self.assertEqual(0, len(map_int32))
+
+ with self.assertRaises(KeyError):
+ map_int32.popitem()
+
+ self.assertEqual(0, map_int32.setdefault(2))
+ self.assertEqual(1, len(map_int32))
+
+ map_int32.update(msg2.map_int32_int32)
+ self.assertEqual(4, len(map_int32))
+
+ with self.assertRaises(TypeError):
+ map_int32.update(msg2.map_int32_int32,
+ msg2.map_int32_int32)
+ with self.assertRaises(TypeError):
+ map_int32.update(0)
+ with self.assertRaises(TypeError):
+ map_int32.update(value=12)
+
def testMapItems(self):
# Map items used to have strange behaviors when use c extension. Because
# [] may reorder the map and invalidate any exsting iterators.
@@ -1836,6 +2003,9 @@ class Proto3Test(BaseTestCase):
del msg.map_int32_int32[4]
self.assertEqual(0, len(msg.map_int32_int32))
+ with self.assertRaises(KeyError):
+ del msg.map_int32_all_types[32]
+
def testMapsAreMapping(self):
msg = map_unittest_pb2.TestMap()
self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
@@ -1844,6 +2014,14 @@ class Proto3Test(BaseTestCase):
self.assertIsInstance(msg.map_int32_foreign_message,
collections.MutableMapping)
+ def testMapsCompare(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[-123] = -456
+ self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
+ self.assertEqual(msg.map_int32_foreign_message,
+ msg.map_int32_foreign_message)
+ self.assertNotEqual(msg.map_int32_int32, 0)
+
def testMapFindInitializationErrorsSmokeTest(self):
msg = map_unittest_pb2.TestMap()
msg.map_string_string['abc'] = '123'
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index c363d843..975e3b4d 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1010,11 +1010,16 @@ def _AddByteSizeMethod(message_descriptor, cls):
return self._cached_byte_size
size = 0
- for field_descriptor, field_value in self.ListFields():
- size += field_descriptor._sizer(field_value)
-
- for tag_bytes, value_bytes in self._unknown_fields:
- size += len(tag_bytes) + len(value_bytes)
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ size = descriptor.fields_by_name['key']._sizer(self.key)
+ size += descriptor.fields_by_name['value']._sizer(self.value)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ size += field_descriptor._sizer(field_value)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ size += len(tag_bytes) + len(value_bytes)
self._cached_byte_size = size
self._cached_byte_size_dirty = False
@@ -1053,11 +1058,20 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
api_implementation.IsPythonDefaultSerializationDeterministic())
else:
deterministic = bool(deterministic)
- for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value, deterministic)
- for tag_bytes, value_bytes in self._unknown_fields:
- write_bytes(tag_bytes)
- write_bytes(value_bytes)
+
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ descriptor.fields_by_name['key']._encoder(
+ write_bytes, self.key, deterministic)
+ descriptor.fields_by_name['value']._encoder(
+ write_bytes, self.value, deterministic)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ field_descriptor._encoder(write_bytes, field_value, deterministic)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ write_bytes(tag_bytes)
+ write_bytes(value_bytes)
cls._InternalSerialize = InternalSerialize
@@ -1095,7 +1109,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not is_proto3:
+ if (not is_proto3 or
+ api_implementation.GetPythonProto3PreserveUnknownsDefault()):
if not unknown_field_list:
unknown_field_list = self._unknown_fields = []
unknown_field_list.append(
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 5ab5225e..0306ff46 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -40,7 +40,6 @@ import gc
import operator
import six
import struct
-import sys
try:
import unittest2 as unittest #PY26
@@ -621,6 +620,14 @@ class ReflectionTest(BaseTestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+ self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
+ self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
+ self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
+ # TODO(jieluo): Fix type checking difference for python and c extension
+ if api_implementation.Type() == 'python':
+ self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
+ else:
+ proto.optional_bool = 1.1
def assertIntegerTypes(self, integer_fn):
"""Verifies setting of scalar integers.
@@ -687,8 +694,10 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
self.assertEqual(expected_max, getattr(pb, field_name))
- self.assertRaises(Exception, setattr, pb, field_name, expected_min - 1)
- self.assertRaises(Exception, setattr, pb, field_name, expected_max + 1)
+ self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
+ expected_min - 1)
+ self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
+ expected_max + 1)
TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
@@ -697,7 +706,7 @@ class ReflectionTest(BaseTestCase):
# A bit of white-box testing since -1 is an int and not a long in C++ and
# so goes down a different path.
pb = unittest_pb2.TestAllTypes()
- with self.assertRaises(Exception):
+ with self.assertRaises((ValueError, TypeError)):
pb.optional_uint64 = integer_fn(-(1 << 63))
pb = unittest_pb2.TestAllTypes()
@@ -721,6 +730,12 @@ class ReflectionTest(BaseTestCase):
proto.repeated_int32[0] = 23
self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
+ self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
+ self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
+ 'index', 23)
+
+ proto.repeated_string.append('2')
+ self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
# Repeated enums tests.
#proto.repeated_nested_enum.append(0)
@@ -1008,6 +1023,14 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(4, len(proto.repeated_nested_message))
self.assertEqual(n1, proto.repeated_nested_message[2])
self.assertEqual(n2, proto.repeated_nested_message[3])
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend, n1)
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend, [0])
+ wrong_message_type = unittest_pb2.TestAllTypes()
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend,
+ [wrong_message_type])
# Test clearing.
proto.ClearField('repeated_nested_message')
@@ -1019,6 +1042,8 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(1, len(proto.repeated_nested_message))
self.assertEqual(23, proto.repeated_nested_message[0].bb)
self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
+ with self.assertRaises(Exception):
+ proto.repeated_nested_message[0] = 23
def testRepeatedCompositeRemove(self):
proto = unittest_pb2.TestAllTypes()
@@ -1643,8 +1668,11 @@ class ReflectionTest(BaseTestCase):
proto.SerializeToString()
proto.SerializePartialToString()
- def assertNotInitialized(self, proto):
+ def assertNotInitialized(self, proto, error_size=None):
+ errors = []
self.assertFalse(proto.IsInitialized())
+ self.assertFalse(proto.IsInitialized(errors))
+ self.assertEqual(error_size, len(errors))
self.assertRaises(message.EncodeError, proto.SerializeToString)
# "Partial" serialization doesn't care if message is uninitialized.
proto.SerializePartialToString()
@@ -1658,7 +1686,7 @@ class ReflectionTest(BaseTestCase):
# The case of uninitialized required fields.
proto = unittest_pb2.TestRequired()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
proto.a = proto.b = proto.c = 2
self.assertInitialized(proto)
@@ -1666,14 +1694,14 @@ class ReflectionTest(BaseTestCase):
proto = unittest_pb2.TestRequiredForeign()
self.assertInitialized(proto)
proto.optional_message.a = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 2)
proto.optional_message.b = 0
proto.optional_message.c = 0
self.assertInitialized(proto)
# Uninitialized repeated submessage.
message1 = proto.repeated_message.add()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
message1.a = message1.b = message1.c = 0
self.assertInitialized(proto)
@@ -1682,11 +1710,11 @@ class ReflectionTest(BaseTestCase):
extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 6)
message1.a = 1
message1.b = 1
message1.c = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
message2.a = 2
message2.b = 2
message2.c = 2
@@ -1696,7 +1724,7 @@ class ReflectionTest(BaseTestCase):
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 2)
proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3
self.assertInitialized(proto)
@@ -2155,6 +2183,8 @@ class ByteSizeTest(BaseTestCase):
foreign_message_1 = self.proto.repeated_nested_message.add()
foreign_message_1.bb = 9
self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
+ repeated_nested_message = copy.deepcopy(
+ self.proto.repeated_nested_message)
# 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
del self.proto.repeated_nested_message[0]
@@ -2175,6 +2205,16 @@ class ByteSizeTest(BaseTestCase):
del self.proto.repeated_nested_message[0]
self.assertEqual(0, self.Size())
+ self.assertEqual(2, len(repeated_nested_message))
+ del repeated_nested_message[0:1]
+ # TODO(jieluo): Fix cpp extension bug when delete repeated message.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(1, len(repeated_nested_message))
+ del repeated_nested_message[-1]
+ # TODO(jieluo): Fix cpp extension bug when delete repeated message.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(0, len(repeated_nested_message))
+
def testRepeatedGroups(self):
# 2-byte START_GROUP plus 2-byte END_GROUP.
group_0 = self.proto.repeatedgroup.add()
@@ -2191,6 +2231,10 @@ class ByteSizeTest(BaseTestCase):
proto.Extensions[extension] = 23
# 1 byte for tag, 1 byte for value.
self.assertEqual(2, proto.ByteSize())
+ field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
+ 'optional_int32']
+ with self.assertRaises(KeyError):
+ proto.Extensions[field] = 23
def testCacheInvalidationForNonrepeatedScalar(self):
# Test non-extension.
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index 62900b1d..77239f44 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -82,6 +82,10 @@ class FooUnitTest(unittest.TestCase):
service_descriptor = unittest_pb2.TestService.GetDescriptor()
srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
unittest_pb2.BarRequest(), MyCallback)
+ self.assertTrue(srvc.GetRequestClass(service_descriptor.methods[1]) is
+ unittest_pb2.BarRequest)
+ self.assertTrue(srvc.GetResponseClass(service_descriptor.methods[1]) is
+ unittest_pb2.BarResponse)
self.assertEqual('Method Bar not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 9434b7b1..a6e34ef5 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -44,9 +44,9 @@ from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
try:
- long # Python 2
+ long # Python 2
except NameError:
- long = int # Python 3
+ long = int # Python 3
# Tests whether the given TestAllTypes message is proto2 or not.
@@ -133,22 +133,37 @@ def SetAllNonLazyFields(message):
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
- # Add a second one of each field.
- message.repeated_int32.append(301)
- message.repeated_int64.append(302)
- message.repeated_uint32.append(303)
- message.repeated_uint64.append(304)
- message.repeated_sint32.append(305)
- message.repeated_sint64.append(306)
- message.repeated_fixed32.append(307)
- message.repeated_fixed64.append(308)
- message.repeated_sfixed32.append(309)
- message.repeated_sfixed64.append(310)
- message.repeated_float.append(311)
- message.repeated_double.append(312)
- message.repeated_bool.append(False)
- message.repeated_string.append(u'315')
- message.repeated_bytes.append(b'316')
+ # Add a second one of each field and set value by index.
+ message.repeated_int32.append(0)
+ message.repeated_int64.append(0)
+ message.repeated_uint32.append(0)
+ message.repeated_uint64.append(0)
+ message.repeated_sint32.append(0)
+ message.repeated_sint64.append(0)
+ message.repeated_fixed32.append(0)
+ message.repeated_fixed64.append(0)
+ message.repeated_sfixed32.append(0)
+ message.repeated_sfixed64.append(0)
+ message.repeated_float.append(0)
+ message.repeated_double.append(0)
+ message.repeated_bool.append(True)
+ message.repeated_string.append(u'0')
+ message.repeated_bytes.append(b'0')
+ message.repeated_int32[1] = 301
+ message.repeated_int64[1] = 302
+ message.repeated_uint32[1] = 303
+ message.repeated_uint64[1] = 304
+ message.repeated_sint32[1] = 305
+ message.repeated_sint64[1] = 306
+ message.repeated_fixed32[1] = 307
+ message.repeated_fixed64[1] = 308
+ message.repeated_sfixed32[1] = 309
+ message.repeated_sfixed64[1] = 310
+ message.repeated_float[1] = 311
+ message.repeated_double[1] = 312
+ message.repeated_bool[1] = False
+ message.repeated_string[1] = u'315'
+ message.repeated_bytes[1] = b'316'
if IsProto2(message):
message.repeatedgroup.add().a = 317
@@ -157,7 +172,8 @@ def SetAllNonLazyFields(message):
message.repeated_import_message.add().d = 320
message.repeated_lazy_message.add().bb = 327
- message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
+ message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
+ message.repeated_nested_enum[1] = unittest_pb2.TestAllTypes.BAZ
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
if IsProto2(message):
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 424b29cc..1214c3ea 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -35,6 +35,7 @@
__author__ = 'kenton@google.com (Kenton Varda)'
+import math
import re
import six
import string
@@ -53,8 +54,8 @@ from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
from google.protobuf.internal import any_test_pb2 as test_extend_any
-from google.protobuf.internal import test_util
from google.protobuf.internal import message_set_extensions_pb2
+from google.protobuf.internal import test_util
from google.protobuf import descriptor_pool
from google.protobuf import text_format
@@ -371,7 +372,10 @@ class TextFormatTest(TextFormatBase):
def testParseInvalidUtf8(self, message_module):
message = message_module.TestAllTypes()
text = 'repeated_string: "\\xc3\\xc3"'
- self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Parse(text, message)
+ self.assertEqual(e.exception.GetLine(), 1)
+ self.assertEqual(e.exception.GetColumn(), 28)
def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes()
@@ -784,13 +788,14 @@ class Proto2Tests(TextFormatBase):
' bin: "\xe0"'
' [nested_unknown_ext]: {\n'
' i: 23\n'
+ ' x: x\n'
' test: "test_string"\n'
' floaty_float: -0.315\n'
' num: -inf\n'
' multiline_str: "abc"\n'
' "def"\n'
' "xyz."\n'
- ' [nested_unknown_ext]: <\n'
+ ' [nested_unknown_ext.ext]: <\n'
' i: 23\n'
' i: 24\n'
' pointfloat: .3\n'
@@ -896,6 +901,14 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+ def testParseBadIdentifier(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { "bb": 1 }')
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Parse(text, message)
+ self.assertEqual(str(e.exception),
+ '1:27 : Expected identifier or number, got "bb".')
+
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
@@ -1095,6 +1108,19 @@ class Proto3Tests(unittest.TestCase):
' < data: "string" > '
'>')
+ def testUnknownEnums(self):
+ message = unittest_proto3_arena_pb2.TestAllTypes()
+ message2 = unittest_proto3_arena_pb2.TestAllTypes()
+ message.optional_nested_enum = 999
+ text_string = text_format.MessageToString(message)
+ # TODO(jieluo): proto3 should support numeric unknown enum.
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Parse(text_string, message2)
+ self.assertEqual(999, message2.optional_nested_enum)
+ self.assertEqual(str(e.exception),
+ '1:23 : Enum type "proto3_arena_unittest.TestAllTypes.'
+ 'NestedEnum" has no value with number 999.')
+
def testMergeExpandedAny(self):
message = any_test_pb2.TestAny()
text = ('any_value {\n'
@@ -1180,6 +1206,15 @@ class Proto3Tests(unittest.TestCase):
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
+ def testMergeMissingAnyEndToken(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n')
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Merge(text, message)
+ self.assertEqual(str(e.exception), '3:11 : Expected "}".')
+
class TokenizerTest(unittest.TestCase):
@@ -1191,7 +1226,7 @@ class TokenizerTest(unittest.TestCase):
'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f '
- 'False_bool: False True_bool: True')
+ 'False_bool: False True_bool: True X:iNf Y:-inF Z:nAN')
tokenizer = text_format.Tokenizer(text.splitlines())
methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':',
(tokenizer.ConsumeString, 'string1'),
@@ -1239,7 +1274,13 @@ class TokenizerTest(unittest.TestCase):
(tokenizer.ConsumeIdentifier, 'False_bool'), ':',
(tokenizer.ConsumeBool, False),
(tokenizer.ConsumeIdentifier, 'True_bool'), ':',
- (tokenizer.ConsumeBool, True)]
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'X'), ':',
+ (tokenizer.ConsumeFloat, float('inf')),
+ (tokenizer.ConsumeIdentifier, 'Y'), ':',
+ (tokenizer.ConsumeFloat, float('-inf')),
+ (tokenizer.ConsumeIdentifier, 'Z'), ':',
+ (tokenizer.ConsumeFloat, float('nan'))]
i = 0
while not tokenizer.AtEnd():
@@ -1248,6 +1289,8 @@ class TokenizerTest(unittest.TestCase):
token = tokenizer.token
self.assertEqual(token, m)
tokenizer.NextToken()
+ elif isinstance(m[1], float) and math.isnan(m[1]):
+ self.assertTrue(math.isnan(m[0]()))
else:
self.assertEqual(m[1], m[0]())
i += 1
@@ -1266,10 +1309,15 @@ class TokenizerTest(unittest.TestCase):
self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger())
self.assertTrue(tokenizer.AtEnd())
- text = '-0 0'
+ text = '-0 0 0 1.2'
tokenizer = text_format.Tokenizer(text.splitlines())
self.assertEqual(0, tokenizer.ConsumeInteger())
self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(True, tokenizer.TryConsumeInteger())
+ self.assertEqual(False, tokenizer.TryConsumeInteger())
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger()
+ self.assertEqual(1.2, tokenizer.ConsumeFloat())
self.assertTrue(tokenizer.AtEnd())
def testConsumeIntegers(self):
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index d614eaa8..9bdb6f27 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -54,10 +54,13 @@ from google.protobuf.internal import type_checkers
BaseTestCase = testing_refleaks.BaseTestCase
-def SkipIfCppImplementation(func):
+# CheckUnknownField() cannot be used by the C++ implementation because
+# some protect members are called. It is not a behavior difference
+# for python and C++ implementation.
+def SkipCheckUnknownFieldIfCppImplementation(func):
return unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
- 'C++ implementation does not expose unknown fields to Python')(func)
+ 'Addtional test for pure python involved protect members')(func)
class UnknownFieldsTest(BaseTestCase):
@@ -77,11 +80,24 @@ class UnknownFieldsTest(BaseTestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
- def testSerializeProto3(self):
- # Verify that proto3 doesn't preserve unknown fields.
+ def expectSerializeProto3(self, preserve):
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
- self.assertEqual(0, len(message.SerializeToString()))
+ if preserve:
+ self.assertEqual(self.all_fields_data, message.SerializeToString())
+ else:
+ self.assertEqual(0, len(message.SerializeToString()))
+
+ def testSerializeProto3(self):
+ # Verify that proto3 unknown fields behavior.
+ default_preserve = (api_implementation
+ .GetPythonProto3PreserveUnknownsDefault())
+ self.assertEqual(False, default_preserve)
+ self.expectSerializeProto3(default_preserve)
+ api_implementation.SetPythonProto3PreserveUnknownsDefault(
+ not default_preserve)
+ self.expectSerializeProto3(not default_preserve)
+ api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@@ -154,12 +170,13 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- # GetUnknownField() checks a detail of the Python implementation, which stores
- # unknown fields as serialized strings. It cannot be used by the C++
- # implementation: it's enough to check that the message is correctly
- # serialized.
+ # CheckUnknownField() is an additional Pure Python check which checks
+ # a detail of unknown fields. It cannot be used by the C++
+ # implementation because some protect members are called.
+ # The test is added for historical reasons. It is not necessary as
+ # serialized string is checked.
- def GetUnknownField(self, name):
+ def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@@ -168,42 +185,35 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(value, 0, len(value), self.all_fields, result_dict)
- return result_dict[field_descriptor]
-
- @SkipIfCppImplementation
- def testEnum(self):
- value = self.GetUnknownField('optional_nested_enum')
- self.assertEqual(self.all_fields.optional_nested_enum, value)
-
- @SkipIfCppImplementation
- def testRepeatedEnum(self):
- value = self.GetUnknownField('repeated_nested_enum')
- self.assertEqual(self.all_fields.repeated_nested_enum, value)
-
- @SkipIfCppImplementation
- def testVarint(self):
- value = self.GetUnknownField('optional_int32')
- self.assertEqual(self.all_fields.optional_int32, value)
-
- @SkipIfCppImplementation
- def testFixed32(self):
- value = self.GetUnknownField('optional_fixed32')
- self.assertEqual(self.all_fields.optional_fixed32, value)
-
- @SkipIfCppImplementation
- def testFixed64(self):
- value = self.GetUnknownField('optional_fixed64')
- self.assertEqual(self.all_fields.optional_fixed64, value)
-
- @SkipIfCppImplementation
- def testLengthDelimited(self):
- value = self.GetUnknownField('optional_string')
- self.assertEqual(self.all_fields.optional_string, value)
-
- @SkipIfCppImplementation
- def testGroup(self):
- value = self.GetUnknownField('optionalgroup')
- self.assertEqual(self.all_fields.optionalgroup, value)
+ self.assertEqual(expected_value, result_dict[field_descriptor])
+
+ @SkipCheckUnknownFieldIfCppImplementation
+ def testCheckUnknownFieldValue(self):
+ # Test enum.
+ self.CheckUnknownField('optional_nested_enum',
+ self.all_fields.optional_nested_enum)
+ # Test repeated enum.
+ self.CheckUnknownField('repeated_nested_enum',
+ self.all_fields.repeated_nested_enum)
+
+ # Test varint.
+ self.CheckUnknownField('optional_int32',
+ self.all_fields.optional_int32)
+ # Test fixed32.
+ self.CheckUnknownField('optional_fixed32',
+ self.all_fields.optional_fixed32)
+
+ # Test fixed64.
+ self.CheckUnknownField('optional_fixed64',
+ self.all_fields.optional_fixed64)
+
+ # Test lengthd elimited.
+ self.CheckUnknownField('optional_string',
+ self.all_fields.optional_string)
+
+ # Test group.
+ self.CheckUnknownField('optionalgroup',
+ self.all_fields.optionalgroup)
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@@ -263,12 +273,13 @@ class UnknownEnumValuesTest(BaseTestCase):
self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
self.missing_message.ParseFromString(self.message_data)
- # GetUnknownField() checks a detail of the Python implementation, which stores
- # unknown fields as serialized strings. It cannot be used by the C++
- # implementation: it's enough to check that the message is correctly
- # serialized.
+ # CheckUnknownField() is an additional Pure Python check which checks
+ # a detail of unknown fields. It cannot be used by the C++
+ # implementation because some protect members are called.
+ # The test is added for historical reasons. It is not necessary as
+ # serialized string is checked.
- def GetUnknownField(self, name):
+ def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@@ -278,7 +289,7 @@ class UnknownEnumValuesTest(BaseTestCase):
decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
tag_bytes][0]
decoder(value, 0, len(value), self.message, result_dict)
- return result_dict[field_descriptor]
+ self.assertEqual(expected_value, result_dict[field_descriptor])
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@@ -294,38 +305,27 @@ class UnknownEnumValuesTest(BaseTestCase):
self.assertEqual(missing.optional_nested_enum, 0)
def testUnknownEnumValue(self):
- if api_implementation.Type() == 'cpp':
- # The CPP implementation of protos (wrongly) allows unknown enum values
- # for proto2.
- self.assertTrue(self.missing_message.HasField('optional_nested_enum'))
- self.assertEqual(self.message.optional_nested_enum,
- self.missing_message.optional_nested_enum)
- else:
- # On the other hand, the Python implementation considers unknown values
- # as unknown fields. This is the correct behavior.
- self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
- value = self.GetUnknownField('optional_nested_enum')
- self.assertEqual(self.message.optional_nested_enum, value)
- self.missing_message.ClearField('optional_nested_enum')
self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
+ self.assertEqual(self.missing_message.optional_nested_enum, 2)
+ # Clear does not do anything.
+ serialized = self.missing_message.SerializeToString()
+ self.missing_message.ClearField('optional_nested_enum')
+ self.assertEqual(self.missing_message.SerializeToString(), serialized)
def testUnknownRepeatedEnumValue(self):
- if api_implementation.Type() == 'cpp':
- # For repeated enums, both implementations agree.
- self.assertEqual([], self.missing_message.repeated_nested_enum)
- else:
- self.assertEqual([], self.missing_message.repeated_nested_enum)
- value = self.GetUnknownField('repeated_nested_enum')
- self.assertEqual(self.message.repeated_nested_enum, value)
+ self.assertEqual([], self.missing_message.repeated_nested_enum)
def testUnknownPackedEnumValue(self):
- if api_implementation.Type() == 'cpp':
- # For repeated enums, both implementations agree.
- self.assertEqual([], self.missing_message.packed_nested_enum)
- else:
- self.assertEqual([], self.missing_message.packed_nested_enum)
- value = self.GetUnknownField('packed_nested_enum')
- self.assertEqual(self.message.packed_nested_enum, value)
+ self.assertEqual([], self.missing_message.packed_nested_enum)
+
+ @SkipCheckUnknownFieldIfCppImplementation
+ def testCheckUnknownFieldValueForEnum(self):
+ self.CheckUnknownField('optional_nested_enum',
+ self.message.optional_nested_enum)
+ self.CheckUnknownField('repeated_nested_enum',
+ self.message.repeated_nested_enum)
+ self.CheckUnknownField('packed_nested_enum',
+ self.message.packed_nested_enum)
def testRoundTrip(self):
new_message = missing_enum_values_pb2.TestEnumValues()
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index d0c7ffda..57b96998 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -473,7 +473,7 @@ def _IsValidPath(message_descriptor, path):
parts = path.split('.')
last = parts.pop()
for name in parts:
- field = message_descriptor.fields_by_name[name]
+ field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
@@ -698,6 +698,12 @@ def _SetStructValue(struct_value, value):
struct_value.string_value = value
elif isinstance(value, _INT_OR_FLOAT):
struct_value.number_value = value
+ elif isinstance(value, dict):
+ struct_value.struct_value.Clear()
+ struct_value.struct_value.update(value)
+ elif isinstance(value, list):
+ struct_value.list_value.Clear()
+ struct_value.list_value.extend(value)
else:
raise ValueError('Unexpected type')
@@ -733,13 +739,21 @@ class Struct(object):
def get_or_create_list(self, key):
"""Returns a list for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('list_value'):
+ # Clear will mark list_value modified which will indeed create a list.
+ self.fields[key].list_value.Clear()
return self.fields[key].list_value
def get_or_create_struct(self, key):
"""Returns a struct for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('struct_value'):
+ # Clear will mark struct_value modified which will indeed create a struct.
+ self.fields[key].struct_value.Clear()
return self.fields[key].struct_value
- # TODO(haberman): allow constructing/merging from dict.
+ def update(self, dictionary): # pylint: disable=invalid-name
+ for key, value in dictionary.items():
+ _SetStructValue(self.fields[key], value)
class ListValue(object):
@@ -768,11 +782,17 @@ class ListValue(object):
def add_struct(self):
"""Appends and returns a struct value as the next value in the list."""
- return self.values.add().struct_value
+ struct_value = self.values.add().struct_value
+ # Clear will mark struct_value modified which will indeed create a struct.
+ struct_value.Clear()
+ return struct_value
def add_list(self):
"""Appends and returns a list value as the next value in the list."""
- return self.values.add().list_value
+ list_value = self.values.add().list_value
+ # Clear will mark list_value modified which will indeed create a list.
+ list_value.Clear()
+ return list_value
WKTBASES = {
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 123a537c..70975da1 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -105,6 +105,10 @@ class TimeUtilTest(TimeUtilTestBase):
self.assertEqual(8 * 3600, message.seconds)
self.assertEqual(0, message.nanos)
+ # It is not easy to check with current time. For test coverage only.
+ message.GetCurrentTime()
+ self.assertNotEqual(8 * 3600, message.seconds)
+
def testDurationSerializeAndParse(self):
message = duration_pb2.Duration()
# Generated output should contain 3, 6, or 9 fractional digits.
@@ -268,6 +272,17 @@ class TimeUtilTest(TimeUtilTestBase):
def testInvalidTimestamp(self):
message = timestamp_pb2.Timestamp()
self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Failed to parse timestamp: missing valid timezone offset.',
+ message.FromJsonString,
+ '')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Failed to parse timestamp: invalid trailing data '
+ '1970-01-01T00:00:01Ztrail.',
+ message.FromJsonString,
+ '1970-01-01T00:00:01Ztrail')
+ self.assertRaisesRegexp(
ValueError,
'time data \'10000-01-01T00:00:00\' does not match'
' format \'%Y-%m-%dT%H:%M:%S\'',
@@ -322,6 +337,13 @@ class TimeUtilTest(TimeUtilTestBase):
r'Duration is not valid\: Seconds -315576000001 must be in range'
r' \[-315576000000\, 315576000000\].',
message.ToJsonString)
+ message.seconds = 0
+ message.nanos = 999999999 + 1
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Nanos 1000000000 must be in range'
+ r' \[-999999999\, 999999999\].',
+ message.ToJsonString)
class FieldMaskTest(unittest.TestCase):
@@ -363,10 +385,37 @@ class FieldMaskTest(unittest.TestCase):
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
for field in msg_descriptor.fields:
self.assertTrue(field.name in mask.paths)
+
+ def testIsValidForDescriptor(self):
+ msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ # Empty mask
+ mask = field_mask_pb2.FieldMask()
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # All fields from descriptor
+ mask.AllFieldsFromDescriptor(msg_descriptor)
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # Child under optional message
mask.paths.append('optional_nested_message.bb')
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # Repeated field is only allowed in the last position of path
mask.paths.append('repeated_nested_message.bb')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid top level field
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('xxx')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in root
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('xxx.zzz')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in internal node
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('optional_nested_message.xxx.zzz')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in leaf
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('optional_nested_message.xxx')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
def testCanonicalFrom(self):
mask = field_mask_pb2.FieldMask()
@@ -422,6 +471,9 @@ class FieldMaskTest(unittest.TestCase):
mask2.FromJsonString('foo.bar,bar')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
+ src = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError):
+ out_mask.Union(src, mask2)
def testIntersect(self):
mask1 = field_mask_pb2.FieldMask()
@@ -546,6 +598,19 @@ class FieldMaskTest(unittest.TestCase):
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+ def testMergeErrors(self):
+ src = unittest_pb2.TestAllTypes()
+ dst = unittest_pb2.TestAllTypes()
+ mask = field_mask_pb2.FieldMask()
+ test_util.SetAllFields(src)
+ mask.FromJsonString('optionalInt32.field')
+ with self.assertRaises(ValueError) as e:
+ mask.MergeMessage(src, dst)
+ self.assertEqual('Error: Field optional_int32 in message '
+ 'protobuf_unittest.TestAllTypes is not a singular '
+ 'message field and cannot have sub-fields.',
+ str(e.exception))
+
def testSnakeCaseToCamelCase(self):
self.assertEqual('fooBar',
well_known_types._SnakeCaseToCamelCase('foo_bar'))
@@ -611,6 +676,8 @@ class StructTest(unittest.TestCase):
struct_list = struct.get_or_create_list('key5')
struct_list.extend([6, 'seven', True, False, None])
struct_list.add_struct()['subkey2'] = 9
+ struct['key6'] = {'subkey': {}}
+ struct['key7'] = [2, False]
self.assertTrue(isinstance(struct, well_known_types.Struct))
self.assertEqual(5, struct['key1'])
@@ -621,9 +688,10 @@ class StructTest(unittest.TestCase):
inner_struct['subkey2'] = 9
self.assertEqual([6, 'seven', True, False, None, inner_struct],
list(struct['key5'].items()))
+ self.assertEqual({}, dict(struct['key6']['subkey'].fields))
+ self.assertEqual([2, False], list(struct['key7'].items()))
serialized = struct.SerializeToString()
-
struct2 = struct_pb2.Struct()
struct2.ParseFromString(serialized)
@@ -651,6 +719,17 @@ class StructTest(unittest.TestCase):
struct_list.add_list().extend([1, 'two', True, False, None])
self.assertEqual([1, 'two', True, False, None],
list(struct_list[6].items()))
+ struct_list.extend([{'nested_struct': 30}, ['nested_list', 99], {}, []])
+ self.assertEqual(11, len(struct_list.values))
+ self.assertEqual(30, struct_list[7]['nested_struct'])
+ self.assertEqual('nested_list', struct_list[8][0])
+ self.assertEqual(99, struct_list[8][1])
+ self.assertEqual({}, dict(struct_list[9].fields))
+ self.assertEqual([], list(struct_list[10].items()))
+ struct_list[0] = {'replace': 'set'}
+ struct_list[1] = ['replace', 'set']
+ self.assertEqual('set', struct_list[0]['replace'])
+ self.assertEqual(['replace', 'set'], list(struct_list[1].items()))
text_serialized = str(struct)
struct3 = struct_pb2.Struct()
@@ -660,6 +739,67 @@ class StructTest(unittest.TestCase):
struct.get_or_create_struct('key3')['replace'] = 12
self.assertEqual(12, struct['key3']['replace'])
+ # Tests empty list.
+ struct.get_or_create_list('empty_list')
+ empty_list = struct['empty_list']
+ self.assertEqual([], list(empty_list.items()))
+ list2 = struct_pb2.ListValue()
+ list2.add_list()
+ empty_list = list2[0]
+ self.assertEqual([], list(empty_list.items()))
+
+ # Tests empty struct.
+ struct.get_or_create_struct('empty_struct')
+ empty_struct = struct['empty_struct']
+ self.assertEqual({}, dict(empty_struct.fields))
+ list2.add_struct()
+ empty_struct = list2[1]
+ self.assertEqual({}, dict(empty_struct.fields))
+
+ def testMergeFrom(self):
+ struct = struct_pb2.Struct()
+ struct_class = struct.__class__
+
+ dictionary = {
+ 'key1': 5,
+ 'key2': 'abc',
+ 'key3': True,
+ 'key4': {'subkey': 11.0},
+ 'key5': [6, 'seven', True, False, None, {'subkey2': 9}],
+ 'key6': [['nested_list', True]],
+ 'empty_struct': {},
+ 'empty_list': []
+ }
+ struct.update(dictionary)
+ self.assertEqual(5, struct['key1'])
+ self.assertEqual('abc', struct['key2'])
+ self.assertIs(True, struct['key3'])
+ self.assertEqual(11, struct['key4']['subkey'])
+ inner_struct = struct_class()
+ inner_struct['subkey2'] = 9
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct['key5'].items()))
+ self.assertEqual(2, len(struct['key6'][0].values))
+ self.assertEqual('nested_list', struct['key6'][0][0])
+ self.assertEqual(True, struct['key6'][0][1])
+ empty_list = struct['empty_list']
+ self.assertEqual([], list(empty_list.items()))
+ empty_struct = struct['empty_struct']
+ self.assertEqual({}, dict(empty_struct.fields))
+
+ # According to documentation: "When parsing from the wire or when merging,
+ # if there are duplicate map keys the last key seen is used".
+ duplicate = {
+ 'key4': {'replace': 20},
+ 'key5': [[False, 5]]
+ }
+ struct.update(duplicate)
+ self.assertEqual(1, len(struct['key4'].fields))
+ self.assertEqual(20, struct['key4']['replace'])
+ self.assertEqual(1, len(struct['key5'].values))
+ self.assertEqual(False, struct['key5'][0][0])
+ self.assertEqual(5, struct['key5'][0][1])
+
class AnyTest(unittest.TestCase):
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 43be0701..abd15b77 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -720,14 +720,17 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
map_key, &value);
ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue()));
- // PyDict_DelItem will have key error if the key is not in the map. We do
- // not want to call PyErr_Clear() which may clear other errors. Thus
- // PyDict_Contains() check is called before delete.
- int contains = PyDict_Contains(self->message_dict, key.get());
- if (contains < 0) {
- return -1;
- }
- if (contains) {
+ PyObject* cmsg_value = PyDict_GetItem(self->message_dict, key.get());
+ if (cmsg_value) {
+ // Need to keep CMessage stay alive if it is still referenced after
+ // deletion. Makes a new message and swaps values into CMessage
+ // instead of just removing.
+ CMessage* cmsg = reinterpret_cast<CMessage*>(cmsg_value);
+ Message* msg = cmsg->message;
+ cmsg->owner.reset(msg->New());
+ cmsg->message = cmsg->owner.get();
+ cmsg->parent = NULL;
+ msg->GetReflection()->Swap(msg, cmsg->message);
if (PyDict_DelItem(self->message_dict, key.get()) < 0) {
return -1;
}
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 43482c54..0f54506b 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -605,19 +605,21 @@ void OutOfRangeError(PyObject* arg) {
template<class RangeType, class ValueType>
bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
- if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) {
- if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
- // Replace it with the same ValueError as pure python protos instead of
- // the default one.
- PyErr_Clear();
+ if
+ GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ // Replace it with the same ValueError as pure python protos instead of
+ // the default one.
+ PyErr_Clear();
+ OutOfRangeError(arg);
+ } // Otherwise propagate existing error.
+ return false;
+ }
+ if
+ GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) {
OutOfRangeError(arg);
- } // Otherwise propagate existing error.
- return false;
- }
- if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) {
- OutOfRangeError(arg);
- return false;
- }
+ return false;
+ }
return true;
}
@@ -626,25 +628,29 @@ bool CheckAndGetInteger(PyObject* arg, T* value) {
// The fast path.
#if PY_MAJOR_VERSION < 3
// For the typical case, offer a fast path.
- if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) {
- long int_result = PyInt_AsLong(arg);
- if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) {
- *value = static_cast<T>(int_result);
- return true;
- } else {
- OutOfRangeError(arg);
- return false;
+ if
+ GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) {
+ long int_result = PyInt_AsLong(arg);
+ if
+ GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) {
+ *value = static_cast<T>(int_result);
+ return true;
+ }
+ else {
+ OutOfRangeError(arg);
+ return false;
+ }
}
- }
#endif
// This effectively defines an integer as "an object that can be cast as
// an integer and can be used as an ordinal number".
// This definition includes everything that implements numbers.Integral
// and shouldn't cast the net too wide.
- if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) {
- FormatTypeError(arg, "int, long");
- return false;
- }
+ if
+ GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
// Now we have an integral number so we can safely use PyLong_ functions.
// We need to treat the signed and unsigned cases differently in case arg is
@@ -658,10 +664,11 @@ bool CheckAndGetInteger(PyObject* arg, T* value) {
// Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
// picky about the exact type.
PyObject* casted = PyNumber_Long(arg);
- if GOOGLE_PREDICT_FALSE(casted == NULL) {
- // Propagate existing error.
- return false;
- }
+ if
+ GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
ulong_result = PyLong_AsUnsignedLongLong(casted);
Py_DECREF(casted);
}
@@ -683,10 +690,11 @@ bool CheckAndGetInteger(PyObject* arg, T* value) {
// Valid subclasses of numbers.Integral should have a __long__() method
// so fall back to that.
PyObject* casted = PyNumber_Long(arg);
- if GOOGLE_PREDICT_FALSE(casted == NULL) {
- // Propagate existing error.
- return false;
- }
+ if
+ GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
long_result = PyLong_AsLongLong(casted);
Py_DECREF(casted);
}
@@ -709,10 +717,11 @@ template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
bool CheckAndGetDouble(PyObject* arg, double* value) {
*value = PyFloat_AsDouble(arg);
- if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) {
- FormatTypeError(arg, "int, long, float");
- return false;
- }
+ if
+ GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) {
+ FormatTypeError(arg, "int, long, float");
+ return false;
+ }
return true;
}
@@ -1553,20 +1562,7 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
if (message->GetReflection()->HasField(*message, field_descriptor)) {
Py_RETURN_TRUE;
}
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
- // Special case: Python HasField() differs in semantics from C++
- // slightly: we return HasField('enum_field') == true if there is
- // an unknown enum value present. To implement this we have to
- // look in the UnknownFieldSet.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() == field_descriptor->number()) {
- Py_RETURN_TRUE;
- }
- }
- }
+
Py_RETURN_FALSE;
}
@@ -1745,12 +1741,6 @@ PyObject* ClearFieldByDescriptor(
AssureWritable(self);
Message* message = self->message;
message->GetReflection()->ClearField(message, field_descriptor);
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
- !message->GetReflection()->SupportsUnknownEnumValues()) {
- UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- unknown_field_set->DeleteByNumber(field_descriptor->number());
- }
Py_RETURN_NONE;
}
@@ -2345,27 +2335,9 @@ PyObject* InternalGetScalar(const Message* message,
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- !message->GetReflection()->HasField(*message, field_descriptor)) {
- // Look for the value in the unknown fields.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() ==
- field_descriptor->number() &&
- unknown_field_set.field(i).type() ==
- google::protobuf::UnknownField::TYPE_VARINT) {
- result = PyInt_FromLong(unknown_field_set.field(i).varint());
- break;
- }
- }
- }
-
- if (result == NULL) {
- const EnumValueDescriptor* enum_value =
- message->GetReflection()->GetEnum(*message, field_descriptor);
- result = PyInt_FromLong(enum_value->number());
- }
+ const EnumValueDescriptor* enum_value =
+ message->GetReflection()->GetEnum(*message, field_descriptor);
+ result = PyInt_FromLong(enum_value->number());
break;
}
default:
@@ -3087,5 +3059,4 @@ bool InitProto2MessageModule(PyObject *m) {
} // namespace python
} // namespace protobuf
-
} // namespace google
diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc
index d90d9de3..7c4df47f 100644
--- a/python/google/protobuf/pyext/message_module.cc
+++ b/python/google/protobuf/pyext/message_module.cc
@@ -28,8 +28,33 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include <Python.h>
+
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/message_lite.h>
+
+static PyObject* GetPythonProto3PreserveUnknownsDefault(
+ PyObject* /*m*/, PyObject* /*args*/) {
+ if (google::protobuf::internal::GetProto3PreserveUnknownsDefault()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static PyObject* SetPythonProto3PreserveUnknownsDefault(
+ PyObject* /*m*/, PyObject* arg) {
+ if (!arg || !PyBool_Check(arg)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Argument to SetPythonProto3PreserveUnknownsDefault must be boolean");
+ return NULL;
+ }
+ google::protobuf::internal::SetProto3PreserveUnknownsDefault(PyObject_IsTrue(arg));
+ Py_RETURN_NONE;
+}
+
static const char module_docstring[] =
"python-proto2 is a module that can be used to enhance proto2 Python API\n"
"performance.\n"
@@ -41,6 +66,14 @@ static PyMethodDef ModuleMethods[] = {
{"SetAllowOversizeProtos",
(PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
METH_O, "Enable/disable oversize proto parsing."},
+ // DO NOT USE: For migration and testing only.
+ {"GetPythonProto3PreserveUnknownsDefault",
+ (PyCFunction)GetPythonProto3PreserveUnknownsDefault,
+ METH_NOARGS, "Get Proto3 preserve unknowns default."},
+ // DO NOT USE: For migration and testing only.
+ {"SetPythonProto3PreserveUnknownsDefault",
+ (PyCFunction)SetPythonProto3PreserveUnknownsDefault,
+ METH_O, "Enable/disable proto3 unknowns preservation."},
{ NULL, NULL}
};
diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto
index cce645d7..2e50df74 100644
--- a/python/google/protobuf/pyext/python.proto
+++ b/python/google/protobuf/pyext/python.proto
@@ -58,11 +58,11 @@ message ForeignMessage {
repeated int32 d = 2;
}
-message TestAllExtensions {
+message TestAllExtensions { // extension begin
extensions 1 to max;
-}
+} // extension end
-extend TestAllExtensions {
+extend TestAllExtensions { // extension begin
optional TestAllTypes.NestedMessage optional_nested_message_extension = 1;
repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2;
-}
+} // extension end
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index 54998800..5a7832cd 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -261,22 +261,6 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
result = ToStringObject(field_descriptor, value);
break;
}
- case FieldDescriptor::CPPTYPE_MESSAGE: {
- PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>(
- &CMessage_Type), NULL);
- if (py_cmsg == NULL) {
- return NULL;
- }
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
- const Message& msg = reflection->GetRepeatedMessage(
- *message, field_descriptor, index);
- cmsg->owner = self->owner;
- cmsg->parent = self->parent;
- cmsg->message = const_cast<Message*>(&msg);
- cmsg->read_only = false;
- result = reinterpret_cast<PyObject*>(py_cmsg);
- break;
- }
default:
PyErr_Format(
PyExc_SystemError,
diff --git a/python/setup.py b/python/setup.py
index 70b7de5c..efb74fe7 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -78,6 +78,7 @@ def generate_proto(source, require = True):
def GenerateUnittestProtos():
generate_proto("../src/google/protobuf/any_test.proto", False)
+ generate_proto("../src/google/protobuf/map_proto2_unittest.proto", False)
generate_proto("../src/google/protobuf/map_unittest.proto", False)
generate_proto("../src/google/protobuf/test_messages_proto3.proto", False)
generate_proto("../src/google/protobuf/test_messages_proto2.proto", False)