aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf')
-rwxr-xr-xpython/google/protobuf/__init__.py8
-rw-r--r--python/google/protobuf/compiler/__init__.py0
-rwxr-xr-xpython/google/protobuf/descriptor.py210
-rw-r--r--python/google/protobuf/descriptor_database.py49
-rw-r--r--python/google/protobuf/descriptor_pool.py484
-rwxr-xr-xpython/google/protobuf/internal/_parameterized.py52
-rw-r--r--python/google/protobuf/internal/any_test.proto15
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py74
-rwxr-xr-xpython/google/protobuf/internal/containers.py33
-rwxr-xr-xpython/google/protobuf/internal/decoder.py24
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py58
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py478
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test2.proto1
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py381
-rwxr-xr-xpython/google/protobuf/internal/encoder.py117
-rw-r--r--python/google/protobuf/internal/factory_test2.proto5
-rw-r--r--python/google/protobuf/internal/file_options_test.proto43
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py6
-rw-r--r--python/google/protobuf/internal/json_format_test.py310
-rw-r--r--python/google/protobuf/internal/message_factory_test.py113
-rwxr-xr-xpython/google/protobuf/internal/message_test.py523
-rw-r--r--python/google/protobuf/internal/more_extensions_dynamic.proto1
-rw-r--r--python/google/protobuf/internal/no_package.proto10
-rw-r--r--python/google/protobuf/internal/proto_builder_test.py1
-rwxr-xr-xpython/google/protobuf/internal/python_message.py203
-rw-r--r--python/google/protobuf/internal/python_protobuf.cc63
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py198
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py8
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py38
-rwxr-xr-xpython/google/protobuf/internal/test_util.py210
-rw-r--r--python/google/protobuf/internal/testing_refleaks.py126
-rwxr-xr-xpython/google/protobuf/internal/text_encoding_test.py3
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py1065
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py32
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py200
-rw-r--r--python/google/protobuf/internal/well_known_types.py159
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py347
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py3
-rw-r--r--python/google/protobuf/json_format.py904
-rwxr-xr-xpython/google/protobuf/message.py28
-rw-r--r--python/google/protobuf/message_factory.py38
-rw-r--r--python/google/protobuf/proto_api.h92
-rw-r--r--python/google/protobuf/pyext/__init__.py4
-rw-r--r--python/google/protobuf/pyext/cpp_message.py6
-rw-r--r--python/google/protobuf/pyext/descriptor.cc414
-rw-r--r--python/google/protobuf/pyext/descriptor.h8
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.cc616
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.h8
-rw-r--r--python/google/protobuf/pyext/descriptor_database.cc3
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc270
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h39
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc138
-rw-r--r--python/google/protobuf/pyext/extension_dict.h52
-rw-r--r--python/google/protobuf/pyext/map_container.cc166
-rw-r--r--python/google/protobuf/pyext/map_container.h36
-rw-r--r--python/google/protobuf/pyext/message.cc1033
-rw-r--r--python/google/protobuf/pyext/message.h132
-rw-r--r--python/google/protobuf/pyext/message_factory.cc283
-rw-r--r--python/google/protobuf/pyext/message_factory.h103
-rw-r--r--python/google/protobuf/pyext/message_module.cc138
-rw-r--r--python/google/protobuf/pyext/python.proto8
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc191
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.h28
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc191
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.h19
-rw-r--r--python/google/protobuf/pyext/safe_numerics.h164
-rw-r--r--python/google/protobuf/pyext/scoped_pyobject_ptr.h59
-rw-r--r--python/google/protobuf/pyext/thread_unsafe_shared_ptr.h104
-rw-r--r--python/google/protobuf/python_protobuf.h (renamed from python/google/protobuf/pyext/python_protobuf.h)0
-rwxr-xr-xpython/google/protobuf/reflection.py19
-rw-r--r--python/google/protobuf/symbol_database.py104
-rwxr-xr-xpython/google/protobuf/text_format.py1305
-rw-r--r--python/google/protobuf/util/__init__.py0
73 files changed, 9101 insertions, 3253 deletions
diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py
index 533821c1..d4360727 100755
--- a/python/google/protobuf/__init__.py
+++ b/python/google/protobuf/__init__.py
@@ -30,4 +30,10 @@
# Copyright 2007 Google Inc. All Rights Reserved.
-__version__ = '3.0.0b2'
+__version__ = '3.5.2'
+
+if __name__ != '__main__':
+ try:
+ __import__('pkg_resources').declare_namespace(__name__)
+ except ImportError:
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
diff --git a/python/google/protobuf/compiler/__init__.py b/python/google/protobuf/compiler/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/python/google/protobuf/compiler/__init__.py
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 5f613c88..8a9ba3da 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -34,6 +34,7 @@ file, in types that make this information accessible in Python.
__author__ = 'robinson@google.com (Will Robinson)'
+import threading
import six
from google.protobuf.internal import api_implementation
@@ -41,8 +42,8 @@ from google.protobuf.internal import api_implementation
_USE_C_DESCRIPTORS = False
if api_implementation.Type() == 'cpp':
# Used by MakeDescriptor in cpp mode
+ import binascii
import os
- import uuid
from google.protobuf.pyext import _message
_USE_C_DESCRIPTORS = getattr(_message, '_USE_C_DESCRIPTORS', False)
@@ -72,6 +73,24 @@ else:
DescriptorMetaclass = type
+class _Lock(object):
+ """Wrapper class of threading.Lock(), which is allowed by 'with'."""
+
+ def __new__(cls):
+ self = object.__new__(cls)
+ self._lock = threading.Lock() # pylint: disable=protected-access
+ return self
+
+ def __enter__(self):
+ self._lock.acquire()
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ self._lock.release()
+
+
+_lock = threading.Lock()
+
+
class DescriptorBase(six.with_metaclass(DescriptorMetaclass)):
"""Descriptors base class.
@@ -92,16 +111,17 @@ class DescriptorBase(six.with_metaclass(DescriptorMetaclass)):
# subclasses" of this descriptor class.
_C_DESCRIPTOR_CLASS = ()
- def __init__(self, options, options_class_name):
+ def __init__(self, options, serialized_options, options_class_name):
"""Initialize the descriptor given its options message and the name of the
class of the options message. The name of the class is required in case
the options message is None and has to be created.
"""
self._options = options
self._options_class_name = options_class_name
+ self._serialized_options = serialized_options
# Does this descriptor have non-default options?
- self.has_options = options is not None
+ self.has_options = (options is not None) or (serialized_options is not None)
def _SetOptions(self, options, options_class_name):
"""Sets the descriptor's options
@@ -123,14 +143,23 @@ class DescriptorBase(six.with_metaclass(DescriptorMetaclass)):
"""
if self._options:
return self._options
+
from google.protobuf import descriptor_pb2
try:
- options_class = getattr(descriptor_pb2, self._options_class_name)
+ options_class = getattr(descriptor_pb2,
+ self._options_class_name)
except AttributeError:
raise RuntimeError('Unknown options class name %s!' %
(self._options_class_name))
- self._options = options_class()
- return self._options
+
+ with _lock:
+ if self._serialized_options is None:
+ self._options = options_class()
+ else:
+ self._options = _ParseOptions(options_class(),
+ self._serialized_options)
+
+ return self._options
class _NestedDescriptorBase(DescriptorBase):
@@ -138,7 +167,7 @@ class _NestedDescriptorBase(DescriptorBase):
def __init__(self, options, options_class_name, name, full_name,
file, containing_type, serialized_start=None,
- serialized_end=None):
+ serialized_end=None, serialized_options=None):
"""Constructor.
Args:
@@ -157,9 +186,10 @@ class _NestedDescriptorBase(DescriptorBase):
file.serialized_pb that describes this descriptor.
serialized_end: The end index (exclusive) in block in the
file.serialized_pb that describes this descriptor.
+ serialized_options: Protocol message serilized options or None.
"""
super(_NestedDescriptorBase, self).__init__(
- options, options_class_name)
+ options, serialized_options, options_class_name)
self.name = name
# TODO(falk): Add function to calculate full_name instead of having it in
@@ -171,13 +201,6 @@ class _NestedDescriptorBase(DescriptorBase):
self._serialized_start = serialized_start
self._serialized_end = serialized_end
- def GetTopLevelContainingType(self):
- """Returns the root if this is a nested type, or itself if its the root."""
- desc = self
- while desc.containing_type is not None:
- desc = desc.containing_type
- return desc
-
def CopyToProto(self, proto):
"""Copies this to the matching proto in descriptor_pb2.
@@ -257,8 +280,9 @@ class Descriptor(_NestedDescriptorBase):
def __new__(cls, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
+ serialized_options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
- file=None, serialized_start=None, serialized_end=None,
+ file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
syntax=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindMessageTypeByName(full_name)
@@ -268,9 +292,10 @@ class Descriptor(_NestedDescriptorBase):
# name of the argument.
def __init__(self, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
+ serialized_options=None,
is_extendable=True, extension_ranges=None, oneofs=None,
- file=None, serialized_start=None, serialized_end=None,
- syntax=None): # pylint:disable=redefined-builtin
+ file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin
+ syntax=None):
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
@@ -280,7 +305,7 @@ class Descriptor(_NestedDescriptorBase):
super(Descriptor, self).__init__(
options, 'MessageOptions', name, full_name, file,
containing_type, serialized_start=serialized_start,
- serialized_end=serialized_end)
+ serialized_end=serialized_end, serialized_options=serialized_options)
# We have fields in addition to fields_by_name and fields_by_number,
# so that:
@@ -349,7 +374,7 @@ class Descriptor(_NestedDescriptorBase):
Args:
proto: An empty descriptor_pb2.DescriptorProto.
"""
- # This function is overriden to give a better doc comment.
+ # This function is overridden to give a better doc comment.
super(Descriptor, self).CopyToProto(proto)
@@ -413,6 +438,8 @@ class FieldDescriptor(DescriptorBase):
containing_oneof: (OneofDescriptor) If the field is a member of a oneof
union, contains its descriptor. Otherwise, None.
+
+ file: (FileDescriptor) Reference to file descriptor.
"""
# Must be consistent with C++ FieldDescriptor::Type enum in
@@ -497,7 +524,9 @@ class FieldDescriptor(DescriptorBase):
def __new__(cls, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None):
+ serialized_options=None,
+ has_default_value=True, containing_oneof=None, json_name=None,
+ file=None): # pylint: disable=redefined-builtin
_message.Message._CheckCalledFromGeneratedFile()
if is_extension:
return _message.default_pool.FindExtensionByName(full_name)
@@ -507,7 +536,9 @@ class FieldDescriptor(DescriptorBase):
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None):
+ serialized_options=None,
+ has_default_value=True, containing_oneof=None, json_name=None,
+ file=None): # pylint: disable=redefined-builtin
"""The arguments are as described in the description of FieldDescriptor
attributes above.
@@ -515,10 +546,16 @@ class FieldDescriptor(DescriptorBase):
(to deal with circular references between message types, for example).
Likewise for extension_scope.
"""
- super(FieldDescriptor, self).__init__(options, 'FieldOptions')
+ super(FieldDescriptor, self).__init__(
+ options, serialized_options, 'FieldOptions')
self.name = name
self.full_name = full_name
+ self.file = file
self._camelcase_name = None
+ if json_name is None:
+ self.json_name = _ToJsonName(name)
+ else:
+ self.json_name = json_name
self.index = index
self.number = number
self.type = type
@@ -596,13 +633,15 @@ class EnumDescriptor(_NestedDescriptorBase):
_C_DESCRIPTOR_CLASS = _message.EnumDescriptor
def __new__(cls, name, full_name, filename, values,
- containing_type=None, options=None, file=None,
+ containing_type=None, options=None,
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
serialized_start=None, serialized_end=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindEnumTypeByName(full_name)
def __init__(self, name, full_name, filename, values,
- containing_type=None, options=None, file=None,
+ containing_type=None, options=None,
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
serialized_start=None, serialized_end=None):
"""Arguments are as described in the attribute description above.
@@ -612,7 +651,7 @@ class EnumDescriptor(_NestedDescriptorBase):
super(EnumDescriptor, self).__init__(
options, 'EnumOptions', name, full_name, file,
containing_type, serialized_start=serialized_start,
- serialized_end=serialized_end)
+ serialized_end=serialized_end, serialized_options=serialized_options)
self.values = values
for value in self.values:
@@ -626,7 +665,7 @@ class EnumDescriptor(_NestedDescriptorBase):
Args:
proto: An empty descriptor_pb2.EnumDescriptorProto.
"""
- # This function is overriden to give a better doc comment.
+ # This function is overridden to give a better doc comment.
super(EnumDescriptor, self).CopyToProto(proto)
@@ -648,7 +687,9 @@ class EnumValueDescriptor(DescriptorBase):
if _USE_C_DESCRIPTORS:
_C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor
- def __new__(cls, name, index, number, type=None, options=None):
+ def __new__(cls, name, index, number,
+ type=None, # pylint: disable=redefined-builtin
+ options=None, serialized_options=None):
_message.Message._CheckCalledFromGeneratedFile()
# There is no way we can build a complete EnumValueDescriptor with the
# given parameters (the name of the Enum is not known, for example).
@@ -656,16 +697,19 @@ class EnumValueDescriptor(DescriptorBase):
# constructor, which will ignore it, so returning None is good enough.
return None
- def __init__(self, name, index, number, type=None, options=None):
+ def __init__(self, name, index, number,
+ type=None, # pylint: disable=redefined-builtin
+ options=None, serialized_options=None):
"""Arguments are as described in the attribute description above."""
- super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions')
+ super(EnumValueDescriptor, self).__init__(
+ options, serialized_options, 'EnumValueOptions')
self.name = name
self.index = index
self.number = number
self.type = type
-class OneofDescriptor(object):
+class OneofDescriptor(DescriptorBase):
"""Descriptor for a oneof field.
name: (str) Name of the oneof field.
@@ -682,12 +726,18 @@ class OneofDescriptor(object):
if _USE_C_DESCRIPTORS:
_C_DESCRIPTOR_CLASS = _message.OneofDescriptor
- def __new__(cls, name, full_name, index, containing_type, fields):
+ def __new__(
+ cls, name, full_name, index, containing_type, fields, options=None,
+ serialized_options=None):
_message.Message._CheckCalledFromGeneratedFile()
return _message.default_pool.FindOneofByName(full_name)
- def __init__(self, name, full_name, index, containing_type, fields):
+ def __init__(
+ self, name, full_name, index, containing_type, fields, options=None,
+ serialized_options=None):
"""Arguments are as described in the attribute description above."""
+ super(OneofDescriptor, self).__init__(
+ options, serialized_options, 'OneofOptions')
self.name = name
self.full_name = full_name
self.index = index
@@ -705,29 +755,40 @@ class ServiceDescriptor(_NestedDescriptorBase):
definition appears withing the .proto file.
methods: (list of MethodDescriptor) List of methods provided by this
service.
+ methods_by_name: (dict str -> MethodDescriptor) Same MethodDescriptor
+ objects as in |methods_by_name|, but indexed by "name" attribute in each
+ MethodDescriptor.
options: (descriptor_pb2.ServiceOptions) Service options message or
None to use default service options.
file: (FileDescriptor) Reference to file info.
"""
- def __init__(self, name, full_name, index, methods, options=None, file=None,
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.ServiceDescriptor
+
+ def __new__(cls, name, full_name, index, methods, options=None,
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
+ serialized_start=None, serialized_end=None):
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
+ return _message.default_pool.FindServiceByName(full_name)
+
+ def __init__(self, name, full_name, index, methods, options=None,
+ serialized_options=None, file=None, # pylint: disable=redefined-builtin
serialized_start=None, serialized_end=None):
super(ServiceDescriptor, self).__init__(
options, 'ServiceOptions', name, full_name, file,
None, serialized_start=serialized_start,
- serialized_end=serialized_end)
+ serialized_end=serialized_end, serialized_options=serialized_options)
self.index = index
self.methods = methods
+ self.methods_by_name = dict((m.name, m) for m in methods)
# Set the containing service for each method in this service.
for method in self.methods:
method.containing_service = self
def FindMethodByName(self, name):
"""Searches for the specified method, and returns its descriptor."""
- for method in self.methods:
- if name == method.name:
- return method
- return None
+ return self.methods_by_name.get(name, None)
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.ServiceDescriptorProto.
@@ -735,7 +796,7 @@ class ServiceDescriptor(_NestedDescriptorBase):
Args:
proto: An empty descriptor_pb2.ServiceDescriptorProto.
"""
- # This function is overriden to give a better doc comment.
+ # This function is overridden to give a better doc comment.
super(ServiceDescriptor, self).CopyToProto(proto)
@@ -754,14 +815,23 @@ class MethodDescriptor(DescriptorBase):
None to use default method options.
"""
+ if _USE_C_DESCRIPTORS:
+ _C_DESCRIPTOR_CLASS = _message.MethodDescriptor
+
+ def __new__(cls, name, full_name, index, containing_service,
+ input_type, output_type, options=None, serialized_options=None):
+ _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access
+ return _message.default_pool.FindMethodByName(full_name)
+
def __init__(self, name, full_name, index, containing_service,
- input_type, output_type, options=None):
+ input_type, output_type, options=None, serialized_options=None):
"""The arguments are as described in the description of MethodDescriptor
attributes above.
Note that containing_service may be None, and may be set later if necessary.
"""
- super(MethodDescriptor, self).__init__(options, 'MethodOptions')
+ super(MethodDescriptor, self).__init__(
+ options, serialized_options, 'MethodOptions')
self.name = name
self.full_name = full_name
self.index = index
@@ -783,9 +853,12 @@ class FileDescriptor(DescriptorBase):
serialized_pb: (str) Byte string of serialized
descriptor_pb2.FileDescriptorProto.
dependencies: List of other FileDescriptors this FileDescriptor depends on.
+ public_dependencies: A list of FileDescriptors, subset of the dependencies
+ above, which were declared as "public".
message_types_by_name: Dict of message names of their descriptors.
enum_types_by_name: Dict of enum names and their descriptors.
extensions_by_name: Dict of extension names and their descriptors.
+ services_by_name: Dict of services names and their descriptors.
pool: the DescriptorPool this descriptor belongs to. When not passed to the
constructor, the global default pool is used.
"""
@@ -793,8 +866,10 @@ class FileDescriptor(DescriptorBase):
if _USE_C_DESCRIPTORS:
_C_DESCRIPTOR_CLASS = _message.FileDescriptor
- def __new__(cls, name, package, options=None, serialized_pb=None,
- dependencies=None, syntax=None, pool=None):
+ def __new__(cls, name, package, options=None,
+ serialized_options=None, serialized_pb=None,
+ dependencies=None, public_dependencies=None,
+ syntax=None, pool=None):
# FileDescriptor() is called from various places, not only from generated
# files, to register dynamic proto files and messages.
if serialized_pb:
@@ -804,10 +879,13 @@ class FileDescriptor(DescriptorBase):
else:
return super(FileDescriptor, cls).__new__(cls)
- def __init__(self, name, package, options=None, serialized_pb=None,
- dependencies=None, syntax=None, pool=None):
+ def __init__(self, name, package, options=None,
+ serialized_options=None, serialized_pb=None,
+ dependencies=None, public_dependencies=None,
+ syntax=None, pool=None):
"""Constructor."""
- super(FileDescriptor, self).__init__(options, 'FileOptions')
+ super(FileDescriptor, self).__init__(
+ options, serialized_options, 'FileOptions')
if pool is None:
from google.protobuf import descriptor_pool
@@ -821,7 +899,9 @@ class FileDescriptor(DescriptorBase):
self.enum_types_by_name = {}
self.extensions_by_name = {}
+ self.services_by_name = {}
self.dependencies = (dependencies or [])
+ self.public_dependencies = (public_dependencies or [])
if (api_implementation.Type() == 'cpp' and
self.serialized_pb is not None):
@@ -867,6 +947,31 @@ def _ToCamelCase(name):
return ''.join(result)
+def _OptionsOrNone(descriptor_proto):
+ """Returns the value of the field `options`, or None if it is not set."""
+ if descriptor_proto.HasField('options'):
+ return descriptor_proto.options
+ else:
+ return None
+
+
+def _ToJsonName(name):
+ """Converts name to Json name and returns it."""
+ capitalize_next = False
+ result = []
+
+ for c in name:
+ if c == '_':
+ capitalize_next = True
+ elif capitalize_next:
+ result.append(c.upper())
+ capitalize_next = False
+ else:
+ result += c
+
+ return ''.join(result)
+
+
def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
syntax=None):
"""Make a protobuf Descriptor given a DescriptorProto protobuf.
@@ -898,7 +1003,7 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
# imported ones. We need to specify a file name so the descriptor pool
# accepts our FileDescriptorProto, but it is not important what that file
# name is actually set to.
- proto_name = str(uuid.uuid4())
+ proto_name = binascii.hexlify(os.urandom(16)).decode('ascii')
if package:
file_descriptor_proto.name = os.path.join(package.replace('.', '/'),
@@ -943,6 +1048,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
full_name = '.'.join(full_message_name + [field_proto.name])
enum_desc = None
nested_desc = None
+ if field_proto.json_name:
+ json_name = field_proto.json_name
+ else:
+ json_name = None
if field_proto.HasField('type_name'):
type_name = field_proto.type_name
full_type_name = '.'.join(full_message_name +
@@ -957,10 +1066,11 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
field_proto.number, field_proto.type,
FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
field_proto.label, None, nested_desc, enum_desc, None, False, None,
- options=field_proto.options, has_default_value=False)
+ options=_OptionsOrNone(field_proto), has_default_value=False,
+ json_name=json_name)
fields.append(field)
desc_name = '.'.join(full_message_name)
return Descriptor(desc_proto.name, desc_name, None, None, fields,
list(nested_types.values()), list(enum_types.values()), [],
- options=desc_proto.options)
+ options=_OptionsOrNone(desc_proto))
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index 1333f996..8b7715cd 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -32,6 +32,8 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
+import warnings
+
class Error(Exception):
pass
@@ -54,9 +56,9 @@ class DescriptorDatabase(object):
Args:
file_desc_proto: The FileDescriptorProto to add.
Raises:
- DescriptorDatabaseException: if an attempt is made to add a proto
- with the same name but different definition than an exisiting
- proto in the database.
+ DescriptorDatabaseConflictingDefinitionError: if an attempt is made to
+ add a proto with the same name but different definition than an
+ exisiting proto in the database.
"""
proto_name = file_desc_proto.name
if proto_name not in self._file_desc_protos_by_file:
@@ -64,18 +66,20 @@ class DescriptorDatabase(object):
elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
raise DescriptorDatabaseConflictingDefinitionError(
'%s already added, but with different descriptor.' % proto_name)
+ else:
+ return
- # Add the top-level Message, Enum and Extension descriptors to the index.
+ # Add all the top-level descriptors to the index.
package = file_desc_proto.package
for message in file_desc_proto.message_type:
- self._file_desc_protos_by_symbol.update(
- (name, file_desc_proto) for name in _ExtractSymbols(message, package))
+ for name in _ExtractSymbols(message, package):
+ self._AddSymbol(name, file_desc_proto)
for enum in file_desc_proto.enum_type:
- self._file_desc_protos_by_symbol[
- '.'.join((package, enum.name))] = file_desc_proto
+ self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
for extension in file_desc_proto.extension:
- self._file_desc_protos_by_symbol[
- '.'.join((package, extension.name))] = file_desc_proto
+ self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
+ for service in file_desc_proto.service:
+ self._AddSymbol(('.'.join((package, service.name))), file_desc_proto)
def FindFileByName(self, name):
"""Finds the file descriptor proto by file name.
@@ -104,6 +108,7 @@ class DescriptorDatabase(object):
'some.package.name.Message'
'some.package.name.Message.NestedEnum'
+ 'some.package.name.Message.some_field'
The file descriptor proto containing the specified symbol must be added to
this database using the Add method or else an error will be raised.
@@ -117,8 +122,25 @@ class DescriptorDatabase(object):
Raises:
KeyError if no file contains the specified symbol.
"""
-
- return self._file_desc_protos_by_symbol[symbol]
+ try:
+ return self._file_desc_protos_by_symbol[symbol]
+ except KeyError:
+ # Fields, enum values, and nested extensions are not in
+ # _file_desc_protos_by_symbol. Try to find the top level
+ # descriptor. Non-existent nested symbol under a valid top level
+ # descriptor can also be found. The behavior is the same with
+ # protobuf C++.
+ top_level, _, _ = symbol.rpartition('.')
+ return self._file_desc_protos_by_symbol[top_level]
+
+ def _AddSymbol(self, name, file_desc_proto):
+ if name in self._file_desc_protos_by_symbol:
+ warn_msg = ('Conflict register for file "' + file_desc_proto.name +
+ '": ' + name +
+ ' is already defined in file "' +
+ self._file_desc_protos_by_symbol[name].name + '"')
+ warnings.warn(warn_msg, RuntimeWarning)
+ self._file_desc_protos_by_symbol[name] = file_desc_proto
def _ExtractSymbols(desc_proto, package):
@@ -131,8 +153,7 @@ def _ExtractSymbols(desc_proto, package):
Yields:
The fully qualified name found in the descriptor.
"""
-
- message_name = '.'.join((package, desc_proto.name))
+ message_name = package + '.' + desc_proto.name if package else desc_proto.name
yield message_name
for nested_type in desc_proto.nested_type:
for symbol in _ExtractSymbols(nested_type, message_name):
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 3e80795c..8983f76f 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -57,12 +57,15 @@ directly instead of this class.
__author__ = 'matthewtoia@google.com (Matt Toia)'
+import collections
+import warnings
+
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import text_encoding
-_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS
+_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
def _NormalizeFullyQualifiedName(name):
@@ -80,6 +83,22 @@ def _NormalizeFullyQualifiedName(name):
return name.lstrip('.')
+def _OptionsOrNone(descriptor_proto):
+ """Returns the value of the field `options`, or None if it is not set."""
+ if descriptor_proto.HasField('options'):
+ return descriptor_proto.options
+ else:
+ return None
+
+
+def _IsMessageSetExtension(field):
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
+
+
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
@@ -106,7 +125,40 @@ class DescriptorPool(object):
self._descriptor_db = descriptor_db
self._descriptors = {}
self._enum_descriptors = {}
+ self._service_descriptors = {}
self._file_descriptors = {}
+ self._toplevel_extensions = {}
+ # TODO(jieluo): Remove _file_desc_by_toplevel_extension after
+ # maybe year 2020 for compatibility issue (with 3.4.1 only).
+ self._file_desc_by_toplevel_extension = {}
+ # We store extensions in two two-level mappings: The first key is the
+ # descriptor of the message being extended, the second key is the extension
+ # full name or its tag number.
+ self._extensions_by_name = collections.defaultdict(dict)
+ self._extensions_by_number = collections.defaultdict(dict)
+
+ def _CheckConflictRegister(self, desc):
+ """Check if the descriptor name conflicts with another of the same name.
+
+ Args:
+ desc: Descriptor of a message, enum, service or extension.
+ """
+ desc_name = desc.full_name
+ for register, descriptor_type in [
+ (self._descriptors, descriptor.Descriptor),
+ (self._enum_descriptors, descriptor.EnumDescriptor),
+ (self._service_descriptors, descriptor.ServiceDescriptor),
+ (self._toplevel_extensions, descriptor.FieldDescriptor)]:
+ if desc_name in register:
+ file_name = register[desc_name].file.name
+ if not isinstance(desc, descriptor_type) or (
+ file_name != desc.file.name):
+ warn_msg = ('Conflict register for file "' + desc.file.name +
+ '": ' + desc_name +
+ ' is already defined in file "' +
+ file_name + '"')
+ warnings.warn(warn_msg, RuntimeWarning)
+ return
def Add(self, file_desc_proto):
"""Adds the FileDescriptorProto and its types to this pool.
@@ -144,13 +196,15 @@ class DescriptorPool(object):
if not isinstance(desc, descriptor.Descriptor):
raise TypeError('Expected instance of descriptor.Descriptor.')
+ self._CheckConflictRegister(desc)
+
self._descriptors[desc.full_name] = desc
- self.AddFileDescriptor(desc.file)
+ self._AddFileDescriptor(desc.file)
def AddEnumDescriptor(self, enum_desc):
"""Adds an EnumDescriptor to the pool.
- This method also registers the FileDescriptor associated with the message.
+ This method also registers the FileDescriptor associated with the enum.
Args:
enum_desc: An EnumDescriptor.
@@ -159,8 +213,65 @@ class DescriptorPool(object):
if not isinstance(enum_desc, descriptor.EnumDescriptor):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
+ self._CheckConflictRegister(enum_desc)
self._enum_descriptors[enum_desc.full_name] = enum_desc
- self.AddFileDescriptor(enum_desc.file)
+ self._AddFileDescriptor(enum_desc.file)
+
+ def AddServiceDescriptor(self, service_desc):
+ """Adds a ServiceDescriptor to the pool.
+
+ Args:
+ service_desc: A ServiceDescriptor.
+ """
+
+ if not isinstance(service_desc, descriptor.ServiceDescriptor):
+ raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
+
+ self._CheckConflictRegister(service_desc)
+ self._service_descriptors[service_desc.full_name] = service_desc
+
+ def AddExtensionDescriptor(self, extension):
+ """Adds a FieldDescriptor describing an extension to the pool.
+
+ Args:
+ extension: A FieldDescriptor.
+
+ Raises:
+ AssertionError: when another extension with the same number extends the
+ same message.
+ TypeError: when the specified extension is not a
+ descriptor.FieldDescriptor.
+ """
+ if not (isinstance(extension, descriptor.FieldDescriptor) and
+ extension.is_extension):
+ raise TypeError('Expected an extension descriptor.')
+
+ if extension.extension_scope is None:
+ self._CheckConflictRegister(extension)
+ self._toplevel_extensions[extension.full_name] = extension
+
+ try:
+ existing_desc = self._extensions_by_number[
+ extension.containing_type][extension.number]
+ except KeyError:
+ pass
+ else:
+ if extension is not existing_desc:
+ raise AssertionError(
+ 'Extensions "%s" and "%s" both try to extend message type "%s" '
+ 'with field number %d.' %
+ (extension.full_name, existing_desc.full_name,
+ extension.containing_type.full_name, extension.number))
+
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+
+ # Also register MessageSet extensions with the type name.
+ if _IsMessageSetExtension(extension):
+ self._extensions_by_name[extension.containing_type][
+ extension.message_type.full_name] = extension
def AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -172,6 +283,24 @@ class DescriptorPool(object):
file_desc: A FileDescriptor.
"""
+ self._AddFileDescriptor(file_desc)
+ # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
+ # FieldDescriptor.file is added in code gen. Remove this solution after
+ # maybe 2020 for compatibility reason (with 3.4.1 only).
+ for extension in file_desc.extensions_by_name.values():
+ self._file_desc_by_toplevel_extension[
+ extension.full_name] = file_desc
+
+ def _AddFileDescriptor(self, file_desc):
+ """Adds a FileDescriptor to the pool, non-recursively.
+
+ If the FileDescriptor contains messages or enums, the caller must explicitly
+ register them.
+
+ Args:
+ file_desc: A FileDescriptor.
+ """
+
if not isinstance(file_desc, descriptor.FileDescriptor):
raise TypeError('Expected instance of descriptor.FileDescriptor.')
self._file_descriptors[file_desc.name] = file_desc
@@ -186,7 +315,7 @@ class DescriptorPool(object):
A FileDescriptor for the named file.
Raises:
- KeyError: if the file can not be found in the pool.
+ KeyError: if the file cannot be found in the pool.
"""
try:
@@ -215,7 +344,7 @@ class DescriptorPool(object):
A FileDescriptor that contains the specified symbol.
Raises:
- KeyError: if the file can not be found in the pool.
+ KeyError: if the file cannot be found in the pool.
"""
symbol = _NormalizeFullyQualifiedName(symbol)
@@ -230,15 +359,28 @@ class DescriptorPool(object):
pass
try:
- file_proto = self._internal_db.FindFileContainingSymbol(symbol)
- except KeyError as error:
- if self._descriptor_db:
- file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
- else:
- raise error
- if not file_proto:
+ return self._service_descriptors[symbol].file
+ except KeyError:
+ pass
+
+ try:
+ return self._FindFileContainingSymbolInDb(symbol)
+ except KeyError:
+ pass
+
+ try:
+ return self._file_desc_by_toplevel_extension[symbol]
+ except KeyError:
+ pass
+
+ # Try nested extensions inside a message.
+ message_name, _, extension_name = symbol.rpartition('.')
+ try:
+ message = self.FindMessageTypeByName(message_name)
+ assert message.extensions_by_name[extension_name]
+ return message.file
+ except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)
- return self._ConvertFileProtoToFileDescriptor(file_proto)
def FindMessageTypeByName(self, full_name):
"""Loads the named descriptor from the pool.
@@ -248,11 +390,14 @@ class DescriptorPool(object):
Returns:
The descriptor for the named type.
+
+ Raises:
+ KeyError: if the message cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
if full_name not in self._descriptors:
- self.FindFileContainingSymbol(full_name)
+ self._FindFileContainingSymbolInDb(full_name)
return self._descriptors[full_name]
def FindEnumTypeByName(self, full_name):
@@ -263,11 +408,14 @@ class DescriptorPool(object):
Returns:
The enum descriptor for the named type.
+
+ Raises:
+ KeyError: if the enum cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
if full_name not in self._enum_descriptors:
- self.FindFileContainingSymbol(full_name)
+ self._FindFileContainingSymbolInDb(full_name)
return self._enum_descriptors[full_name]
def FindFieldByName(self, full_name):
@@ -278,12 +426,32 @@ class DescriptorPool(object):
Returns:
The field descriptor for the named field.
+
+ Raises:
+ KeyError: if the field cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
message_name, _, field_name = full_name.rpartition('.')
message_descriptor = self.FindMessageTypeByName(message_name)
return message_descriptor.fields_by_name[field_name]
+ def FindOneofByName(self, full_name):
+ """Loads the named oneof descriptor from the pool.
+
+ Args:
+ full_name: The full name of the oneof descriptor to load.
+
+ Returns:
+ The oneof descriptor for the named oneof.
+
+ Raises:
+ KeyError: if the oneof cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ message_name, _, oneof_name = full_name.rpartition('.')
+ message_descriptor = self.FindMessageTypeByName(message_name)
+ return message_descriptor.oneofs_by_name[oneof_name]
+
def FindExtensionByName(self, full_name):
"""Loads the named extension descriptor from the pool.
@@ -292,17 +460,101 @@ class DescriptorPool(object):
Returns:
A FieldDescriptor, describing the named extension.
+
+ Raises:
+ KeyError: if the extension cannot be found in the pool.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
+ try:
+ # The proto compiler does not give any link between the FileDescriptor
+ # and top-level extensions unless the FileDescriptorProto is added to
+ # the DescriptorDatabase, but this can impact memory usage.
+ # So we registered these extensions by name explicitly.
+ return self._toplevel_extensions[full_name]
+ except KeyError:
+ pass
message_name, _, extension_name = full_name.rpartition('.')
try:
# Most extensions are nested inside a message.
scope = self.FindMessageTypeByName(message_name)
except KeyError:
# Some extensions are defined at file scope.
- scope = self.FindFileContainingSymbol(full_name)
+ scope = self._FindFileContainingSymbolInDb(full_name)
return scope.extensions_by_name[extension_name]
+ def FindExtensionByNumber(self, message_descriptor, number):
+ """Gets the extension of the specified message with the specified number.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+ number: integer, number of the extension field.
+
+ Returns:
+ A FieldDescriptor describing the extension.
+
+ Raises:
+ KeyError: when no extension with the given number is known for the
+ specified message.
+ """
+ return self._extensions_by_number[message_descriptor][number]
+
+ def FindAllExtensions(self, message_descriptor):
+ """Gets all the known extension of a given message.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+
+ Returns:
+ A list of FieldDescriptor describing the extensions.
+ """
+ return list(self._extensions_by_number[message_descriptor].values())
+
+ def FindServiceByName(self, full_name):
+ """Loads the named service descriptor from the pool.
+
+ Args:
+ full_name: The full name of the service descriptor to load.
+
+ Returns:
+ The service descriptor for the named service.
+
+ Raises:
+ KeyError: if the service cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ if full_name not in self._service_descriptors:
+ self._FindFileContainingSymbolInDb(full_name)
+ return self._service_descriptors[full_name]
+
+ def _FindFileContainingSymbolInDb(self, symbol):
+ """Finds the file in descriptor DB containing the specified symbol.
+
+ Args:
+ symbol: The name of the symbol to search for.
+
+ Returns:
+ A FileDescriptor that contains the specified symbol.
+
+ Raises:
+ KeyError: if the file cannot be found in the descriptor database.
+ """
+ try:
+ file_proto = self._internal_db.FindFileContainingSymbol(symbol)
+ except KeyError as error:
+ if self._descriptor_db:
+ file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
+ else:
+ raise error
+ if not file_proto:
+ raise KeyError('Cannot find a file containing %s' % symbol)
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.
@@ -319,78 +571,69 @@ class DescriptorPool(object):
if file_proto.name not in self._file_descriptors:
built_deps = list(self._GetDeps(file_proto.dependency))
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
+ public_deps = [direct_deps[i] for i in file_proto.public_dependency]
file_descriptor = descriptor.FileDescriptor(
pool=self,
name=file_proto.name,
package=file_proto.package,
syntax=file_proto.syntax,
- options=file_proto.options,
+ options=_OptionsOrNone(file_proto),
serialized_pb=file_proto.SerializeToString(),
- dependencies=direct_deps)
- if _USE_C_DESCRIPTORS:
- # When using C++ descriptors, all objects defined in the file were added
- # to the C++ database when the FileDescriptor was built above.
- # Just add them to this descriptor pool.
- def _AddMessageDescriptor(message_desc):
- self._descriptors[message_desc.full_name] = message_desc
- for nested in message_desc.nested_types:
- _AddMessageDescriptor(nested)
- for enum_type in message_desc.enum_types:
- _AddEnumDescriptor(enum_type)
- def _AddEnumDescriptor(enum_desc):
- self._enum_descriptors[enum_desc.full_name] = enum_desc
- for message_type in file_descriptor.message_types_by_name.values():
- _AddMessageDescriptor(message_type)
- for enum_type in file_descriptor.enum_types_by_name.values():
- _AddEnumDescriptor(enum_type)
+ dependencies=direct_deps,
+ public_dependencies=public_deps)
+ scope = {}
+
+ # This loop extracts all the message and enum types from all the
+ # dependencies of the file_proto. This is necessary to create the
+ # scope of available message types when defining the passed in
+ # file proto.
+ for dependency in built_deps:
+ scope.update(self._ExtractSymbols(
+ dependency.message_types_by_name.values()))
+ scope.update((_PrefixWithDot(enum.full_name), enum)
+ for enum in dependency.enum_types_by_name.values())
+
+ for message_type in file_proto.message_type:
+ message_desc = self._ConvertMessageDescriptor(
+ message_type, file_proto.package, file_descriptor, scope,
+ file_proto.syntax)
+ file_descriptor.message_types_by_name[message_desc.name] = (
+ message_desc)
+
+ for enum_type in file_proto.enum_type:
+ file_descriptor.enum_types_by_name[enum_type.name] = (
+ self._ConvertEnumDescriptor(enum_type, file_proto.package,
+ file_descriptor, None, scope))
+
+ for index, extension_proto in enumerate(file_proto.extension):
+ extension_desc = self._MakeFieldDescriptor(
+ extension_proto, file_proto.package, index, file_descriptor,
+ is_extension=True)
+ extension_desc.containing_type = self._GetTypeFromScope(
+ file_descriptor.package, extension_proto.extendee, scope)
+ self._SetFieldType(extension_proto, extension_desc,
+ file_descriptor.package, scope)
+ file_descriptor.extensions_by_name[extension_desc.name] = (
+ extension_desc)
+
+ for desc_proto in file_proto.message_type:
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
+
+ if file_proto.package:
+ desc_proto_prefix = _PrefixWithDot(file_proto.package)
else:
- scope = {}
-
- # This loop extracts all the message and enum types from all the
- # dependencies of the file_proto. This is necessary to create the
- # scope of available message types when defining the passed in
- # file proto.
- for dependency in built_deps:
- scope.update(self._ExtractSymbols(
- dependency.message_types_by_name.values()))
- scope.update((_PrefixWithDot(enum.full_name), enum)
- for enum in dependency.enum_types_by_name.values())
-
- for message_type in file_proto.message_type:
- message_desc = self._ConvertMessageDescriptor(
- message_type, file_proto.package, file_descriptor, scope,
- file_proto.syntax)
- file_descriptor.message_types_by_name[message_desc.name] = (
- message_desc)
-
- for enum_type in file_proto.enum_type:
- file_descriptor.enum_types_by_name[enum_type.name] = (
- self._ConvertEnumDescriptor(enum_type, file_proto.package,
- file_descriptor, None, scope))
-
- for index, extension_proto in enumerate(file_proto.extension):
- extension_desc = self._MakeFieldDescriptor(
- extension_proto, file_proto.package, index, is_extension=True)
- extension_desc.containing_type = self._GetTypeFromScope(
- file_descriptor.package, extension_proto.extendee, scope)
- self._SetFieldType(extension_proto, extension_desc,
- file_descriptor.package, scope)
- file_descriptor.extensions_by_name[extension_desc.name] = (
- extension_desc)
-
- for desc_proto in file_proto.message_type:
- self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
-
- if file_proto.package:
- desc_proto_prefix = _PrefixWithDot(file_proto.package)
- else:
- desc_proto_prefix = ''
+ desc_proto_prefix = ''
+
+ for desc_proto in file_proto.message_type:
+ desc = self._GetTypeFromScope(
+ desc_proto_prefix, desc_proto.name, scope)
+ file_descriptor.message_types_by_name[desc_proto.name] = desc
- for desc_proto in file_proto.message_type:
- desc = self._GetTypeFromScope(
- desc_proto_prefix, desc_proto.name, scope)
- file_descriptor.message_types_by_name[desc_proto.name] = desc
+ for index, service_proto in enumerate(file_proto.service):
+ file_descriptor.services_by_name[service_proto.name] = (
+ self._MakeServiceDescriptor(service_proto, index, scope,
+ file_proto.package, file_descriptor))
self.Add(file_proto)
self._file_descriptors[file_proto.name] = file_descriptor
@@ -406,6 +649,7 @@ class DescriptorPool(object):
package: The package the proto should be located in.
file_desc: The file containing this message.
scope: Dict mapping short and full symbols to message and enum types.
+ syntax: string indicating syntax of the file ("proto2" or "proto3")
Returns:
The added descriptor.
@@ -431,15 +675,15 @@ class DescriptorPool(object):
enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
for enum in desc_proto.enum_type]
- fields = [self._MakeFieldDescriptor(field, desc_name, index)
+ fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
for index, field in enumerate(desc_proto.field)]
extensions = [
- self._MakeFieldDescriptor(extension, desc_name, index,
+ self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
is_extension=True)
for index, extension in enumerate(desc_proto.extension)]
oneofs = [
descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
- index, None, [])
+ index, None, [], desc.options)
for index, desc in enumerate(desc_proto.oneof_decl)]
extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
if extension_ranges:
@@ -456,7 +700,7 @@ class DescriptorPool(object):
nested_types=nested,
enum_types=enums,
extensions=extensions,
- options=desc_proto.options,
+ options=_OptionsOrNone(desc_proto),
is_extendable=is_extendable,
extension_ranges=extension_ranges,
file=file_desc,
@@ -474,6 +718,7 @@ class DescriptorPool(object):
fields[field_index].containing_oneof = oneofs[oneof_index]
scope[_PrefixWithDot(desc_name)] = desc
+ self._CheckConflictRegister(desc)
self._descriptors[desc_name] = desc
return desc
@@ -510,13 +755,14 @@ class DescriptorPool(object):
file=file_desc,
values=values,
containing_type=containing_type,
- options=enum_proto.options)
+ options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
+ self._CheckConflictRegister(desc)
self._enum_descriptors[enum_name] = desc
return desc
def _MakeFieldDescriptor(self, field_proto, message_name, index,
- is_extension=False):
+ file_desc, is_extension=False):
"""Creates a field descriptor from a FieldDescriptorProto.
For message and enum type fields, this method will do a look up
@@ -529,6 +775,7 @@ class DescriptorPool(object):
field_proto: The proto describing the field.
message_name: The name of the containing message.
index: Index of the field
+ file_desc: The file containing the field descriptor.
is_extension: Indication that this field is for an extension.
Returns:
@@ -555,7 +802,8 @@ class DescriptorPool(object):
default_value=None,
is_extension=is_extension,
extension_scope=None,
- options=field_proto.options)
+ options=_OptionsOrNone(field_proto),
+ file=file_desc)
def _SetAllFieldTypes(self, package, desc_proto, scope):
"""Sets all the descriptor's fields's types.
@@ -674,9 +922,69 @@ class DescriptorPool(object):
name=value_proto.name,
index=index,
number=value_proto.number,
- options=value_proto.options,
+ options=_OptionsOrNone(value_proto),
type=None)
+ def _MakeServiceDescriptor(self, service_proto, service_index, scope,
+ package, file_desc):
+ """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
+
+ Args:
+ service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
+ service_index: The index of the service in the File.
+ scope: Dict mapping short and full symbols to message and enum types.
+ package: Optional package name for the new message EnumDescriptor.
+ file_desc: The file containing the service descriptor.
+
+ Returns:
+ The added descriptor.
+ """
+
+ if package:
+ service_name = '.'.join((package, service_proto.name))
+ else:
+ service_name = service_proto.name
+
+ methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
+ scope, index)
+ for index, method_proto in enumerate(service_proto.method)]
+ desc = descriptor.ServiceDescriptor(name=service_proto.name,
+ full_name=service_name,
+ index=service_index,
+ methods=methods,
+ options=_OptionsOrNone(service_proto),
+ file=file_desc)
+ self._CheckConflictRegister(desc)
+ self._service_descriptors[service_name] = desc
+ return desc
+
+ def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
+ index):
+ """Creates a method descriptor from a MethodDescriptorProto.
+
+ Args:
+ method_proto: The proto describing the method.
+ service_name: The name of the containing service.
+ package: Optional package name to look up for types.
+ scope: Scope containing available types.
+ index: Index of the method in the service.
+
+ Returns:
+ An initialized MethodDescriptor object.
+ """
+ full_name = '.'.join((service_name, method_proto.name))
+ input_type = self._GetTypeFromScope(
+ package, method_proto.input_type, scope)
+ output_type = self._GetTypeFromScope(
+ package, method_proto.output_type, scope)
+ return descriptor.MethodDescriptor(name=method_proto.name,
+ full_name=full_name,
+ index=index,
+ containing_service=None,
+ input_type=input_type,
+ output_type=output_type,
+ options=_OptionsOrNone(method_proto))
+
def _ExtractSymbols(self, descriptors):
"""Pulls out all the symbols from descriptor protos.
diff --git a/python/google/protobuf/internal/_parameterized.py b/python/google/protobuf/internal/_parameterized.py
index dea3f199..f2c0b305 100755
--- a/python/google/protobuf/internal/_parameterized.py
+++ b/python/google/protobuf/internal/_parameterized.py
@@ -37,8 +37,8 @@ argument tuples.
A simple example:
- class AdditionExample(parameterized.ParameterizedTestCase):
- @parameterized.Parameters(
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
(1, 2, 3),
(4, 5, 9),
(1, 1, 3))
@@ -54,8 +54,8 @@ fail due to an assertion error (1 + 1 != 3).
Parameters for invididual test cases can be tuples (with positional parameters)
or dictionaries (with named parameters):
- class AdditionExample(parameterized.ParameterizedTestCase):
- @parameterized.Parameters(
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
{'op1': 1, 'op2': 2, 'result': 3},
{'op1': 4, 'op2': 5, 'result': 9},
)
@@ -77,13 +77,13 @@ stay the same across several invocations, object representations like
'<__main__.Foo object at 0x23d8610>'
are turned into '<__main__.Foo>'. For even more descriptive names,
-especially in test logs, you can use the NamedParameters decorator. In
+especially in test logs, you can use the named_parameters decorator. In
this case, only tuples are supported, and the first parameters has to
be a string (or an object that returns an apt name when converted via
str()):
- class NamedExample(parameterized.ParameterizedTestCase):
- @parameterized.NamedParameters(
+ class NamedExample(parameterized.TestCase):
+ @parameterized.named_parameters(
('Normal', 'aa', 'aaa', True),
('EmptyPrefix', '', 'abc', True),
('BothEmpty', '', '', True))
@@ -103,13 +103,13 @@ from the command line:
Parameterized Classes
=====================
If invocation arguments are shared across test methods in a single
-ParameterizedTestCase class, instead of decorating all test methods
+TestCase class, instead of decorating all test methods
individually, the class itself can be decorated:
- @parameterized.Parameters(
+ @parameterized.parameters(
(1, 2, 3)
(4, 5, 9))
- class ArithmeticTest(parameterized.ParameterizedTestCase):
+ class ArithmeticTest(parameterized.TestCase):
def testAdd(self, arg1, arg2, result):
self.assertEqual(arg1 + arg2, result)
@@ -122,8 +122,8 @@ If parameters should be shared across several test cases, or are dynamically
created from other sources, a single non-tuple iterable can be passed into
the decorator. This iterable will be used to obtain the test cases:
- class AdditionExample(parameterized.ParameterizedTestCase):
- @parameterized.Parameters(
+ class AdditionExample(parameterized.TestCase):
+ @parameterized.parameters(
c.op1, c.op2, c.result for c in testcases
)
def testAddition(self, op1, op2, result):
@@ -135,8 +135,8 @@ Single-Argument Test Methods
If a test method takes only one argument, the single argument does not need to
be wrapped into a tuple:
- class NegativeNumberExample(parameterized.ParameterizedTestCase):
- @parameterized.Parameters(
+ class NegativeNumberExample(parameterized.TestCase):
+ @parameterized.parameters(
-1, -3, -4, -5
)
def testIsNegative(self, arg):
@@ -212,7 +212,7 @@ class _ParameterizedTestIter(object):
def __call__(self, *args, **kwargs):
raise RuntimeError('You appear to be running a parameterized test case '
'without having inherited from parameterized.'
- 'ParameterizedTestCase. This is bad because none of '
+ 'TestCase. This is bad because none of '
'your test cases are actually being run.')
def __iter__(self):
@@ -306,7 +306,7 @@ def _ParameterDecorator(naming_type, testcases):
return _Apply
-def Parameters(*testcases):
+def parameters(*testcases): # pylint: disable=invalid-name
"""A decorator for creating parameterized tests.
See the module docstring for a usage example.
@@ -321,7 +321,7 @@ def Parameters(*testcases):
return _ParameterDecorator(_ARGUMENT_REPR, testcases)
-def NamedParameters(*testcases):
+def named_parameters(*testcases): # pylint: disable=invalid-name
"""A decorator for creating parameterized tests.
See the module docstring for a usage example. The first element of
@@ -347,8 +347,8 @@ class TestGeneratorMetaclass(type):
iterable conforms to the test pattern, the injected methods will be picked
up as tests by the unittest framework.
- In general, it is supposed to be used in conjuction with the
- Parameters decorator.
+ In general, it is supposed to be used in conjunction with the
+ parameters decorator.
"""
def __new__(mcs, class_name, bases, dct):
@@ -385,8 +385,8 @@ def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator):
id_suffix[new_name] = getattr(func, '__x_extra_id__', '')
-class ParameterizedTestCase(unittest.TestCase):
- """Base class for test cases using the Parameters decorator."""
+class TestCase(unittest.TestCase):
+ """Base class for test cases using the parameters decorator."""
__metaclass__ = TestGeneratorMetaclass
def _OriginalName(self):
@@ -409,10 +409,10 @@ class ParameterizedTestCase(unittest.TestCase):
self._id_suffix.get(self._testMethodName, ''))
-def CoopParameterizedTestCase(other_base_class):
+def CoopTestCase(other_base_class):
"""Returns a new base class with a cooperative metaclass base.
- This enables the ParameterizedTestCase to be used in combination
+ This enables the TestCase to be used in combination
with other base classes that have custom metaclasses, such as
mox.MoxTestBase.
@@ -425,7 +425,7 @@ def CoopParameterizedTestCase(other_base_class):
from google3.testing.pybase import parameterized
- class ExampleTest(parameterized.CoopParameterizedTestCase(mox.MoxTestBase)):
+ class ExampleTest(parameterized.CoopTestCase(mox.MoxTestBase)):
...
Args:
@@ -439,5 +439,5 @@ def CoopParameterizedTestCase(other_base_class):
(other_base_class.__metaclass__,
TestGeneratorMetaclass), {})
return metaclass(
- 'CoopParameterizedTestCase',
- (other_base_class, ParameterizedTestCase), {})
+ 'CoopTestCase',
+ (other_base_class, TestCase), {})
diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto
index cd641ca0..1a563fd9 100644
--- a/python/google/protobuf/internal/any_test.proto
+++ b/python/google/protobuf/internal/any_test.proto
@@ -30,13 +30,22 @@
// Author: jieluo@google.com (Jie Luo)
-syntax = "proto3";
+syntax = "proto2";
package google.protobuf.internal;
import "google/protobuf/any.proto";
message TestAny {
- google.protobuf.Any value = 1;
- int32 int_value = 2;
+ optional google.protobuf.Any value = 1;
+ optional int32 int_value = 2;
+ map<string,int32> map_value = 3;
+ extensions 10 to max;
+}
+
+message TestAnyExtension1 {
+ extend TestAny {
+ optional TestAnyExtension1 extension1 = 98418603;
+ }
+ optional int32 i = 15;
}
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index ffcf7511..ab9e7812 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -32,6 +32,7 @@
"""
import os
+import warnings
import sys
try:
@@ -60,10 +61,18 @@ if _api_version < 0: # Still unspecified?
del _use_fast_cpp_protos
_api_version = 2
except ImportError:
- if _proto_extension_modules_exist_in_build:
- if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2.
- _api_version = 2
- # TODO(b/17427486): Make Python 2 default to C++ impl v2.
+ try:
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.internal import use_pure_python
+ del use_pure_python # Avoids a pylint error and namespace pollution.
+ except ImportError:
+ # TODO(b/74017912): It's unsafe to enable :use_fast_cpp_protos by default;
+ # it can cause data loss if you have any Python-only extensions to any
+ # message passed back and forth with C++ code.
+ #
+ # TODO(b/17427486): Once that bug is fixed, we want to make both Python 2
+ # and Python 3 default to `_api_version = 2` (C++ implementation V2).
+ pass
_default_implementation_type = (
'python' if _api_version <= 0 else 'cpp')
@@ -78,6 +87,11 @@ _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION',
if _implementation_type != 'python':
_implementation_type = 'cpp'
+if 'PyPy' in sys.version and _implementation_type == 'cpp':
+ warnings.warn('PyPy does not work yet with cpp protocol buffers. '
+ 'Falling back to the python implementation.')
+ _implementation_type = 'python'
+
# This environment variable can be used to switch between the two
# 'cpp' implementations, overriding the compile-time constants in the
# _api_implementation module. Right now only '2' is supported. Any other
@@ -94,6 +108,27 @@ if _implementation_version_str != '2':
_implementation_version = int(_implementation_version_str)
+# Detect if serialization should be deterministic by default
+try:
+ # The presence of this module in a build allows the proto implementation to
+ # be upgraded merely via build deps.
+ #
+ # NOTE: Merely importing this automatically enables deterministic proto
+ # serialization for C++ code, but we still need to export it as a boolean so
+ # that we can do the same for `_implementation_type == 'python'`.
+ #
+ # NOTE2: It is possible for C++ code to enable deterministic serialization by
+ # default _without_ affecting Python code, if the C++ implementation is not in
+ # use by this module. That is intended behavior, so we don't actually expose
+ # this boolean outside of this module.
+ #
+ # pylint: disable=g-import-not-at-top,unused-import
+ from google.protobuf import enable_deterministic_proto_serialization
+ _python_deterministic_proto_serialization = True
+except ImportError:
+ _python_deterministic_proto_serialization = False
+
+
# Usage of this function is discouraged. Clients shouldn't care which
# implementation of the API is in use. Note that there is no guarantee
# that differences between APIs will be maintained.
@@ -105,3 +140,34 @@ def Type():
# See comment on 'Type' above.
def Version():
return _implementation_version
+
+
+# For internal use only
+def IsPythonDefaultSerializationDeterministic():
+ return _python_deterministic_proto_serialization
+
+# DO NOT USE: For migration and testing only. Will be removed when Proto3
+# defaults to preserve unknowns.
+if _implementation_type == 'cpp':
+ try:
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+
+ def GetPythonProto3PreserveUnknownsDefault():
+ return _message.GetPythonProto3PreserveUnknownsDefault()
+
+ def SetPythonProto3PreserveUnknownsDefault(preserve):
+ _message.SetPythonProto3PreserveUnknownsDefault(preserve)
+ except ImportError:
+ # Unrecognized cpp implementation. Skipping the unknown fields APIs.
+ pass
+else:
+ _python_proto3_preserve_unknowns_default = True
+
+ def GetPythonProto3PreserveUnknownsDefault():
+ return _python_proto3_preserve_unknowns_default
+
+ def SetPythonProto3PreserveUnknownsDefault(preserve):
+ global _python_proto3_preserve_unknowns_default
+ _python_proto3_preserve_unknowns_default = preserve
+
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 97cdd848..c6a3692a 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -275,7 +275,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter]
if new_values:
self._values.extend(new_values)
- self._message_listener.Modified()
+ self._message_listener.Modified()
def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
@@ -436,9 +436,11 @@ class ScalarMap(MutableMapping):
"""Simple, type-checked, dict-like container for holding repeated scalars."""
# Disallows assignment to other attributes.
- __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener']
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
+ '_entry_descriptor']
- def __init__(self, message_listener, key_checker, value_checker):
+ def __init__(self, message_listener, key_checker, value_checker,
+ entry_descriptor):
"""
Args:
message_listener: A MessageListener implementation.
@@ -448,10 +450,12 @@ class ScalarMap(MutableMapping):
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._key_checker = key_checker
self._value_checker = value_checker
+ self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key):
@@ -513,6 +517,9 @@ class ScalarMap(MutableMapping):
self._values.clear()
self._message_listener.Modified()
+ def GetEntryClass(self):
+ return self._entry_descriptor._concrete_class
+
class MessageMap(MutableMapping):
@@ -520,9 +527,10 @@ class MessageMap(MutableMapping):
# Disallows assignment to other attributes.
__slots__ = ['_key_checker', '_values', '_message_listener',
- '_message_descriptor']
+ '_message_descriptor', '_entry_descriptor']
- def __init__(self, message_listener, message_descriptor, key_checker):
+ def __init__(self, message_listener, message_descriptor, key_checker,
+ entry_descriptor):
"""
Args:
message_listener: A MessageListener implementation.
@@ -532,17 +540,19 @@ class MessageMap(MutableMapping):
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._message_descriptor = message_descriptor
self._key_checker = key_checker
+ self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key):
+ key = self._key_checker.CheckValue(key)
try:
return self._values[key]
except KeyError:
- key = self._key_checker.CheckValue(key)
new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener)
self._values[key] = new_element
@@ -574,12 +584,14 @@ class MessageMap(MutableMapping):
return default
def __contains__(self, item):
+ item = self._key_checker.CheckValue(item)
return item in self._values
def __setitem__(self, key, value):
raise ValueError('May not set values directly, call my_map[key].foo = 5')
def __delitem__(self, key):
+ key = self._key_checker.CheckValue(key)
del self._values[key]
self._message_listener.Modified()
@@ -594,7 +606,11 @@ class MessageMap(MutableMapping):
def MergeFrom(self, other):
for key in other:
- self[key].MergeFrom(other[key])
+ # According to documentation: "When parsing from the wire or when merging,
+ # if there are duplicate map keys the last key seen is used".
+ if key in self:
+ del self[key]
+ self[key].CopyFrom(other[key])
# self._message_listener.Modified() not required here, because
# mutations to submessages already propagate.
@@ -609,3 +625,6 @@ class MessageMap(MutableMapping):
def clear(self):
self._values.clear()
self._message_listener.Modified()
+
+ def GetEntryClass(self):
+ return self._entry_descriptor._concrete_class
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 31869e45..52b64915 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -131,9 +131,12 @@ def _VarintDecoder(mask, result_type):
return DecodeVarint
-def _SignedVarintDecoder(mask, result_type):
+def _SignedVarintDecoder(bits, result_type):
"""Like _VarintDecoder() but decodes signed values."""
+ signbit = 1 << (bits - 1)
+ mask = (1 << bits) - 1
+
def DecodeVarint(buffer, pos):
result = 0
shift = 0
@@ -142,11 +145,8 @@ def _SignedVarintDecoder(mask, result_type):
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
- if result > 0x7fffffffffffffff:
- result -= (1 << 64)
- result |= ~mask
- else:
- result &= mask
+ result &= mask
+ result = (result ^ signbit) - signbit
result = result_type(result)
return (result, pos)
shift += 7
@@ -159,11 +159,11 @@ def _SignedVarintDecoder(mask, result_type):
# (e.g. the C++ implementation) simpler.
_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
+_DecodeSignedVarint = _SignedVarintDecoder(64, long)
# Use these versions for values which must be limited to 32 bits.
_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
+_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
def ReadTag(buffer, pos):
@@ -181,7 +181,7 @@ def ReadTag(buffer, pos):
while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
- return (buffer[start:pos], pos)
+ return (six.binary_type(buffer[start:pos]), pos)
# --------------------------------------------------------------------
@@ -642,10 +642,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
-def MessageSetItemDecoder(extensions_by_number):
+def MessageSetItemDecoder(descriptor):
"""Returns a decoder for a MessageSet item.
- The parameter is the _extensions_by_number map for the message class.
+ The parameter is the message Descriptor.
The message set message looks like this:
message MessageSet {
@@ -694,7 +694,7 @@ def MessageSetItemDecoder(extensions_by_number):
if message_start == -1:
raise _DecodeError('MessageSet item missing message.')
- extension = extensions_by_number.get(type_id)
+ extension = message.Extensions._FindExtensionByNumber(type_id)
if extension is not None:
value = field_dict.get(extension)
if value is None:
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index 1baff7d1..f97477b3 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -35,9 +35,12 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+import warnings
+
+from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
@@ -53,16 +56,69 @@ class DescriptorDatabaseTest(unittest.TestCase):
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
+ # Can find message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message'))
+ # Can find nested message type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
+ # Can find enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum'))
+ # Can find nested enum type.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.MessageWithNestedEnumOnly.NestedEnum'))
+ # Can find field.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.list_field'))
+ # Can find enum value.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
+ # Can find top level extension.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.another_field'))
+ # Can find nested extension inside a message.
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field'))
+
+ # Can find service.
+ file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb)
+ db.Add(file_desc_proto2)
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ 'protobuf_unittest.TestService'))
+
+ # Non-existent field under a valid top level symbol can also be
+ # found. The behavior is the same with protobuf C++.
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes.none_field'))
+
+ self.assertRaises(KeyError,
+ db.FindFileContainingSymbol,
+ 'protobuf_unittest.NoneMessage')
+
+ def testConflictRegister(self):
+ db = descriptor_database.DescriptorDatabase()
+ unittest_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb)
+ db.Add(unittest_fd)
+ conflict_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb)
+ conflict_fd.name = 'other_file'
+ with warnings.catch_warnings(record=True) as w:
+ # Cause all warnings to always be triggered.
+ warnings.simplefilter('always')
+ db.Add(conflict_fd)
+ self.assertTrue(len(w))
+ self.assertIs(w[0].category, RuntimeWarning)
+ self.assertIn('Conflict register for file "other_file": ',
+ str(w[0].message))
+ self.assertIn('already defined in file '
+ '"google/protobuf/unittest.proto"',
+ str(w[0].message))
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index f1d6bf99..2cbf7813 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -34,12 +34,16 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
+import copy
import os
+import sys
+import warnings
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_pb2
@@ -49,7 +53,8 @@ from google.protobuf.internal import descriptor_pool_test1_pb2
from google.protobuf.internal import descriptor_pool_test2_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
-from google.protobuf.internal import test_util
+from google.protobuf.internal import file_options_test_pb2
+from google.protobuf.internal import more_messages_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
@@ -57,19 +62,8 @@ from google.protobuf import message_factory
from google.protobuf import symbol_database
-class DescriptorPoolTest(unittest.TestCase):
-
- def CreatePool(self):
- return descriptor_pool.DescriptorPool()
- def setUp(self):
- self.pool = self.CreatePool()
- self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
- factory_test1_pb2.DESCRIPTOR.serialized_pb)
- self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
- factory_test2_pb2.DESCRIPTOR.serialized_pb)
- self.pool.Add(self.factory_test1_fd)
- self.pool.Add(self.factory_test2_fd)
+class DescriptorPoolTestBase(object):
def testFindFileByName(self):
name1 = 'google/protobuf/internal/factory_test1.proto'
@@ -107,6 +101,34 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google.protobuf.python.internal', file_desc2.package)
self.assertIn('Factory2Message', file_desc2.message_types_by_name)
+ # Tests top level extension.
+ file_desc3 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.another_field')
+ self.assertIsInstance(file_desc3, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test2.proto',
+ file_desc3.name)
+
+ # Tests nested extension inside a message.
+ file_desc4 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ self.assertIsInstance(file_desc4, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test2.proto',
+ file_desc4.name)
+
+ file_desc5 = self.pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestService')
+ self.assertIsInstance(file_desc5, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/unittest.proto',
+ file_desc5.name)
+
+ # Tests the generated pool.
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'google.protobuf.python.internal.another_field')
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'protobuf_unittest.TestService')
+
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
@@ -119,6 +141,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google.protobuf.python.internal.Factory1Message',
msg1.full_name)
self.assertEqual(None, msg1.containing_type)
+ self.assertFalse(msg1.has_options)
nested_msg1 = msg1.nested_types[0]
self.assertEqual('NestedFactory1Message', nested_msg1.name)
@@ -192,6 +215,27 @@ class DescriptorPoolTest(unittest.TestCase):
msg2.fields_by_name[name].containing_oneof)
self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields)
+ def testFindTypeErrors(self):
+ self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')
+
+ # TODO(jieluo): Fix python to raise correct errors.
+ if api_implementation.Type() == 'cpp':
+ self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
+ self.assertRaises(KeyError, self.pool.FindMethodByName, '')
+ error_type = TypeError
+ else:
+ error_type = AttributeError
+ self.assertRaises(error_type, self.pool.FindMessageTypeByName, 0)
+ self.assertRaises(error_type, self.pool.FindFieldByName, 0)
+ self.assertRaises(error_type, self.pool.FindExtensionByName, 0)
+ self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
+ self.assertRaises(error_type, self.pool.FindOneofByName, 0)
+ self.assertRaises(error_type, self.pool.FindServiceByName, 0)
+ self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
+ if api_implementation.Type() == 'python':
+ error_type = KeyError
+ self.assertRaises(error_type, self.pool.FindFileByName, 0)
+
def testFindMessageTypeByNameFailure(self):
with self.assertRaises(KeyError):
self.pool.FindMessageTypeByName('Does not exist')
@@ -202,6 +246,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(enum1, descriptor.EnumDescriptor)
self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
+ self.assertFalse(enum1.has_options)
nested_enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum')
@@ -230,14 +275,38 @@ class DescriptorPoolTest(unittest.TestCase):
self.pool.FindEnumTypeByName('Does not exist')
def testFindFieldByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find field correctly
+ # when descriptor pool is using an underlying database.
+ return
field = self.pool.FindFieldByName(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
self.assertEqual(field.label, field.LABEL_REPEATED)
+ self.assertFalse(field.has_options)
+
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
+ def testFindOneofByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find oneof correctly
+ # when descriptor pool is using an underlying database.
+ return
+ oneof = self.pool.FindOneofByName(
+ 'google.protobuf.python.internal.Factory2Message.oneof_field')
+ self.assertEqual(oneof.name, 'oneof_field')
+ with self.assertRaises(KeyError):
+ self.pool.FindOneofByName('Does not exist')
+
def testFindExtensionByName(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix cpp extension to find extension correctly
+ # when descriptor pool is using an underlying database.
+ return
# An extension defined in a message.
extension = self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -250,6 +319,53 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
+ def testFindAllExtensions(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ expected_extension_numbers = set([one_more_field, another_field])
+ self.assertEqual(expected_extension_numbers, set(extensions))
+ # Verify that mutating the returned list does not affect the pool.
+ extensions.append('unexpected_element')
+ # Get the extensions again, the returned value does not contain the
+ # 'unexpected_element'.
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ self.assertEqual(expected_extension_numbers, set(extensions))
+
+ def testFindExtensionByNumber(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1001)
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1002)
+ self.assertEqual(extension.name, 'another_field')
+ with self.assertRaises(KeyError):
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1234567)
+
def testExtensionsAreNotFields(self):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
@@ -260,6 +376,12 @@ class DescriptorPoolTest(unittest.TestCase):
self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory1Message.list_value')
+ def testFindService(self):
+ service = self.pool.FindServiceByName('protobuf_unittest.TestService')
+ self.assertEqual(service.full_name, 'protobuf_unittest.TestService')
+ with self.assertRaises(KeyError):
+ self.pool.FindServiceByName('Does not exist')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -268,21 +390,17 @@ class DescriptorPoolTest(unittest.TestCase):
self.testFindMessageTypeByName()
def testAddSerializedFile(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
self.pool = descriptor_pool.DescriptorPool()
self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
self.testFindMessageTypeByName()
- def testComplexNesting(self):
- test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
- descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
- test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
- descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
- self.pool.Add(test1_desc)
- self.pool.Add(test2_desc)
- TEST1_FILE.CheckFile(self, self.pool)
- TEST2_FILE.CheckFile(self, self.pool)
-
def testEnumDefaultValue(self):
"""Test the default value of enums which don't start at zero."""
@@ -301,6 +419,12 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIs(file_descriptor, descriptor_pool_test1_pb2.DESCRIPTOR)
_CheckDefaultValue(file_descriptor)
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
# Then check the dynamic pool and its internal DescriptorDatabase.
descriptor_proto = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
@@ -348,26 +472,166 @@ class DescriptorPoolTest(unittest.TestCase):
unittest_pb2.TestAllTypes.DESCRIPTOR.full_name))
_CheckDefaultValues(message_class())
+ def testAddFileDescriptor(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ self.pool.Add(file_desc)
+ self.pool.AddSerializedFile(file_desc.SerializeToString())
+
+ def testComplexNesting(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
+ more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ more_messages_pb2.DESCRIPTOR.serialized_pb)
+ test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
+ test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(more_messages_desc)
+ self.pool.Add(test1_desc)
+ self.pool.Add(test2_desc)
+ TEST1_FILE.CheckFile(self, self.pool)
+ TEST2_FILE.CheckFile(self, self.pool)
+
+ def testConflictRegister(self):
+ if isinstance(self, SecondaryDescriptorFromDescriptorDB):
+ if api_implementation.Type() == 'cpp':
+ # Cpp extension cannot call Add on a DescriptorPool
+ # that uses a DescriptorDatabase.
+ # TODO(jieluo): Fix python and cpp extension diff.
+ return
+ unittest_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb)
+ conflict_fd = copy.deepcopy(unittest_fd)
+ conflict_fd.name = 'other_file'
+ if api_implementation.Type() == 'cpp':
+ try:
+ self.pool.Add(unittest_fd)
+ self.pool.Add(conflict_fd)
+ except TypeError:
+ pass
+ else:
+ with warnings.catch_warnings(record=True) as w:
+ # Cause all warnings to always be triggered.
+ warnings.simplefilter('always')
+ pool = copy.deepcopy(self.pool)
+ # No warnings to add the same descriptors.
+ file_descriptor = unittest_pb2.DESCRIPTOR
+ pool.AddDescriptor(
+ file_descriptor.message_types_by_name['TestAllTypes'])
+ pool.AddEnumDescriptor(
+ file_descriptor.enum_types_by_name['ForeignEnum'])
+ pool.AddServiceDescriptor(
+ file_descriptor.services_by_name['TestService'])
+ pool.AddExtensionDescriptor(
+ file_descriptor.extensions_by_name['optional_int32_extension'])
+ self.assertEqual(len(w), 0)
+ # Check warnings for conflict descriptors with the same name.
+ pool.Add(unittest_fd)
+ pool.Add(conflict_fd)
+ pool.FindFileByName(unittest_fd.name)
+ pool.FindFileByName(conflict_fd.name)
+ self.assertTrue(len(w))
+ self.assertIs(w[0].category, RuntimeWarning)
+ self.assertIn('Conflict register for file "other_file": ',
+ str(w[0].message))
+ self.assertIn('already defined in file '
+ '"google/protobuf/unittest.proto"',
+ str(w[0].message))
+
+
+class DefaultDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
+
+ def setUp(self):
+ self.pool = descriptor_pool.Default()
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+
+ def testFindMethods(self):
+ self.assertIs(
+ self.pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindFieldByName(
+ 'protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ self.pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ self.pool.FindExtensionByName(
+ 'protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ self.pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+ self.assertIs(
+ self.pool.FindServiceByName('protobuf_unittest.TestService'),
+ unittest_pb2.DESCRIPTOR.services_by_name['TestService'])
+
+
+class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
+
+ def setUp(self):
+ self.pool = descriptor_pool.DescriptorPool()
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(self.factory_test1_fd)
+ self.pool.Add(self.factory_test2_fd)
+
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
-@unittest.skipIf(api_implementation.Type() != 'cpp',
- 'explicit tests of the C++ implementation')
-class CppDescriptorPoolTest(DescriptorPoolTest):
- # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true
- # C++ descriptor pool object for C++ implementation.
- def CreatePool(self):
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- return _message.DescriptorPool()
+class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
+ unittest.TestCase):
+
+ def setUp(self):
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ db = descriptor_database.DescriptorDatabase()
+ db.Add(self.factory_test1_fd)
+ db.Add(self.factory_test2_fd)
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_public_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_import_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ unittest_pb2.DESCRIPTOR.serialized_pb))
+ self.pool = descriptor_pool.DescriptorPool(descriptor_db=db)
class ProtoFile(object):
- def __init__(self, name, package, messages, dependencies=None):
+ def __init__(self, name, package, messages, dependencies=None,
+ public_dependencies=None):
self.name = name
self.package = package
self.messages = messages
self.dependencies = dependencies or []
+ self.public_dependencies = public_dependencies or []
def CheckFile(self, test, pool):
file_desc = pool.FindFileByName(self.name)
@@ -375,6 +639,8 @@ class ProtoFile(object):
test.assertEqual(self.package, file_desc.package)
dependencies_names = [f.name for f in file_desc.dependencies]
test.assertEqual(self.dependencies, dependencies_names)
+ public_dependencies_names = [f.name for f in file_desc.public_dependencies]
+ test.assertEqual(self.public_dependencies, public_dependencies_names)
for name, msg_type in self.messages.items():
msg_type.CheckType(test, None, name, file_desc)
@@ -426,10 +692,10 @@ class MessageType(object):
subtype.CheckType(test, desc, name, file_desc)
for index, (name, field) in enumerate(self.field_list):
- field.CheckField(test, desc, name, index)
+ field.CheckField(test, desc, name, index, file_desc)
for index, (name, field) in enumerate(self.extensions):
- field.CheckField(test, desc, name, index)
+ field.CheckField(test, desc, name, index, file_desc)
class EnumField(object):
@@ -439,7 +705,7 @@ class EnumField(object):
self.type_name = type_name
self.default_value = default_value
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
enum_desc = msg_desc.enum_types_by_name[self.type_name]
test.assertEqual(name, field_desc.name)
@@ -453,8 +719,10 @@ class EnumField(object):
test.assertTrue(field_desc.has_default_value)
test.assertEqual(enum_desc.values_by_name[self.default_value].number,
field_desc.default_value)
+ test.assertFalse(enum_desc.values_by_name[self.default_value].has_options)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(enum_desc, field_desc.enum_type)
+ test.assertEqual(file_desc, enum_desc.file)
class MessageField(object):
@@ -463,7 +731,7 @@ class MessageField(object):
self.number = number
self.type_name = type_name
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
field_type_desc = msg_desc.nested_types_by_name[self.type_name]
test.assertEqual(name, field_desc.name)
@@ -477,6 +745,12 @@ class MessageField(object):
test.assertFalse(field_desc.has_default_value)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
+ test.assertEqual(file_desc, field_desc.file)
+ # TODO(jieluo): Fix python and cpp extension diff for message field
+ # default value.
+ if api_implementation.Type() == 'cpp':
+ test.assertRaises(
+ NotImplementedError, getattr, field_desc, 'default_value')
class StringField(object):
@@ -485,7 +759,7 @@ class StringField(object):
self.number = number
self.default_value = default_value
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
test.assertEqual(name, field_desc.name)
expected_field_full_name = '.'.join([msg_desc.full_name, name])
@@ -497,6 +771,7 @@ class StringField(object):
field_desc.cpp_type)
test.assertTrue(field_desc.has_default_value)
test.assertEqual(self.default_value, field_desc.default_value)
+ test.assertEqual(file_desc, field_desc.file)
class ExtensionField(object):
@@ -505,7 +780,7 @@ class ExtensionField(object):
self.number = number
self.extended_type = extended_type
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.extensions_by_name[name]
test.assertEqual(name, field_desc.name)
expected_field_full_name = '.'.join([msg_desc.full_name, name])
@@ -520,6 +795,7 @@ class ExtensionField(object):
test.assertEqual(msg_desc, field_desc.extension_scope)
test.assertEqual(msg_desc, field_desc.message_type)
test.assertEqual(self.extended_type, field_desc.containing_type.name)
+ test.assertEqual(file_desc, field_desc.file)
class AddDescriptorTest(unittest.TestCase):
@@ -555,7 +831,7 @@ class AddDescriptorTest(unittest.TestCase):
prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
@unittest.skipIf(api_implementation.Type() == 'cpp',
- 'With the cpp implementation, Add() must be called first')
+ 'With the cpp implementation, Add() must be called first')
def testMessage(self):
self._TestMessage('')
self._TestMessage('.')
@@ -591,13 +867,24 @@ class AddDescriptorTest(unittest.TestCase):
prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
@unittest.skipIf(api_implementation.Type() == 'cpp',
- 'With the cpp implementation, Add() must be called first')
+ 'With the cpp implementation, Add() must be called first')
def testEnum(self):
self._TestEnum('')
self._TestEnum('.')
@unittest.skipIf(api_implementation.Type() == 'cpp',
- 'With the cpp implementation, Add() must be called first')
+ 'With the cpp implementation, Add() must be called first')
+ def testService(self):
+ pool = descriptor_pool.DescriptorPool()
+ with self.assertRaises(KeyError):
+ pool.FindServiceByName('protobuf_unittest.TestService')
+ pool.AddServiceDescriptor(unittest_pb2._TESTSERVICE)
+ self.assertEqual(
+ 'protobuf_unittest.TestService',
+ pool.FindServiceByName('protobuf_unittest.TestService').full_name)
+
+ @unittest.skipIf(api_implementation.Type() == 'cpp',
+ 'With the cpp implementation, Add() must be called first')
def testFile(self):
pool = descriptor_pool.DescriptorPool()
pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
@@ -612,18 +899,9 @@ class AddDescriptorTest(unittest.TestCase):
pool.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes')
- def _GetDescriptorPoolClass(self):
- # Test with both implementations of descriptor pools.
- if api_implementation.Type() == 'cpp':
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- return _message.DescriptorPool
- else:
- return descriptor_pool.DescriptorPool
-
def testEmptyDescriptorPool(self):
- # Check that an empty DescriptorPool() contains no message.
- pool = self._GetDescriptorPoolClass()()
+ # Check that an empty DescriptorPool() contains no messages.
+ pool = descriptor_pool.DescriptorPool()
proto_file_name = descriptor_pb2.DESCRIPTOR.name
self.assertRaises(KeyError, pool.FindFileByName, proto_file_name)
# Add the above file to the pool
@@ -635,7 +913,7 @@ class AddDescriptorTest(unittest.TestCase):
def testCustomDescriptorPool(self):
# Create a new pool, and add a file descriptor.
- pool = self._GetDescriptorPoolClass()()
+ pool = descriptor_pool.DescriptorPool()
file_desc = descriptor_pb2.FileDescriptorProto(
name='some/file.proto', package='package')
file_desc.message_type.add(name='Message')
@@ -644,43 +922,55 @@ class AddDescriptorTest(unittest.TestCase):
'some/file.proto')
self.assertEqual(pool.FindMessageTypeByName('package.Message').name,
'Message')
-
-
-@unittest.skipIf(
- api_implementation.Type() != 'cpp',
- 'default_pool is only supported by the C++ implementation')
-class DefaultPoolTest(unittest.TestCase):
-
- def testFindMethods(self):
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- pool = _message.default_pool
- self.assertIs(
- pool.FindFileByName('google/protobuf/unittest.proto'),
- unittest_pb2.DESCRIPTOR)
- self.assertIs(
- pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
- unittest_pb2.TestAllTypes.DESCRIPTOR)
- self.assertIs(
- pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
- unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
- self.assertIs(
- pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
- unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
- self.assertIs(
- pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
- unittest_pb2.ForeignEnum.DESCRIPTOR)
- self.assertIs(
- pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
- unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
-
- def testAddFileDescriptor(self):
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- pool = _message.default_pool
- file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
- pool.Add(file_desc)
- pool.AddSerializedFile(file_desc.SerializeToString())
+ # Test no package
+ file_proto = descriptor_pb2.FileDescriptorProto(
+ name='some/filename/container.proto')
+ message_proto = file_proto.message_type.add(
+ name='TopMessage')
+ message_proto.field.add(
+ name='bb',
+ number=1,
+ type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
+ label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL)
+ enum_proto = file_proto.enum_type.add(name='TopEnum')
+ enum_proto.value.add(name='FOREIGN_FOO', number=4)
+ file_proto.service.add(name='TopService')
+ pool = descriptor_pool.DescriptorPool()
+ pool.Add(file_proto)
+ self.assertEqual('TopMessage',
+ pool.FindMessageTypeByName('TopMessage').name)
+ self.assertEqual('TopEnum', pool.FindEnumTypeByName('TopEnum').name)
+ self.assertEqual('TopService', pool.FindServiceByName('TopService').name)
+
+ def testFileDescriptorOptionsWithCustomDescriptorPool(self):
+ # Create a descriptor pool, and add a new FileDescriptorProto to it.
+ pool = descriptor_pool.DescriptorPool()
+ file_name = 'file_descriptor_options_with_custom_descriptor_pool.proto'
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto(name=file_name)
+ extension_id = file_options_test_pb2.foo_options
+ file_descriptor_proto.options.Extensions[extension_id].foo_name = 'foo'
+ pool.Add(file_descriptor_proto)
+ # The options set on the FileDescriptorProto should be available in the
+ # descriptor even if they contain extensions that cannot be deserialized
+ # using the pool.
+ file_descriptor = pool.FindFileByName(file_name)
+ options = file_descriptor.GetOptions()
+ self.assertEqual('foo', options.Extensions[extension_id].foo_name)
+ # The object returned by GetOptions() is cached.
+ self.assertIs(options, file_descriptor.GetOptions())
+
+ def testAddTypeError(self):
+ pool = descriptor_pool.DescriptorPool()
+ with self.assertRaises(TypeError):
+ pool.AddDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddEnumDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddServiceDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddExtensionDescriptor(0)
+ with self.assertRaises(TypeError):
+ pool.AddFileDescriptor(0)
TEST1_FILE = ProtoFile(
@@ -756,7 +1046,9 @@ TEST2_FILE = ProtoFile(
ExtensionField(1001, 'DescriptorPoolTest1')),
]),
},
- dependencies=['google/protobuf/internal/descriptor_pool_test1.proto'])
+ dependencies=['google/protobuf/internal/descriptor_pool_test1.proto',
+ 'google/protobuf/internal/more_messages.proto'],
+ public_dependencies=['google/protobuf/internal/more_messages.proto'])
if __name__ == '__main__':
diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto
index e3fa660c..a218eccb 100644
--- a/python/google/protobuf/internal/descriptor_pool_test2.proto
+++ b/python/google/protobuf/internal/descriptor_pool_test2.proto
@@ -33,6 +33,7 @@ syntax = "proto2";
package google.protobuf.python.internal;
import "google/protobuf/internal/descriptor_pool_test1.proto";
+import public "google/protobuf/internal/more_messages.proto";
message DescriptorPoolTest3 {
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index fee09a56..02a43d15 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -37,9 +37,10 @@ __author__ = 'robinson@google.com (Will Robinson)'
import sys
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
@@ -76,27 +77,24 @@ class DescriptorTest(unittest.TestCase):
enum_proto.value.add(name='FOREIGN_BAR', number=5)
enum_proto.value.add(name='FOREIGN_BAZ', number=6)
+ file_proto.message_type.add(name='ResponseMessage')
+ service_proto = file_proto.service.add(
+ name='Service')
+ method_proto = service_proto.method.add(
+ name='CallMethod',
+ input_type='.protobuf_unittest.NestedMessage',
+ output_type='.protobuf_unittest.ResponseMessage')
+
+ # Note: Calling DescriptorPool.Add() multiple times with the same file only
+ # works if the input is canonical; in particular, all type names must be
+ # fully qualified.
self.pool = self.GetDescriptorPool()
self.pool.Add(file_proto)
self.my_file = self.pool.FindFileByName(file_proto.name)
self.my_message = self.my_file.message_types_by_name[message_proto.name]
self.my_enum = self.my_message.enum_types_by_name[enum_proto.name]
-
- self.my_method = descriptor.MethodDescriptor(
- name='Bar',
- full_name='protobuf_unittest.TestService.Bar',
- index=0,
- containing_service=None,
- input_type=None,
- output_type=None)
- self.my_service = descriptor.ServiceDescriptor(
- name='TestServiceWithOptions',
- full_name='protobuf_unittest.TestServiceWithOptions',
- file=self.my_file,
- index=0,
- methods=[
- self.my_method
- ])
+ self.my_service = self.my_file.services_by_name[service_proto.name]
+ self.my_method = self.my_service.methods_by_name[method_proto.name]
def GetDescriptorPool(self):
return symbol_database.Default().pool
@@ -109,6 +107,12 @@ class DescriptorTest(unittest.TestCase):
self.my_message.enum_types_by_name[
'ForeignEnum'].values_by_number[4].name,
self.my_message.EnumValueName('ForeignEnum', 4))
+ with self.assertRaises(KeyError):
+ self.my_message.EnumValueName('ForeignEnum', 999)
+ with self.assertRaises(KeyError):
+ self.my_message.EnumValueName('NoneEnum', 999)
+ with self.assertRaises(TypeError):
+ self.my_message.EnumValueName()
def testEnumFixups(self):
self.assertEqual(self.my_enum, self.my_enum.values[0].type)
@@ -136,15 +140,18 @@ class DescriptorTest(unittest.TestCase):
def testSimpleCustomOptions(self):
file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
- message_descriptor =\
- unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR
- field_descriptor = message_descriptor.fields_by_name["field1"]
- enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"]
- enum_value_descriptor =\
- message_descriptor.enum_values_by_name["ANENUM_VAL2"]
- service_descriptor =\
- unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR
- method_descriptor = service_descriptor.FindMethodByName("Foo")
+ message_descriptor = (unittest_custom_options_pb2.
+ TestMessageWithCustomOptions.DESCRIPTOR)
+ field_descriptor = message_descriptor.fields_by_name['field1']
+ oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof']
+ enum_descriptor = message_descriptor.enum_types_by_name['AnEnum']
+ enum_value_descriptor = (message_descriptor.
+ enum_values_by_name['ANENUM_VAL2'])
+ other_enum_value_descriptor = (message_descriptor.
+ enum_values_by_name['ANENUM_VAL1'])
+ service_descriptor = (unittest_custom_options_pb2.
+ TestServiceWithCustomOptions.DESCRIPTOR)
+ method_descriptor = service_descriptor.FindMethodByName('Foo')
file_options = file_descriptor.GetOptions()
file_opt1 = unittest_custom_options_pb2.file_opt1
@@ -157,6 +164,9 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(8765432109, field_options.Extensions[field_opt1])
field_opt2 = unittest_custom_options_pb2.field_opt2
self.assertEqual(42, field_options.Extensions[field_opt2])
+ oneof_options = oneof_descriptor.GetOptions()
+ oneof_opt1 = unittest_custom_options_pb2.oneof_opt1
+ self.assertEqual(-99, oneof_options.Extensions[oneof_opt1])
enum_options = enum_descriptor.GetOptions()
enum_opt1 = unittest_custom_options_pb2.enum_opt1
self.assertEqual(-789, enum_options.Extensions[enum_opt1])
@@ -176,6 +186,11 @@ class DescriptorTest(unittest.TestCase):
unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR)
self.assertTrue(file_descriptor.has_options)
self.assertFalse(message_descriptor.has_options)
+ self.assertTrue(field_descriptor.has_options)
+ self.assertTrue(oneof_descriptor.has_options)
+ self.assertTrue(enum_descriptor.has_options)
+ self.assertTrue(enum_value_descriptor.has_options)
+ self.assertFalse(other_enum_value_descriptor.has_options)
def testDifferentCustomOptionTypes(self):
kint32min = -2**31
@@ -398,6 +413,12 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
self.assertEqual(self.my_file.pool, self.pool)
+ self.assertFalse(self.my_file.has_options)
+ self.assertEqual('proto2', self.my_file.syntax)
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ self.my_file.CopyToProto(file_proto)
+ self.assertEqual(self.my_file.serialized_pb,
+ file_proto.SerializeToString())
# Generated modules also belong to the default pool.
self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default())
@@ -405,13 +426,31 @@ class DescriptorTest(unittest.TestCase):
api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
'Immutability of descriptors is only enforced in v2 implementation')
def testImmutableCppDescriptor(self):
+ file_descriptor = unittest_pb2.DESCRIPTOR
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ field_descriptor = message_descriptor.fields_by_name['optional_int32']
+ enum_descriptor = message_descriptor.enum_types_by_name['NestedEnum']
+ oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field']
with self.assertRaises(AttributeError):
message_descriptor.fields_by_name = None
with self.assertRaises(TypeError):
message_descriptor.fields_by_name['Another'] = None
with self.assertRaises(TypeError):
message_descriptor.fields.append(None)
+ with self.assertRaises(AttributeError):
+ field_descriptor.containing_type = message_descriptor
+ with self.assertRaises(AttributeError):
+ file_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ field_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ oneof_descriptor.has_options = False
+ with self.assertRaises(AttributeError):
+ enum_descriptor.has_options = False
+ with self.assertRaises(AttributeError) as e:
+ message_descriptor.has_options = True
+ self.assertEqual('attribute is not writable: has_options',
+ str(e.exception))
class NewDescriptorTest(DescriptorTest):
@@ -440,6 +479,12 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.CheckDescriptorMapping(message_descriptor.fields_by_name)
self.CheckDescriptorMapping(message_descriptor.fields_by_number)
self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_types_by_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_values_by_name)
+ self.CheckDescriptorMapping(message_descriptor.oneofs_by_name)
+ self.CheckDescriptorMapping(message_descriptor.enum_types[0].values_by_name)
+ # Test extension range
+ self.assertEqual(message_descriptor.extension_ranges, [])
def CheckFieldDescriptor(self, field_descriptor):
# Basic properties
@@ -448,6 +493,7 @@ class GeneratedDescriptorTest(unittest.TestCase):
self.assertEqual(field_descriptor.full_name,
'protobuf_unittest.TestAllTypes.optional_int32')
self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes')
+ self.assertEqual(field_descriptor.file, unittest_pb2.DESCRIPTOR)
# Test equality and hashability
self.assertEqual(field_descriptor, field_descriptor)
self.assertEqual(
@@ -459,32 +505,73 @@ class GeneratedDescriptorTest(unittest.TestCase):
field_descriptor)
self.assertIn(field_descriptor, [field_descriptor])
self.assertIn(field_descriptor, {field_descriptor: None})
+ self.assertEqual(None, field_descriptor.extension_scope)
+ self.assertEqual(None, field_descriptor.enum_type)
+ if api_implementation.Type() == 'cpp':
+ # For test coverage only
+ self.assertEqual(field_descriptor.id, field_descriptor.id)
def CheckDescriptorSequence(self, sequence):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Sequence.
+ self.assertNotEqual(sequence,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR.fields)
+ self.assertNotEqual(sequence, [])
+ self.assertNotEqual(sequence, 1)
+ self.assertFalse(sequence == 1) # Only for cpp test coverage
+ self.assertEqual(sequence, sequence)
+ expected_list = list(sequence)
+ self.assertEqual(expected_list, sequence)
self.assertGreater(len(sequence), 0) # Sized
- self.assertEqual(len(sequence), len(list(sequence))) # Iterable
+ self.assertEqual(len(sequence), len(expected_list)) # Iterable
+ self.assertEqual(sequence[len(sequence) -1], sequence[-1])
item = sequence[0]
self.assertEqual(item, sequence[0])
self.assertIn(item, sequence) # Container
self.assertEqual(sequence.index(item), 0)
self.assertEqual(sequence.count(item), 1)
+ other_item = unittest_pb2.NestedTestAllTypes.DESCRIPTOR.fields[0]
+ self.assertNotIn(other_item, sequence)
+ self.assertEqual(sequence.count(other_item), 0)
+ self.assertRaises(ValueError, sequence.index, other_item)
+ self.assertRaises(ValueError, sequence.index, [])
reversed_iterator = reversed(sequence)
self.assertEqual(list(reversed_iterator), list(sequence)[::-1])
self.assertRaises(StopIteration, next, reversed_iterator)
+ expected_list[0] = 'change value'
+ self.assertNotEqual(expected_list, sequence)
+ # TODO(jieluo): Change __repr__ support for DescriptorSequence.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(str(list(sequence)), str(sequence))
+ else:
+ self.assertEqual(str(sequence)[0], '<')
def CheckDescriptorMapping(self, mapping):
# Verifies that a property like 'messageDescriptor.fields' has all the
# properties of an immutable abc.Mapping.
+ self.assertNotEqual(
+ mapping, unittest_pb2.TestAllExtensions.DESCRIPTOR.fields_by_name)
+ self.assertNotEqual(mapping, {})
+ self.assertNotEqual(mapping, 1)
+ self.assertFalse(mapping == 1) # Only for cpp test coverage
+ excepted_dict = dict(mapping.items())
+ self.assertEqual(mapping, excepted_dict)
+ self.assertEqual(mapping, mapping)
self.assertGreater(len(mapping), 0) # Sized
- self.assertEqual(len(mapping), len(list(mapping))) # Iterable
+ self.assertEqual(len(mapping), len(excepted_dict)) # Iterable
if sys.version_info >= (3,):
key, item = next(iter(mapping.items()))
else:
key, item = mapping.items()[0]
self.assertIn(key, mapping) # Container
self.assertEqual(mapping.get(key), item)
+ with self.assertRaises(TypeError):
+ mapping.get()
+ # TODO(jieluo): Fix python and cpp extension diff.
+ if api_implementation.Type() == 'python':
+ self.assertRaises(TypeError, mapping.get, [])
+ else:
+ self.assertEqual(None, mapping.get([]))
# keys(), iterkeys() &co
item = (next(iter(mapping.keys())), next(iter(mapping.values())))
self.assertEqual(item, next(iter(mapping.items())))
@@ -495,6 +582,18 @@ class GeneratedDescriptorTest(unittest.TestCase):
CheckItems(mapping.keys(), mapping.iterkeys())
CheckItems(mapping.values(), mapping.itervalues())
CheckItems(mapping.items(), mapping.iteritems())
+ excepted_dict[key] = 'change value'
+ self.assertNotEqual(mapping, excepted_dict)
+ del excepted_dict[key]
+ excepted_dict['new_key'] = 'new'
+ self.assertNotEqual(mapping, excepted_dict)
+ self.assertRaises(KeyError, mapping.__getitem__, 'key_error')
+ self.assertRaises(KeyError, mapping.__getitem__, len(mapping) + 1)
+ # TODO(jieluo): Add __repr__ support for DescriptorMapping.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(len(str(dict(mapping.items()))), len(str(mapping)))
+ else:
+ self.assertEqual(str(mapping)[0], '<')
def testDescriptor(self):
message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -504,13 +603,26 @@ class GeneratedDescriptorTest(unittest.TestCase):
field_descriptor = message_descriptor.fields_by_camelcase_name[
'optionalInt32']
self.CheckFieldDescriptor(field_descriptor)
+ enum_descriptor = unittest_pb2.DESCRIPTOR.enum_types_by_name[
+ 'ForeignEnum']
+ self.assertEqual(None, enum_descriptor.containing_type)
+ # Test extension range
+ self.assertEqual(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
+ [(1, 536870912)])
+ self.assertEqual(
+ unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
+ [(42, 43), (4143, 4244), (65536, 536870912)])
def testCppDescriptorContainer(self):
- # Check that the collection is still valid even if the parent disappeared.
- enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum']
- values = enum.values
- del enum
- self.assertEqual('FOO', values[0].name)
+ containing_file = unittest_pb2.DESCRIPTOR
+ self.CheckDescriptorSequence(containing_file.dependencies)
+ self.CheckDescriptorMapping(containing_file.message_types_by_name)
+ self.CheckDescriptorMapping(containing_file.enum_types_by_name)
+ self.CheckDescriptorMapping(containing_file.services_by_name)
+ self.CheckDescriptorMapping(containing_file.extensions_by_name)
+ self.CheckDescriptorMapping(
+ unittest_pb2.TestNestedExtension.DESCRIPTOR.extensions_by_name)
def testCppDescriptorContainer_Iterator(self):
# Same test with the iterator
@@ -519,6 +631,24 @@ class GeneratedDescriptorTest(unittest.TestCase):
del enum
self.assertEqual('FOO', next(values_iter).name)
+ def testServiceDescriptor(self):
+ service_descriptor = unittest_pb2.DESCRIPTOR.services_by_name['TestService']
+ self.assertEqual(service_descriptor.name, 'TestService')
+ self.assertEqual(service_descriptor.methods[0].name, 'Foo')
+ self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR)
+ self.assertEqual(service_descriptor.index, 0)
+ self.CheckDescriptorMapping(service_descriptor.methods_by_name)
+
+ def testOneofDescriptor(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ oneof_descriptor = message_descriptor.oneofs_by_name['oneof_field']
+ self.assertFalse(oneof_descriptor.has_options)
+ self.assertEqual(message_descriptor, oneof_descriptor.containing_type)
+ self.assertEqual('oneof_field', oneof_descriptor.name)
+ self.assertEqual('protobuf_unittest.TestAllTypes.oneof_field',
+ oneof_descriptor.full_name)
+ self.assertEqual(0, oneof_descriptor.index)
+
class DescriptorCopyToProtoTest(unittest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
@@ -596,7 +726,7 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
"""
self._InternalTestCopyToProto(
- unittest_pb2._FOREIGNENUM,
+ unittest_pb2.ForeignEnum.DESCRIPTOR,
descriptor_pb2.EnumDescriptorProto,
TEST_FOREIGN_ENUM_ASCII)
@@ -612,6 +742,19 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
deprecated: true
>
>
+ field {
+ name: "deprecated_int32_in_oneof"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ options {
+ deprecated: true
+ }
+ oneof_index: 0
+ }
+ oneof_decl {
+ name: "oneof_fields"
+ }
"""
self._InternalTestCopyToProto(
@@ -655,49 +798,64 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
descriptor_pb2.DescriptorProto,
TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
- # Disable this test so we can make changes to the proto file.
- # TODO(xiaofeng): Enable this test after cl/55530659 is submitted.
- #
- # def testCopyToProto_FileDescriptor(self):
- # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
- # name: 'google/protobuf/unittest_import.proto'
- # package: 'protobuf_unittest_import'
- # dependency: 'google/protobuf/unittest_import_public.proto'
- # message_type: <
- # name: 'ImportMessage'
- # field: <
- # name: 'd'
- # number: 1
- # label: 1 # Optional
- # type: 5 # TYPE_INT32
- # >
- # >
- # """ +
- # """enum_type: <
- # name: 'ImportEnum'
- # value: <
- # name: 'IMPORT_FOO'
- # number: 7
- # >
- # value: <
- # name: 'IMPORT_BAR'
- # number: 8
- # >
- # value: <
- # name: 'IMPORT_BAZ'
- # number: 9
- # >
- # >
- # options: <
- # java_package: 'com.google.protobuf.test'
- # optimize_for: 1 # SPEED
- # >
- # public_dependency: 0
- # """)
- # self._InternalTestCopyToProto(
- # unittest_import_pb2.DESCRIPTOR,
- # descriptor_pb2.FileDescriptorProto,
- # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
+ def testCopyToProto_FileDescriptor(self):
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
+ name: 'google/protobuf/unittest_import.proto'
+ package: 'protobuf_unittest_import'
+ dependency: 'google/protobuf/unittest_import_public.proto'
+ message_type: <
+ name: 'ImportMessage'
+ field: <
+ name: 'd'
+ number: 1
+ label: 1 # Optional
+ type: 5 # TYPE_INT32
+ >
+ >
+ """ +
+ """enum_type: <
+ name: 'ImportEnum'
+ value: <
+ name: 'IMPORT_FOO'
+ number: 7
+ >
+ value: <
+ name: 'IMPORT_BAR'
+ number: 8
+ >
+ value: <
+ name: 'IMPORT_BAZ'
+ number: 9
+ >
+ >
+ enum_type: <
+ name: 'ImportEnumForMap'
+ value: <
+ name: 'UNKNOWN'
+ number: 0
+ >
+ value: <
+ name: 'FOO'
+ number: 1
+ >
+ value: <
+ name: 'BAR'
+ number: 2
+ >
+ >
+ options: <
+ java_package: 'com.google.protobuf.test'
+ optimize_for: 1 # SPEED
+ """ +
+ """
+ cc_enable_arenas: true
+ >
+ public_dependency: 0
+ """)
+ self._InternalTestCopyToProto(
+ unittest_import_pb2.DESCRIPTOR,
+ descriptor_pb2.FileDescriptorProto,
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
def testCopyToProto_ServiceDescriptor(self):
TEST_SERVICE_ASCII = """
@@ -713,12 +871,47 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
- # TODO(rocking): enable this test after the proto descriptor change is
- # checked in.
- #self._InternalTestCopyToProto(
- # unittest_pb2.TestService.DESCRIPTOR,
- # descriptor_pb2.ServiceDescriptorProto,
- # TEST_SERVICE_ASCII)
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestService.DESCRIPTOR,
+ descriptor_pb2.ServiceDescriptorProto,
+ TEST_SERVICE_ASCII)
+
+ @unittest.skipIf(
+ api_implementation.Type() == 'python',
+ 'It is not implemented in python.')
+ # TODO(jieluo): Add support for pure python or remove in c extension.
+ def testCopyToProto_MethodDescriptor(self):
+ expected_ascii = """
+ name: 'Foo'
+ input_type: '.protobuf_unittest.FooRequest'
+ output_type: '.protobuf_unittest.FooResponse'
+ """
+ method_descriptor = unittest_pb2.TestService.DESCRIPTOR.FindMethodByName(
+ 'Foo')
+ self._InternalTestCopyToProto(
+ method_descriptor,
+ descriptor_pb2.MethodDescriptorProto,
+ expected_ascii)
+
+ @unittest.skipIf(
+ api_implementation.Type() == 'python',
+ 'Pure python does not raise error.')
+ # TODO(jieluo): Fix pure python to check with the proto type.
+ def testCopyToProto_TypeError(self):
+ file_proto = descriptor_pb2.FileDescriptorProto()
+ self.assertRaises(TypeError,
+ unittest_pb2.TestEmptyMessage.DESCRIPTOR.CopyToProto,
+ file_proto)
+ self.assertRaises(TypeError,
+ unittest_pb2.ForeignEnum.DESCRIPTOR.CopyToProto,
+ file_proto)
+ self.assertRaises(TypeError,
+ unittest_pb2.TestService.DESCRIPTOR.CopyToProto,
+ file_proto)
+ proto = descriptor_pb2.DescriptorProto()
+ self.assertRaises(TypeError,
+ unittest_import_pb2.DESCRIPTOR.CopyToProto,
+ proto)
class MakeDescriptorTest(unittest.TestCase):
@@ -764,6 +957,11 @@ class MakeDescriptorTest(unittest.TestCase):
'Foo2.Sub.bar_field')
self.assertEqual(result.nested_types[0].fields[0].enum_type,
result.nested_types[0].enum_types[0])
+ self.assertFalse(result.has_options)
+ self.assertFalse(result.fields[0].has_options)
+ if api_implementation.Type() == 'cpp':
+ with self.assertRaises(AttributeError):
+ result.fields[0].has_options = False
def testMakeDescriptorWithUnsignedIntField(self):
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
@@ -816,6 +1014,23 @@ class MakeDescriptorTest(unittest.TestCase):
self.assertEqual(result.fields[index].camelcase_name,
camelcase_names[index])
+ def testJsonName(self):
+ descriptor_proto = descriptor_pb2.DescriptorProto()
+ descriptor_proto.name = 'TestJsonName'
+ names = ['field_name', 'fieldName', 'FieldName',
+ '_field_name', 'FIELD_NAME', 'json_name']
+ json_names = ['fieldName', 'fieldName', 'FieldName',
+ 'FieldName', 'FIELDNAME', '@type']
+ for index in range(len(names)):
+ field = descriptor_proto.field.add()
+ field.number = index + 1
+ field.name = names[index]
+ field.json_name = '@type'
+ result = descriptor.MakeDescriptor(descriptor_proto)
+ for index in range(len(json_names)):
+ self.assertEqual(result.fields[index].json_name,
+ json_names[index])
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 48ef2df3..0d1f49dd 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -340,7 +340,7 @@ def MessageSetItemSizer(field_number):
# Map is special: it needs custom logic to compute its size properly.
-def MapSizer(field_descriptor):
+def MapSizer(field_descriptor, is_message_map):
"""Returns a sizer for a map field."""
# Can't look at field_descriptor.message_type._concrete_class because it may
@@ -355,9 +355,12 @@ def MapSizer(field_descriptor):
# It's wasteful to create the messages and throw them away one second
# later since we'll do the same for the actual encode. But there's not an
# obvious way to avoid this within the current design without tons of code
- # duplication.
+ # duplication. For message map, value.ByteSize() should be called to
+ # update the status.
entry_msg = message_type._concrete_class(key=key, value=value)
total += message_sizer(entry_msg)
+ if is_message_map:
+ value.ByteSize()
return total
return FieldSize
@@ -369,7 +372,7 @@ def MapSizer(field_descriptor):
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
- def EncodeVarint(write, value):
+ def EncodeVarint(write, value, unused_deterministic=None):
bits = value & 0x7f
value >>= 7
while value:
@@ -385,7 +388,7 @@ def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
- def EncodeSignedVarint(write, value):
+ def EncodeSignedVarint(write, value, unused_deterministic=None):
if value < 0:
value += (1 << 64)
bits = value & 0x7f
@@ -408,14 +411,15 @@ def _VarintBytes(value):
called at startup time so it doesn't need to be fast."""
pieces = []
- _EncodeVarint(pieces.append, value)
+ _EncodeVarint(pieces.append, value, True)
return b"".join(pieces)
def TagBytes(field_number, wire_type):
"""Encode the given tag and return the bytes. Only called at startup."""
- return _VarintBytes(wire_format.PackTag(field_number, wire_type))
+ return six.binary_type(
+ _VarintBytes(wire_format.PackTag(field_number, wire_type)))
# --------------------------------------------------------------------
# As with sizers (see above), we have a number of common encoder
@@ -437,27 +441,27 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(element)
- local_EncodeVarint(write, size)
+ local_EncodeVarint(write, size, deterministic)
for element in value:
- encode_value(write, element)
+ encode_value(write, element, deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
- encode_value(write, element)
+ encode_value(write, element, deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag_bytes)
- return encode_value(write, value)
+ return encode_value(write, value, deterministic)
return EncodeField
return SpecificEncoder
@@ -471,27 +475,27 @@ def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(modify_value(element))
- local_EncodeVarint(write, size)
+ local_EncodeVarint(write, size, deterministic)
for element in value:
- encode_value(write, modify_value(element))
+ encode_value(write, modify_value(element), deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
- encode_value(write, modify_value(element))
+ encode_value(write, modify_value(element), deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag_bytes)
- return encode_value(write, modify_value(value))
+ return encode_value(write, modify_value(value), deterministic)
return EncodeField
return SpecificEncoder
@@ -512,22 +516,22 @@ def _StructPackEncoder(wire_type, format):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value) * value_size)
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
write(local_struct_pack(format, element))
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
write(local_struct_pack(format, element))
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
return write(local_struct_pack(format, value))
return EncodeField
@@ -578,9 +582,9 @@ def _FloatingPointEncoder(wire_type, format):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value) * value_size)
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
# This try/except block is going to be faster than any code that
# we could write to check whether element is finite.
@@ -591,7 +595,7 @@ def _FloatingPointEncoder(wire_type, format):
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
try:
@@ -601,7 +605,7 @@ def _FloatingPointEncoder(wire_type, format):
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
try:
write(local_struct_pack(format, value))
@@ -647,9 +651,9 @@ def BoolEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value))
+ local_EncodeVarint(write, len(value), deterministic)
for element in value:
if element:
write(true_byte)
@@ -658,7 +662,7 @@ def BoolEncoder(field_number, is_repeated, is_packed):
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic=None):
for element in value:
write(tag_bytes)
if element:
@@ -668,7 +672,7 @@ def BoolEncoder(field_number, is_repeated, is_packed):
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic=None):
write(tag_bytes)
if value:
return write(true_byte)
@@ -684,18 +688,18 @@ def StringEncoder(field_number, is_repeated, is_packed):
local_len = len
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
encoded = element.encode('utf-8')
write(tag)
- local_EncodeVarint(write, local_len(encoded))
+ local_EncodeVarint(write, local_len(encoded), deterministic)
write(encoded)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
encoded = value.encode('utf-8')
write(tag)
- local_EncodeVarint(write, local_len(encoded))
+ local_EncodeVarint(write, local_len(encoded), deterministic)
return write(encoded)
return EncodeField
@@ -708,16 +712,16 @@ def BytesEncoder(field_number, is_repeated, is_packed):
local_len = len
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
- local_EncodeVarint(write, local_len(element))
+ local_EncodeVarint(write, local_len(element), deterministic)
write(element)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag)
- local_EncodeVarint(write, local_len(value))
+ local_EncodeVarint(write, local_len(value), deterministic)
return write(value)
return EncodeField
@@ -729,16 +733,16 @@ def GroupEncoder(field_number, is_repeated, is_packed):
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(start_tag)
- element._InternalSerialize(write)
+ element._InternalSerialize(write, deterministic)
write(end_tag)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(start_tag)
- value._InternalSerialize(write)
+ value._InternalSerialize(write, deterministic)
return write(end_tag)
return EncodeField
@@ -750,17 +754,17 @@ def MessageEncoder(field_number, is_repeated, is_packed):
local_EncodeVarint = _EncodeVarint
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
- local_EncodeVarint(write, element.ByteSize())
- element._InternalSerialize(write)
+ local_EncodeVarint(write, element.ByteSize(), deterministic)
+ element._InternalSerialize(write, deterministic)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag)
- local_EncodeVarint(write, value.ByteSize())
- return value._InternalSerialize(write)
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
+ return value._InternalSerialize(write, deterministic)
return EncodeField
@@ -787,10 +791,10 @@ def MessageSetItemEncoder(field_number):
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
local_EncodeVarint = _EncodeVarint
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(start_bytes)
- local_EncodeVarint(write, value.ByteSize())
- value._InternalSerialize(write)
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
+ value._InternalSerialize(write, deterministic)
return write(end_bytes)
return EncodeField
@@ -815,9 +819,10 @@ def MapEncoder(field_descriptor):
message_type = field_descriptor.message_type
encode_message = MessageEncoder(field_descriptor.number, False, False)
- def EncodeField(write, value):
- for key in value:
+ def EncodeField(write, value, deterministic):
+ value_keys = sorted(value.keys()) if deterministic else value
+ for key in value_keys:
entry_msg = message_type._concrete_class(key=key, value=value[key])
- encode_message(write, entry_msg)
+ encode_message(write, entry_msg, deterministic)
return EncodeField
diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto
index bb1b54ad..5fcbc5ac 100644
--- a/python/google/protobuf/internal/factory_test2.proto
+++ b/python/google/protobuf/internal/factory_test2.proto
@@ -97,3 +97,8 @@ message MessageWithNestedEnumOnly {
extend Factory1Message {
optional string another_field = 1002;
}
+
+message MessageWithOption {
+ option no_standard_descriptor_accessor = true;
+ optional int32 field1 = 1;
+}
diff --git a/python/google/protobuf/internal/file_options_test.proto b/python/google/protobuf/internal/file_options_test.proto
new file mode 100644
index 00000000..4eceeb07
--- /dev/null
+++ b/python/google/protobuf/internal/file_options_test.proto
@@ -0,0 +1,43 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+syntax = "proto2";
+
+import "google/protobuf/descriptor.proto";
+
+package google.protobuf.python.internal;
+
+message FooOptions {
+ optional string foo_name = 1;
+}
+
+extend .google.protobuf.FileOptions {
+ optional FooOptions foo_options = 120436268;
+}
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index 9956da59..7f13f9da 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -42,9 +42,10 @@ further ensures that we can use Python protocol message objects as we expect.
__author__ = 'robinson@google.com (Will Robinson)'
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
@@ -226,7 +227,8 @@ class GeneratorTest(unittest.TestCase):
[unittest_import_pb2.DESCRIPTOR])
self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies,
[unittest_import_public_pb2.DESCRIPTOR])
-
+ self.assertEqual(unittest_import_pb2.DESCRIPTOR.public_dependencies,
+ [unittest_import_public_pb2.DESCRIPTOR])
def testNoGenericServices(self):
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 49e96a46..d891dce1 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -39,15 +39,18 @@ import math
import sys
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import any_pb2
from google.protobuf import duration_pb2
from google.protobuf import field_mask_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import wrappers_pb2
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
from google.protobuf.internal import well_known_types
from google.protobuf import json_format
from google.protobuf.util import json_format_proto3_pb2
@@ -157,6 +160,98 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testUnknownEnumToJsonAndBack(self):
+ text = '{\n "enumValue": 999\n}'
+ message = json_format_proto3_pb2.TestMessage()
+ message.enum_value = 999
+ self.assertEqual(json_format.MessageToJson(message),
+ text)
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse(text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testExtensionToJsonAndBack(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_text = json_format.MessageToJson(
+ message
+ )
+ parsed_message = unittest_mset_pb2.TestMessageSetContainer()
+ json_format.Parse(message_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testExtensionErrors(self):
+ self.CheckError('{"[extensionField]": {}}',
+ 'Message type proto3.TestMessage does not have extensions')
+
+ def testExtensionToDictAndBack(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_dict = json_format.MessageToDict(
+ message
+ )
+ parsed_message = unittest_mset_pb2.TestMessageSetContainer()
+ json_format.ParseDict(message_dict, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testExtensionSerializationDictMatchesProto3Spec(self):
+ """See go/proto3-json-spec for spec.
+ """
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_dict = json_format.MessageToDict(
+ message
+ )
+ golden_dict = {
+ 'messageSet': {
+ '[protobuf_unittest.'
+ 'TestMessageSetExtension1.messageSetExtension]': {
+ 'i': 23,
+ },
+ '[protobuf_unittest.'
+ 'TestMessageSetExtension2.messageSetExtension]': {
+ 'str': u'foo',
+ },
+ },
+ }
+ self.assertEqual(golden_dict, message_dict)
+
+
+ def testExtensionSerializationJsonMatchesProto3Spec(self):
+ """See go/proto3-json-spec for spec.
+ """
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_text = json_format.MessageToJson(
+ message
+ )
+ ext1_text = ('protobuf_unittest.TestMessageSetExtension1.'
+ 'messageSetExtension')
+ ext2_text = ('protobuf_unittest.TestMessageSetExtension2.'
+ 'messageSetExtension')
+ golden_text = ('{"messageSet": {'
+ ' "[%s]": {'
+ ' "i": 23'
+ ' },'
+ ' "[%s]": {'
+ ' "str": "foo"'
+ ' }'
+ '}}') % (ext1_text, ext2_text)
+ self.assertEqual(json.loads(golden_text), json.loads(message_text))
+
+
def testJsonEscapeString(self):
message = json_format_proto3_pb2.TestMessage()
if sys.version_info[0] < 3:
@@ -204,8 +299,28 @@ class JsonFormatTest(JsonFormatBase):
parsed_message = json_format_proto3_pb2.TestMessage()
self.CheckParseBack(message, parsed_message)
+ def testIntegersRepresentedAsFloat(self):
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse('{"int32Value": -2.147483648e9}', message)
+ self.assertEqual(message.int32_value, -2147483648)
+ json_format.Parse('{"int32Value": 1e5}', message)
+ self.assertEqual(message.int32_value, 100000)
+ json_format.Parse('{"int32Value": 1.0}', message)
+ self.assertEqual(message.int32_value, 1)
+
def testMapFields(self):
- message = json_format_proto3_pb2.TestMap()
+ message = json_format_proto3_pb2.TestNestedMap()
+ self.assertEqual(
+ json.loads(json_format.MessageToJson(message, True)),
+ json.loads('{'
+ '"boolMap": {},'
+ '"int32Map": {},'
+ '"int64Map": {},'
+ '"uint32Map": {},'
+ '"uint64Map": {},'
+ '"stringMap": {},'
+ '"mapMap": {}'
+ '}'))
message.bool_map[True] = 1
message.bool_map[False] = 2
message.int32_map[1] = 2
@@ -218,17 +333,19 @@ class JsonFormatTest(JsonFormatBase):
message.uint64_map[2] = 3
message.string_map['1'] = 2
message.string_map['null'] = 3
+ message.map_map['1'].bool_map[True] = 3
self.assertEqual(
- json.loads(json_format.MessageToJson(message, True)),
+ json.loads(json_format.MessageToJson(message, False)),
json.loads('{'
'"boolMap": {"false": 2, "true": 1},'
'"int32Map": {"1": 2, "2": 3},'
'"int64Map": {"1": 2, "2": 3},'
'"uint32Map": {"1": 2, "2": 3},'
'"uint64Map": {"1": 2, "2": 3},'
- '"stringMap": {"1": 2, "null": 3}'
+ '"stringMap": {"1": 2, "null": 3},'
+ '"mapMap": {"1": {"boolMap": {"true": 3}}}'
'}'))
- parsed_message = json_format_proto3_pb2.TestMap()
+ parsed_message = json_format_proto3_pb2.TestNestedMap()
self.CheckParseBack(message, parsed_message)
def testOneofFields(self):
@@ -246,6 +363,23 @@ class JsonFormatTest(JsonFormatBase):
parsed_message = json_format_proto3_pb2.TestOneof()
self.CheckParseBack(message, parsed_message)
+ def testSurrogates(self):
+ # Test correct surrogate handling.
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse('{"stringValue": "\\uD83D\\uDE01"}', message)
+ self.assertEqual(message.string_value,
+ b'\xF0\x9F\x98\x81'.decode('utf-8', 'strict'))
+
+ # Error case: unpaired high surrogate.
+ self.CheckError(
+ '{"stringValue": "\\uD83D"}',
+ r'Invalid \\uXXXX escape|Unpaired.*surrogate')
+
+ # Unpaired low surrogate.
+ self.CheckError(
+ '{"stringValue": "\\uDE01"}',
+ r'Invalid \\uXXXX escape|Unpaired.*surrogate')
+
def testTimestampMessage(self):
message = json_format_proto3_pb2.TestTimestamp()
message.value.seconds = 0
@@ -410,6 +544,9 @@ class JsonFormatTest(JsonFormatBase):
' "value": "hello",'
' "repeatedValue": [11.1, false, null, null]'
'}'))
+ message.Clear()
+ json_format.Parse('{"value": null}', message)
+ self.assertEqual(message.value.WhichOneof('kind'), 'null_value')
def testListValueMessage(self):
message = json_format_proto3_pb2.TestListValue()
@@ -457,6 +594,22 @@ class JsonFormatTest(JsonFormatBase):
'}\n'))
parsed_message = json_format_proto3_pb2.TestAny()
self.CheckParseBack(message, parsed_message)
+ # Must print @type first
+ test_message = json_format_proto3_pb2.TestMessage(
+ bool_value=True,
+ int32_value=20,
+ int64_value=-20,
+ uint32_value=20,
+ uint64_value=20,
+ double_value=3.14,
+ string_value='foo')
+ message.Clear()
+ message.value.Pack(test_message)
+ self.assertEqual(
+ json_format.MessageToJson(message, False)[0:68],
+ '{\n'
+ ' "value": {\n'
+ ' "@type": "type.googleapis.com/proto3.TestMessage"')
def testWellKnownInAnyMessage(self):
message = any_pb2.Any()
@@ -566,6 +719,11 @@ class JsonFormatTest(JsonFormatBase):
'}',
parsed_message)
self.assertEqual(message, parsed_message)
+ # Null and {} should have different behavior for sub message.
+ self.assertFalse(parsed_message.HasField('message_value'))
+ json_format.Parse('{"messageValue": {}}', parsed_message)
+ self.assertTrue(parsed_message.HasField('message_value'))
+ # Null is not allowed to be used as an element in repeated field.
self.assertRaisesRegexp(
json_format.ParseError,
'Failed to parse repeatedInt32Value field: '
@@ -573,6 +731,9 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse,
'{"repeatedInt32Value":[1, null]}',
parsed_message)
+ self.CheckError('{"repeatedMessageValue":[null]}',
+ 'Failed to parse repeatedMessageValue field: null is not'
+ ' allowed to be used as an element in a repeated field.')
def testNanFloat(self):
message = json_format_proto3_pb2.TestMessage()
@@ -587,15 +748,26 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('',
r'Failed to load JSON: (Expecting value)|(No JSON).')
- def testParseBadEnumValue(self):
- self.CheckError(
- '{"enumValue": 1}',
- 'Enum value must be a string literal with double quotes. '
- 'Type "proto3.EnumType" has no value named 1.')
+ def testParseEnumValue(self):
+ message = json_format_proto3_pb2.TestMessage()
+ text = '{"enumValue": 0}'
+ json_format.Parse(text, message)
+ text = '{"enumValue": 1}'
+ json_format.Parse(text, message)
self.CheckError(
'{"enumValue": "baz"}',
- 'Enum value must be a string literal with double quotes. '
- 'Type "proto3.EnumType" has no value named baz.')
+ 'Failed to parse enumValue field: Invalid enum value baz '
+ 'for enum type proto3.EnumType.')
+ # Proto3 accepts numeric unknown enums.
+ text = '{"enumValue": 12345}'
+ json_format.Parse(text, message)
+ # Proto2 does not accept unknown enums.
+ message = unittest_pb2.TestAllTypes()
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse optionalNestedEnum field: Invalid enum value 12345 '
+ 'for enum type protobuf_unittest.TestAllTypes.NestedEnum.',
+ json_format.Parse, '{"optionalNestedEnum": 12345}', message)
def testParseBadIdentifer(self):
self.CheckError('{int32Value: 1}',
@@ -605,6 +777,19 @@ class JsonFormatTest(JsonFormatBase):
'Message type "proto3.TestMessage" has no field named '
'"unknownName".')
+ def testIgnoreUnknownField(self):
+ text = '{"unknownName": 1}'
+ parsed_message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse(text, parsed_message, ignore_unknown_fields=True)
+ text = ('{\n'
+ ' "repeatedValue": [ {\n'
+ ' "@type": "type.googleapis.com/proto3.MessageType",\n'
+ ' "unknownName": 1\n'
+ ' }]\n'
+ '}\n')
+ parsed_message = json_format_proto3_pb2.TestAny()
+ json_format.Parse(text, parsed_message, ignore_unknown_fields=True)
+
def testDuplicateField(self):
# Duplicate key check is not supported for python2.6
if sys.version_info < (2, 7):
@@ -625,12 +810,12 @@ class JsonFormatTest(JsonFormatBase):
text = '{"int32Value": 0x12345}'
self.assertRaises(json_format.ParseError,
json_format.Parse, text, message)
+ self.CheckError('{"int32Value": 1.5}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: 1.5.')
self.CheckError('{"int32Value": 012345}',
(r'Failed to load JSON: Expecting \'?,\'? delimiter: '
r'line 1.'))
- self.CheckError('{"int32Value": 1.0}',
- 'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 1.0.')
self.CheckError('{"int32Value": " 1 "}',
'Failed to parse int32Value field: '
'Couldn\'t parse integer: " 1 ".')
@@ -640,9 +825,6 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('{"int32Value": 12345678901234567890}',
'Failed to parse int32Value field: Value out of range: '
'12345678901234567890.')
- self.CheckError('{"int32Value": 1e5}',
- 'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 100000.0.')
self.CheckError('{"uint32Value": -1}',
'Failed to parse uint32Value field: '
'Value out of range: -1.')
@@ -658,6 +840,11 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('{"bytesValue": "AQI*"}',
'Failed to parse bytesValue field: Incorrect padding.')
+ def testInvalidRepeated(self):
+ self.CheckError('{"repeatedInt32Value": 12345}',
+ (r'Failed to parse repeatedInt32Value field: repeated field'
+ r' repeatedInt32Value must be in \[\] which is 12345.'))
+
def testInvalidMap(self):
message = json_format_proto3_pb2.TestMap()
text = '{"int32Map": {"null": 2, "2": 3}}'
@@ -683,6 +870,12 @@ class JsonFormatTest(JsonFormatBase):
json_format.ParseError,
'Failed to load JSON: duplicate key a',
json_format.Parse, text, message)
+ text = r'{"stringMap": 0}'
+ self.assertRaisesRegexp(
+ json_format.ParseError,
+ 'Failed to parse stringMap field: Map field string_map must be '
+ 'in a dict which is 0.',
+ json_format.Parse, text, message)
def testInvalidTimestamp(self):
message = json_format_proto3_pb2.TestTimestamp()
@@ -706,7 +899,7 @@ class JsonFormatTest(JsonFormatBase):
text = '{"value": "0000-01-01T00:00:00Z"}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse value field: year is out of range.',
+ 'Failed to parse value field: year (0 )?is out of range.',
json_format.Parse, text, message)
# Time bigger than maxinum time.
message.value.seconds = 253402300800
@@ -758,11 +951,88 @@ class JsonFormatTest(JsonFormatBase):
'Can not find message descriptor by type_url: '
'type.googleapis.com/MessageNotExist.',
json_format.Parse, text, message)
- # Only last part is to be used.
+ # Only last part is to be used: b/25630112
text = (r'{"@type": "incorrect.googleapis.com/google.protobuf.Int32Value",'
r'"value": 1234}')
json_format.Parse(text, message)
+ def testPreservingProtoFieldNames(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ self.assertEqual('{\n "int32Value": 12345\n}',
+ json_format.MessageToJson(message))
+ self.assertEqual('{\n "int32_value": 12345\n}',
+ json_format.MessageToJson(message, False, True))
+ # When including_default_value_fields is True.
+ message = json_format_proto3_pb2.TestTimestamp()
+ self.assertEqual('{\n "repeatedValue": []\n}',
+ json_format.MessageToJson(message, True, False))
+ self.assertEqual('{\n "repeated_value": []\n}',
+ json_format.MessageToJson(message, True, True))
+
+ # Parsers accept both original proto field names and lowerCamelCase names.
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse('{"int32Value": 54321}', message)
+ self.assertEqual(54321, message.int32_value)
+ json_format.Parse('{"int32_value": 12345}', message)
+ self.assertEqual(12345, message.int32_value)
+
+ def testIndent(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ self.assertEqual('{\n"int32Value": 12345\n}',
+ json_format.MessageToJson(message, indent=0))
+
+ def testFormatEnumsAsInts(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.enum_value = json_format_proto3_pb2.BAR
+ message.repeated_enum_value.append(json_format_proto3_pb2.FOO)
+ message.repeated_enum_value.append(json_format_proto3_pb2.BAR)
+ self.assertEqual(json.loads('{\n'
+ ' "enumValue": 1,\n'
+ ' "repeatedEnumValue": [0, 1]\n'
+ '}\n'),
+ json.loads(json_format.MessageToJson(
+ message, use_integers_for_enums=True)))
+
+ def testParseDict(self):
+ expected = 12345
+ js_dict = {'int32Value': expected}
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.ParseDict(js_dict, message)
+ self.assertEqual(expected, message.int32_value)
+
+ def testMessageToDict(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ expected = {'int32Value': 12345}
+ self.assertEqual(expected,
+ json_format.MessageToDict(message))
+
+ def testJsonName(self):
+ message = json_format_proto3_pb2.TestCustomJsonName()
+ message.value = 12345
+ self.assertEqual('{\n "@value": 12345\n}',
+ json_format.MessageToJson(message))
+ parsed_message = json_format_proto3_pb2.TestCustomJsonName()
+ self.CheckParseBack(message, parsed_message)
+
+ def testSortKeys(self):
+ # Testing sort_keys is not perfectly working, as by random luck we could
+ # get the output sorted. We just use a selection of names.
+ message = json_format_proto3_pb2.TestMessage(bool_value=True,
+ int32_value=1,
+ int64_value=3,
+ uint32_value=4,
+ string_value='bla')
+ self.assertEqual(
+ json_format.MessageToJson(message, sort_keys=True),
+ # We use json.dumps() instead of a hardcoded string due to differences
+ # between Python 2 and Python 3.
+ json.dumps({'boolValue': True, 'int32Value': 1, 'int64Value': '3',
+ 'uint32Value': 4, 'stringValue': 'bla'},
+ indent=2, sort_keys=True))
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 2fbe5ea7..6df52ed2 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -35,10 +35,12 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
@@ -105,30 +107,115 @@ class MessageFactoryTest(unittest.TestCase):
def testGetMessages(self):
# performed twice because multiple calls with the same input must be allowed
for _ in range(2):
- messages = message_factory.GetMessages([self.factory_test1_fd,
- self.factory_test2_fd])
+ # GetMessage should work regardless of the order the FileDescriptorProto
+ # are provided. In particular, the function should succeed when the files
+ # are not in the topological order of dependencies.
+
+ # Assuming factory_test2_fd depends on factory_test1_fd.
+ self.assertIn(self.factory_test1_fd.name,
+ self.factory_test2_fd.dependency)
+ # Get messages should work when a file comes before its dependencies:
+ # factory_test2_fd comes before factory_test1_fd.
+ messages = message_factory.GetMessages([self.factory_test2_fd,
+ self.factory_test1_fd])
self.assertTrue(
set(['google.protobuf.python.internal.Factory2Message',
'google.protobuf.python.internal.Factory1Message'],
).issubset(set(messages.keys())))
self._ExerciseDynamicClass(
messages['google.protobuf.python.internal.Factory2Message'])
- self.assertTrue(
- set(['google.protobuf.python.internal.Factory2Message.one_more_field',
- 'google.protobuf.python.internal.another_field'],
- ).issubset(
- set(messages['google.protobuf.python.internal.Factory1Message']
- ._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
+ self.assertTrue(set(
+ ['google.protobuf.python.internal.Factory2Message.one_more_field',
+ 'google.protobuf.python.internal.another_field'],).issubset(set(
+ ext.full_name
+ for ext in factory_msg1.DESCRIPTOR.file.pool.FindAllExtensions(
+ factory_msg1.DESCRIPTOR))))
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
- ext1 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.Factory2Message.one_more_field']
- ext2 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.another_field']
+ ext1 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ ext2 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2'
self.assertEqual('test1', msg1.Extensions[ext1])
self.assertEqual('test2', msg1.Extensions[ext2])
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByNumber(12321))
+ if api_implementation.Type() == 'cpp':
+ # TODO(jieluo): Fix len to return the correct value.
+ # self.assertEqual(2, len(msg1.Extensions))
+ self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
+ self.assertRaises(TypeError,
+ msg1.Extensions._FindExtensionByName, 0)
+ self.assertRaises(TypeError,
+ msg1.Extensions._FindExtensionByNumber, '')
+ else:
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByName(0))
+ self.assertEqual(None,
+ msg1.Extensions._FindExtensionByNumber(''))
+
+ def testDuplicateExtensionNumber(self):
+ pool = descriptor_pool.DescriptorPool()
+ factory = message_factory.MessageFactory(pool=pool)
+
+ # Add Container message.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/container.proto'
+ f.package = 'google.protobuf.python.internal'
+ msg = f.message_type.add()
+ msg.name = 'Container'
+ rng = msg.extension_range.add()
+ rng.start = 1
+ rng.end = 10
+ pool.Add(f)
+ msgs = factory.GetMessages([f.name])
+ self.assertIn('google.protobuf.python.internal.Container', msgs)
+
+ # Extend container.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/extension.proto'
+ f.package = 'google.protobuf.python.internal'
+ f.dependency.append('google/protobuf/internal/container.proto')
+ msg = f.message_type.add()
+ msg.name = 'Extension'
+ ext = msg.extension.add()
+ ext.name = 'extension_field'
+ ext.number = 2
+ ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ ext.type_name = 'Extension'
+ ext.extendee = 'Container'
+ pool.Add(f)
+ msgs = factory.GetMessages([f.name])
+ self.assertIn('google.protobuf.python.internal.Extension', msgs)
+
+ # Add Duplicate extending the same field number.
+ f = descriptor_pb2.FileDescriptorProto()
+ f.name = 'google/protobuf/internal/duplicate.proto'
+ f.package = 'google.protobuf.python.internal'
+ f.dependency.append('google/protobuf/internal/container.proto')
+ msg = f.message_type.add()
+ msg.name = 'Duplicate'
+ ext = msg.extension.add()
+ ext.name = 'extension_field'
+ ext.number = 2
+ ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+ ext.type_name = 'Duplicate'
+ ext.extendee = 'Container'
+ pool.Add(f)
+
+ with self.assertRaises(Exception) as cm:
+ factory.GetMessages([f.name])
+
+ self.assertIn(str(cm.exception),
+ ['Extensions '
+ '"google.protobuf.python.internal.Duplicate.extension_field" and'
+ ' "google.protobuf.python.internal.Extension.extension_field"'
+ ' both try to extend message type'
+ ' "google.protobuf.python.internal.Container"'
+ ' with field number 2.',
+ 'Double registration of Extensions'])
if __name__ == '__main__':
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index d03f2d25..61a56a67 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -51,24 +51,37 @@ import operator
import pickle
import six
import sys
+import warnings
try:
- import unittest2 as unittest
+ import unittest2 as unittest # PY26
except ImportError:
import unittest
-from google.protobuf.internal import _parameterized
+try:
+ cmp # Python 2
+except NameError:
+ cmp = lambda x, y: (x > y) - (x < y) # Python 3
+
+from google.protobuf import map_proto2_unittest_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
-from google.protobuf.internal import any_test_pb2
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
+from google.protobuf import text_format
from google.protobuf.internal import api_implementation
+from google.protobuf.internal import encoder
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf import message
+from google.protobuf.internal import _parameterized
if six.PY3:
long = int
+
# Python pre-2.6 does not have isinf() or isnan() functions, so we have
# to provide our own.
def isnan(val):
@@ -83,10 +96,13 @@ def IsNegInf(val):
return isinf(val) and (val < 0)
-@_parameterized.Parameters(
- (unittest_pb2),
- (unittest_proto3_arena_pb2))
-class MessageTest(unittest.TestCase):
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
+@_parameterized.named_parameters(
+ ('_proto2', unittest_pb2),
+ ('_proto3', unittest_proto3_arena_pb2))
+class MessageTest(BaseTestCase):
def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python':
@@ -127,6 +143,63 @@ class MessageTest(unittest.TestCase):
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
+ def testParseErrors(self, message_module):
+ msg = message_module.TestAllTypes()
+ self.assertRaises(TypeError, msg.FromString, 0)
+ self.assertRaises(Exception, msg.FromString, '0')
+ # TODO(jieluo): Fix cpp extension to raise error instead of warning.
+ # b/27494216
+ end_tag = encoder.TagBytes(1, 4)
+ if api_implementation.Type() == 'python':
+ with self.assertRaises(message.DecodeError) as context:
+ msg.FromString(end_tag)
+ self.assertEqual('Unexpected end-group tag.', str(context.exception))
+ else:
+ with warnings.catch_warnings(record=True) as w:
+ # Cause all warnings to always be triggered.
+ warnings.simplefilter('always')
+ msg.FromString(end_tag)
+ assert len(w) == 1
+ assert issubclass(w[-1].category, RuntimeWarning)
+ self.assertEqual('Unexpected end-group tag: Not all data was converted',
+ str(w[-1].message))
+
+ def testDeterminismParameters(self, message_module):
+ # This message is always deterministically serialized, even if determinism
+ # is disabled, so we can use it to verify that all the determinism
+ # parameters work correctly.
+ golden_data = (b'\xe2\x02\nOne string'
+ b'\xe2\x02\nTwo string'
+ b'\xe2\x02\nRed string'
+ b'\xe2\x02\x0bBlue string')
+ golden_message = message_module.TestAllTypes()
+ golden_message.repeated_string.extend([
+ 'One string',
+ 'Two string',
+ 'Red string',
+ 'Blue string',
+ ])
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=None))
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=False))
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=True))
+
+ class BadArgError(Exception):
+ pass
+
+ class BadArg(object):
+
+ def __nonzero__(self):
+ raise BadArgError()
+
+ def __bool__(self):
+ raise BadArgError()
+
+ with self.assertRaises(BadArgError):
+ golden_message.SerializeToString(deterministic=BadArg())
+
def testPickleSupport(self, message_module):
golden_data = test_util.GoldenFileData('golden_message')
golden_message = message_module.TestAllTypes()
@@ -368,6 +441,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_int32[0], 1)
self.assertEqual(message.repeated_int32[1], 2)
self.assertEqual(message.repeated_int32[2], 3)
+ self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
message.repeated_float.append(1.1)
message.repeated_float.append(1.3)
@@ -384,6 +458,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_string[0], 'a')
self.assertEqual(message.repeated_string[1], 'b')
self.assertEqual(message.repeated_string[2], 'c')
+ self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
message.repeated_bytes.append(b'a')
message.repeated_bytes.append(b'c')
@@ -392,6 +467,7 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_bytes[0], b'a')
self.assertEqual(message.repeated_bytes[1], b'b')
self.assertEqual(message.repeated_bytes[2], b'c')
+ self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
"""Check some different types with custom comparator."""
@@ -430,6 +506,8 @@ class MessageTest(unittest.TestCase):
self.assertEqual(message.repeated_nested_message[3].bb, 4)
self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6)
+ self.assertEqual(str(message.repeated_nested_message),
+ '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
def testSortingRepeatedCompositeFieldsStable(self, message_module):
"""Check passing a custom comparator to sort a repeated composite field."""
@@ -555,6 +633,18 @@ class MessageTest(unittest.TestCase):
self.assertIsInstance(m.repeated_nested_message,
collections.MutableSequence)
+ def testRepeatedFieldsNotHashable(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(TypeError):
+ hash(m.repeated_int32)
+ with self.assertRaises(TypeError):
+ hash(m.repeated_nested_message)
+
+ def testRepeatedFieldInsideNestedMessage(self, message_module):
+ m = message_module.NestedTestAllTypes()
+ m.payload.repeated_int32.extend([])
+ self.assertTrue(m.HasField('payload'))
+
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -567,6 +657,7 @@ class MessageTest(unittest.TestCase):
def testOneofGetCaseNonexistingField(self, message_module):
m = message_module.TestAllTypes()
self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
+ self.assertRaises(Exception, m.WhichOneof, 0)
def testOneofDefaultValues(self, message_module):
m = message_module.TestAllTypes()
@@ -942,6 +1033,8 @@ class MessageTest(unittest.TestCase):
m = message_module.TestAllTypes()
with self.assertRaises(IndexError) as _:
m.repeated_nested_message.pop()
+ with self.assertRaises(TypeError) as _:
+ m.repeated_nested_message.pop('0')
for i in range(5):
n = m.repeated_nested_message.add()
n.bb = i
@@ -950,9 +1043,42 @@ class MessageTest(unittest.TestCase):
self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
+ def testRepeatedCompareWithSelf(self, message_module):
+ m = message_module.TestAllTypes()
+ for i in range(5):
+ m.repeated_int32.insert(i, i)
+ n = m.repeated_nested_message.add()
+ n.bb = i
+ self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
+ self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
+
+ def testReleasedNestedMessages(self, message_module):
+ """A case that lead to a segfault when a message detached from its parent
+ container has itself a child container.
+ """
+ m = message_module.NestedTestAllTypes()
+ m = m.repeated_child.add()
+ m = m.child
+ m = m.repeated_child.add()
+ self.assertEqual(m.payload.optional_int32, 0)
+
+ def testSetRepeatedComposite(self, message_module):
+ m = message_module.TestAllTypes()
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
+ m.repeated_int32.append(1)
+ if api_implementation.Type() == 'cpp':
+ # For test coverage: cpp has a different path if composite
+ # field is in cache
+ with self.assertRaises(TypeError):
+ m.repeated_int32 = []
+ else:
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
+
# Class to test proto2-only features (required, extensions, etc.)
-class Proto2Test(unittest.TestCase):
+class Proto2Test(BaseTestCase):
def testFieldPresence(self):
message = unittest_pb2.TestAllTypes()
@@ -1002,18 +1128,46 @@ class Proto2Test(unittest.TestCase):
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
- # TODO(tibell): The C++ implementations actually allows assignment
- # of unknown enum values to *scalar* fields (but not repeated
- # fields). Once checked enum fields becomes the default in the
- # Python implementation, the C++ implementation should follow suit.
def testAssignInvalidEnum(self):
- """It should not be possible to assign an invalid enum number to an
- enum field."""
+ """Assigning an invalid enum number is not allowed in proto2."""
m = unittest_pb2.TestAllTypes()
+ # Proto2 can not assign unknown enum.
with self.assertRaises(ValueError) as _:
m.optional_nested_enum = 1234567
self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
+ # Assignment is a different code path than append for the C++ impl.
+ m.repeated_nested_enum.append(2)
+ m.repeated_nested_enum[0] = 2
+ with self.assertRaises(ValueError):
+ m.repeated_nested_enum[0] = 123456
+
+ # Unknown enum value can be parsed but is ignored.
+ m2 = unittest_proto3_arena_pb2.TestAllTypes()
+ m2.optional_nested_enum = 1234567
+ m2.repeated_nested_enum.append(7654321)
+ serialized = m2.SerializeToString()
+
+ m3 = unittest_pb2.TestAllTypes()
+ m3.ParseFromString(serialized)
+ self.assertFalse(m3.HasField('optional_nested_enum'))
+ # 1 is the default value for optional_nested_enum.
+ self.assertEqual(1, m3.optional_nested_enum)
+ self.assertEqual(0, len(m3.repeated_nested_enum))
+ m2.Clear()
+ m2.ParseFromString(m3.SerializeToString())
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
+ def testUnknownEnumMap(self):
+ m = map_proto2_unittest_pb2.TestEnumMap()
+ m.known_map_field[123] = 0
+ with self.assertRaises(ValueError):
+ m.unknown_map_field[1] = 123
+
+ def testExtensionsErrors(self):
+ msg = unittest_pb2.TestAllTypes()
+ self.assertRaises(AttributeError, getattr, msg, 'Extensions')
def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message')
@@ -1108,6 +1262,7 @@ class Proto2Test(unittest.TestCase):
optional_bytes=b'x',
optionalgroup={'a': 400},
optional_nested_message={'bb': 500},
+ optional_foreign_message={},
optional_nested_enum='BAZ',
repeatedgroup=[{'a': 600},
{'a': 700}],
@@ -1120,8 +1275,12 @@ class Proto2Test(unittest.TestCase):
self.assertEqual(300.5, message.optional_float)
self.assertEqual(b'x', message.optional_bytes)
self.assertEqual(400, message.optionalgroup.a)
- self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
+ self.assertIsInstance(message.optional_nested_message,
+ unittest_pb2.TestAllTypes.NestedMessage)
self.assertEqual(500, message.optional_nested_message.bb)
+ self.assertTrue(message.HasField('optional_foreign_message'))
+ self.assertEqual(message.optional_foreign_message,
+ unittest_pb2.ForeignMessage())
self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.optional_nested_enum)
self.assertEqual(2, len(message.repeatedgroup))
@@ -1157,8 +1316,9 @@ class Proto2Test(unittest.TestCase):
unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+
# Class to test proto3-only features/behavior (updated field presence & enums)
-class Proto3Test(unittest.TestCase):
+class Proto3Test(BaseTestCase):
# Utility method for comparing equality with a map.
def assertMapIterEquals(self, map_iter, dict_value):
@@ -1232,6 +1392,7 @@ class Proto3Test(unittest.TestCase):
"""Assigning an unknown enum value is allowed and preserves the value."""
m = unittest_proto3_arena_pb2.TestAllTypes()
+ # Proto3 can assign unknown enums.
m.optional_nested_enum = 1234567
self.assertEqual(1234567, m.optional_nested_enum)
m.repeated_nested_enum.append(22334455)
@@ -1249,7 +1410,7 @@ class Proto3Test(unittest.TestCase):
# Map isn't really a proto3-only feature. But there is no proto2 equivalent
# of google/protobuf/map_unittest.proto right now, so it's not easy to
# test both with the same test like we do for the other proto2/proto3 tests.
- # (google/protobuf/map_protobuf_unittest.proto is very different in the set
+ # (google/protobuf/map_proto2_unittest.proto is very different in the set
# of messages and fields it contains).
def testScalarMapDefaults(self):
msg = map_unittest_pb2.TestMap()
@@ -1259,7 +1420,10 @@ class Proto3Test(unittest.TestCase):
self.assertFalse(-2**33 in msg.map_int64_int64)
self.assertFalse(123 in msg.map_uint32_uint32)
self.assertFalse(2**33 in msg.map_uint64_uint64)
+ self.assertFalse(123 in msg.map_int32_double)
+ self.assertFalse(False in msg.map_bool_bool)
self.assertFalse('abc' in msg.map_string_string)
+ self.assertFalse(111 in msg.map_int32_bytes)
self.assertFalse(888 in msg.map_int32_enum)
# Accessing an unset key returns the default.
@@ -1267,7 +1431,12 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(0, msg.map_int64_int64[-2**33])
self.assertEqual(0, msg.map_uint32_uint32[123])
self.assertEqual(0, msg.map_uint64_uint64[2**33])
+ self.assertEqual(0.0, msg.map_int32_double[123])
+ self.assertTrue(isinstance(msg.map_int32_double[123], float))
+ self.assertEqual(False, msg.map_bool_bool[False])
+ self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
self.assertEqual('', msg.map_string_string['abc'])
+ self.assertEqual(b'', msg.map_int32_bytes[111])
self.assertEqual(0, msg.map_int32_enum[888])
# It also sets the value in the map
@@ -1275,7 +1444,10 @@ class Proto3Test(unittest.TestCase):
self.assertTrue(-2**33 in msg.map_int64_int64)
self.assertTrue(123 in msg.map_uint32_uint32)
self.assertTrue(2**33 in msg.map_uint64_uint64)
+ self.assertTrue(123 in msg.map_int32_double)
+ self.assertTrue(False in msg.map_bool_bool)
self.assertTrue('abc' in msg.map_string_string)
+ self.assertTrue(111 in msg.map_int32_bytes)
self.assertTrue(888 in msg.map_int32_enum)
self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
@@ -1299,12 +1471,17 @@ class Proto3Test(unittest.TestCase):
msg.map_int32_int32[5] = 15
self.assertEqual(15, msg.map_int32_int32.get(5))
+ self.assertEqual(15, msg.map_int32_int32.get(5))
+ with self.assertRaises(TypeError):
+ msg.map_int32_int32.get('')
self.assertIsNone(msg.map_int32_foreign_message.get(5))
self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
submsg = msg.map_int32_foreign_message[5]
self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
+ with self.assertRaises(TypeError):
+ msg.map_int32_foreign_message.get('')
def testScalarMap(self):
msg = map_unittest_pb2.TestMap()
@@ -1316,8 +1493,13 @@ class Proto3Test(unittest.TestCase):
msg.map_int64_int64[-2**33] = -2**34
msg.map_uint32_uint32[123] = 456
msg.map_uint64_uint64[2**33] = 2**34
+ msg.map_int32_float[2] = 1.2
+ msg.map_int32_double[1] = 3.3
msg.map_string_string['abc'] = '123'
+ msg.map_bool_bool[True] = True
msg.map_int32_enum[888] = 2
+ # Unknown numeric enum is supported in proto3.
+ msg.map_int32_enum[123] = 456
self.assertEqual([], msg.FindInitializationErrors())
@@ -1351,8 +1533,24 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
self.assertEqual(456, msg2.map_uint32_uint32[123])
self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+ self.assertAlmostEqual(1.2, msg.map_int32_float[2])
+ self.assertEqual(3.3, msg.map_int32_double[1])
self.assertEqual('123', msg2.map_string_string['abc'])
+ self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888])
+ self.assertEqual(456, msg2.map_int32_enum[123])
+ # TODO(jieluo): Add cpp extension support.
+ if api_implementation.Type() == 'python':
+ self.assertEqual('{-123: -456}',
+ str(msg2.map_int32_int32))
+
+ def testMapEntryAlwaysSerialized(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[0] = 0
+ msg.map_string_string[''] = ''
+ self.assertEqual(msg.ByteSize(), 12)
+ self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
+ msg.SerializeToString())
def testStringUnicodeConversionInMap(self):
msg = map_unittest_pb2.TestMap()
@@ -1405,6 +1603,40 @@ class Proto3Test(unittest.TestCase):
self.assertIn(123, msg2.map_int32_foreign_message)
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ # TODO(jieluo): Fix text format for message map.
+ # TODO(jieluo): Add cpp extension support.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(15,
+ len(str(msg2.map_int32_foreign_message)))
+
+ def testNestedMessageMapItemDelete(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_all_types[1].optional_nested_message.bb = 1
+ del msg.map_int32_all_types[1]
+ msg.map_int32_all_types[2].optional_nested_message.bb = 2
+ self.assertEqual(1, len(msg.map_int32_all_types))
+ msg.map_int32_all_types[1].optional_nested_message.bb = 1
+ self.assertEqual(2, len(msg.map_int32_all_types))
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+ keys = [1, 2]
+ # The loop triggers PyErr_Occurred() in c extension.
+ for key in keys:
+ del msg2.map_int32_all_types[key]
+
+ def testMapByteSize(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[1] = 1
+ size = msg.ByteSize()
+ msg.map_int32_int32[1] = 128
+ self.assertEqual(msg.ByteSize(), size + 1)
+
+ msg.map_int32_foreign_message[19].c = 1
+ size = msg.ByteSize()
+ msg.map_int32_foreign_message[19].c = 128
+ self.assertEqual(msg.ByteSize(), size + 1)
def testMergeFrom(self):
msg = map_unittest_pb2.TestMap()
@@ -1418,6 +1650,8 @@ class Proto3Test(unittest.TestCase):
msg2.map_int32_int32[12] = 55
msg2.map_int64_int64[88] = 99
msg2.map_int32_foreign_message[222].c = 15
+ msg2.map_int32_foreign_message[222].d = 20
+ old_map_value = msg2.map_int32_foreign_message[222]
msg2.MergeFrom(msg)
@@ -1427,6 +1661,16 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(99, msg2.map_int64_int64[88])
self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+ self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
+ if api_implementation.Type() != 'cpp':
+ # During the call to MergeFrom(), the C++ implementation will have
+ # deallocated the underlying message, but this is very difficult to detect
+ # properly. The line below is likely to cause a segmentation fault.
+ # With the Python implementation, old_map_value is just 'detached' from
+ # the main message. Using it will not crash of course, but since it still
+ # have a reference to the parent message I'm sure we can find interesting
+ # ways to cause inconsistencies.
+ self.assertEqual(15, old_map_value.c)
# Verify that there is only one entry per key, even though the MergeFrom
# may have internally created multiple entries for a single key in the
@@ -1447,6 +1691,51 @@ class Proto3Test(unittest.TestCase):
del msg2.map_int32_foreign_message[222]
self.assertFalse(222 in msg2.map_int32_foreign_message)
+ with self.assertRaises(TypeError):
+ del msg2.map_int32_foreign_message['']
+
+ def testMapMergeFrom(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[12] = 34
+ msg.map_int32_int32[56] = 78
+ msg.map_int64_int64[22] = 33
+ msg.map_int32_foreign_message[111].c = 5
+ msg.map_int32_foreign_message[222].c = 10
+
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.map_int32_int32[12] = 55
+ msg2.map_int64_int64[88] = 99
+ msg2.map_int32_foreign_message[222].c = 15
+ msg2.map_int32_foreign_message[222].d = 20
+
+ msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
+ self.assertEqual(34, msg2.map_int32_int32[12])
+ self.assertEqual(78, msg2.map_int32_int32[56])
+
+ msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
+ self.assertEqual(33, msg2.map_int64_int64[22])
+ self.assertEqual(99, msg2.map_int64_int64[88])
+
+ msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+ self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+ self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
+
+ def testMergeFromBadType(self):
+ msg = map_unittest_pb2.TestMap()
+ with self.assertRaisesRegexp(
+ TypeError,
+ r'Parameter to MergeFrom\(\) must be instance of same class: expected '
+ r'.*TestMap got int\.'):
+ msg.MergeFrom(1)
+
+ def testCopyFromBadType(self):
+ msg = map_unittest_pb2.TestMap()
+ with self.assertRaisesRegexp(
+ TypeError,
+ r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
+ r'expected .*TestMap got int\.'):
+ msg.CopyFrom(1)
def testIntegerMapWithLongs(self):
msg = map_unittest_pb2.TestMap()
@@ -1565,6 +1854,98 @@ class Proto3Test(unittest.TestCase):
matching_dict = {2: 4, 3: 6, 4: 8}
self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
+ def testPython2Map(self):
+ if sys.version_info < (3,):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+ msg.map_int32_int32[5] = 10
+ map_int32 = msg.map_int32_int32
+ self.assertEqual(4, len(map_int32))
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(msg.SerializeToString())
+
+ def CheckItems(seq, iterator):
+ self.assertEqual(next(iterator), seq[0])
+ self.assertEqual(list(iterator), seq[1:])
+
+ CheckItems(map_int32.items(), map_int32.iteritems())
+ CheckItems(map_int32.keys(), map_int32.iterkeys())
+ CheckItems(map_int32.values(), map_int32.itervalues())
+
+ self.assertEqual(6, map_int32.get(3))
+ self.assertEqual(None, map_int32.get(999))
+ self.assertEqual(6, map_int32.pop(3))
+ self.assertEqual(0, map_int32.pop(3))
+ self.assertEqual(3, len(map_int32))
+ key, value = map_int32.popitem()
+ self.assertEqual(2 * key, value)
+ self.assertEqual(2, len(map_int32))
+ map_int32.clear()
+ self.assertEqual(0, len(map_int32))
+
+ with self.assertRaises(KeyError):
+ map_int32.popitem()
+
+ self.assertEqual(0, map_int32.setdefault(2))
+ self.assertEqual(1, len(map_int32))
+
+ map_int32.update(msg2.map_int32_int32)
+ self.assertEqual(4, len(map_int32))
+
+ with self.assertRaises(TypeError):
+ map_int32.update(msg2.map_int32_int32,
+ msg2.map_int32_int32)
+ with self.assertRaises(TypeError):
+ map_int32.update(0)
+ with self.assertRaises(TypeError):
+ map_int32.update(value=12)
+
+ def testMapItems(self):
+ # Map items used to have strange behaviors when use c extension. Because
+ # [] may reorder the map and invalidate any exsting iterators.
+ # TODO(jieluo): Check if [] reordering the map is a bug or intended
+ # behavior.
+ msg = map_unittest_pb2.TestMap()
+ msg.map_string_string['local_init_op'] = ''
+ msg.map_string_string['trainable_variables'] = ''
+ msg.map_string_string['variables'] = ''
+ msg.map_string_string['init_op'] = ''
+ msg.map_string_string['summaries'] = ''
+ items1 = msg.map_string_string.items()
+ items2 = msg.map_string_string.items()
+ self.assertEqual(items1, items2)
+
+ def testMapDeterministicSerialization(self):
+ golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
+ b'r\n\n\x05item1\x12\x01e'
+ b'r\n\n\x05item2\x12\x01f'
+ b'r\n\n\x05item3\x12\x01g'
+ b'r\x0b\n\x05item4\x12\x02QQ'
+ b'r\x12\n\rlocal_init_op\x12\x01a'
+ b'r\x0e\n\tsummaries\x12\x01e'
+ b'r\x18\n\x13trainable_variables\x12\x01b'
+ b'r\x0e\n\tvariables\x12\x01c')
+ msg = map_unittest_pb2.TestMap()
+ msg.map_string_string['local_init_op'] = 'a'
+ msg.map_string_string['trainable_variables'] = 'b'
+ msg.map_string_string['variables'] = 'c'
+ msg.map_string_string['init_op'] = 'd'
+ msg.map_string_string['summaries'] = 'e'
+ msg.map_string_string['item1'] = 'e'
+ msg.map_string_string['item2'] = 'f'
+ msg.map_string_string['item3'] = 'g'
+ msg.map_string_string['item4'] = 'QQ'
+
+ # If deterministic serialization is not working correctly, this will be
+ # "flaky" depending on the exact python dict hash seed.
+ #
+ # Fortunately, there are enough items in this map that it is extremely
+ # unlikely to ever hit the "right" in-order combination, so the test
+ # itself should fail reliably.
+ self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
+
def testMapIterationClearMessage(self):
# Iterator needs to work even if message and map are deleted.
msg = map_unittest_pb2.TestMap()
@@ -1651,6 +2032,9 @@ class Proto3Test(unittest.TestCase):
del msg.map_int32_int32[4]
self.assertEqual(0, len(msg.map_int32_int32))
+ with self.assertRaises(KeyError):
+ del msg.map_int32_all_types[32]
+
def testMapsAreMapping(self):
msg = map_unittest_pb2.TestMap()
self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
@@ -1659,6 +2043,14 @@ class Proto3Test(unittest.TestCase):
self.assertIsInstance(msg.map_int32_foreign_message,
collections.MutableMapping)
+ def testMapsCompare(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[-123] = -456
+ self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
+ self.assertEqual(msg.map_int32_foreign_message,
+ msg.map_int32_foreign_message)
+ self.assertNotEqual(msg.map_int32_int32, 0)
+
def testMapFindInitializationErrorsSmokeTest(self):
msg = map_unittest_pb2.TestMap()
msg.map_string_string['abc'] = '123'
@@ -1666,40 +2058,9 @@ class Proto3Test(unittest.TestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
- def testAnyMessage(self):
- # Creates and sets message.
- msg = any_test_pb2.TestAny()
- msg_descriptor = msg.DESCRIPTOR
- all_types = unittest_pb2.TestAllTypes()
- all_descriptor = all_types.DESCRIPTOR
- all_types.repeated_string.append(u'\u00fc\ua71f')
- # Packs to Any.
- msg.value.Pack(all_types)
- self.assertEqual(msg.value.type_url,
- 'type.googleapis.com/%s' % all_descriptor.full_name)
- self.assertEqual(msg.value.value,
- all_types.SerializeToString())
- # Tests Is() method.
- self.assertTrue(msg.value.Is(all_descriptor))
- self.assertFalse(msg.value.Is(msg_descriptor))
- # Unpacks Any.
- unpacked_message = unittest_pb2.TestAllTypes()
- self.assertTrue(msg.value.Unpack(unpacked_message))
- self.assertEqual(all_types, unpacked_message)
- # Unpacks to different type.
- self.assertFalse(msg.value.Unpack(msg))
- # Only Any messages have Pack method.
- try:
- msg.Pack(all_types)
- except AttributeError:
- pass
- else:
- raise AttributeError('%s should not have Pack method.' %
- msg_descriptor.full_name)
-
-class ValidTypeNamesTest(unittest.TestCase):
+class ValidTypeNamesTest(BaseTestCase):
def assertImportFromName(self, msg, base_name):
# Parse <type 'module.class_name'> to extra 'some.name' as a string.
@@ -1720,7 +2081,7 @@ class ValidTypeNamesTest(unittest.TestCase):
self.assertImportFromName(pb.repeated_int32, 'Scalar')
self.assertImportFromName(pb.repeated_nested_message, 'Composite')
-class PackedFieldTest(unittest.TestCase):
+class PackedFieldTest(BaseTestCase):
def setMessage(self, message):
message.repeated_int32.append(1)
@@ -1776,5 +2137,67 @@ class PackedFieldTest(unittest.TestCase):
b'\x70\x01')
self.assertEqual(golden_data, message.SerializeToString())
+
+@unittest.skipIf(api_implementation.Type() != 'cpp' or
+ sys.version_info < (2, 7),
+ 'explicit tests of the C++ implementation for PY27 and above')
+class OversizeProtosTest(BaseTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # At the moment, reference cycles between DescriptorPool and Message classes
+ # are not detected and these objects are never freed.
+ # To avoid errors with ReferenceLeakChecker, we create the class only once.
+ file_desc = """
+ name: "f/f.msg2"
+ package: "f"
+ message_type {
+ name: "msg1"
+ field {
+ name: "payload"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_STRING
+ }
+ }
+ message_type {
+ name: "msg2"
+ field {
+ name: "field"
+ number: 1
+ label: LABEL_OPTIONAL
+ type: TYPE_MESSAGE
+ type_name: "msg1"
+ }
+ }
+ """
+ pool = descriptor_pool.DescriptorPool()
+ desc = descriptor_pb2.FileDescriptorProto()
+ text_format.Parse(file_desc, desc)
+ pool.Add(desc)
+ cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
+ pool.FindMessageTypeByName('f.msg2'))
+
+ def setUp(self):
+ self.p = self.proto_cls()
+ self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
+ self.p_serialized = self.p.SerializeToString()
+
+ def testAssertOversizeProto(self):
+ from google.protobuf.pyext._message import SetAllowOversizeProtos
+ SetAllowOversizeProtos(False)
+ q = self.proto_cls()
+ try:
+ q.ParseFromString(self.p_serialized)
+ except message.DecodeError as e:
+ self.assertEqual(str(e), 'Error parsing message')
+
+ def testSucceedOversizeProto(self):
+ from google.protobuf.pyext._message import SetAllowOversizeProtos
+ SetAllowOversizeProtos(True)
+ q = self.proto_cls()
+ q.ParseFromString(self.p_serialized)
+ self.assertEqual(self.p.field.payload, q.field.payload)
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto
index 11f85ef6..98fcbcb6 100644
--- a/python/google/protobuf/internal/more_extensions_dynamic.proto
+++ b/python/google/protobuf/internal/more_extensions_dynamic.proto
@@ -47,4 +47,5 @@ message DynamicMessageType {
extend ExtendedMessage {
optional int32 dynamic_int32_extension = 100;
optional DynamicMessageType dynamic_message_extension = 101;
+ repeated DynamicMessageType repeated_dynamic_message_extension = 102;
}
diff --git a/python/google/protobuf/internal/no_package.proto b/python/google/protobuf/internal/no_package.proto
new file mode 100644
index 00000000..3546dcc3
--- /dev/null
+++ b/python/google/protobuf/internal/no_package.proto
@@ -0,0 +1,10 @@
+syntax = "proto2";
+
+enum NoPackageEnum {
+ NO_PACKAGE_VALUE_0 = 0;
+ NO_PACKAGE_VALUE_1 = 1;
+}
+
+message NoPackageMessage {
+ optional NoPackageEnum no_package_enum = 1;
+}
diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py
index 822ad895..36dfbfde 100644
--- a/python/google/protobuf/internal/proto_builder_test.py
+++ b/python/google/protobuf/internal/proto_builder_test.py
@@ -40,6 +40,7 @@ try:
import unittest2 as unittest
except ImportError:
import unittest
+
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool
from google.protobuf import proto_builder
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 87f60666..975e3b4d 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -51,14 +51,14 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
from io import BytesIO
-import sys
import struct
+import sys
import weakref
import six
-import six.moves.copyreg as copyreg
# We use "as" to avoid name collisions with variables.
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import containers
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
@@ -69,7 +69,6 @@ from google.protobuf.internal import well_known_types
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
-from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -91,16 +90,12 @@ class GeneratedProtocolMessageType(type):
classes at runtime, as in this example:
mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
+ factory = symbol_database.Default()
+ factory.pool.AddDescriptor(mydescriptor)
+ MyProtoClass = factory.GetPrototype(mydescriptor)
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
-
- The above example will not work for nested types. If you wish to include them,
- use reflection.MakeClass() instead of manually instantiating the class in
- order to create the appropriate class structure.
"""
# Must be consistent with the protocol-compiler code in
@@ -157,12 +152,10 @@ class GeneratedProtocolMessageType(type):
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
+ decoder.MessageSetItemDecoder(descriptor), None)
# Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields:
@@ -176,7 +169,6 @@ class GeneratedProtocolMessageType(type):
_AddStaticMethods(cls)
_AddMessageMethods(descriptor, cls)
_AddPrivateHelperMethods(descriptor, cls)
- copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
superclass = super(GeneratedProtocolMessageType, cls)
superclass.__init__(name, bases, dictionary)
@@ -297,7 +289,8 @@ def _AttachFieldHelpers(cls, field_descriptor):
if is_map_entry:
field_encoder = encoder.MapEncoder(field_descriptor)
- sizer = encoder.MapSizer(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor,
+ _IsMessageMapField(field_descriptor))
elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
@@ -378,13 +371,15 @@ def _GetInitializeDefaultForMap(field):
if _IsMessageMapField(field):
def MakeMessageMapDefault(message):
return containers.MessageMap(
- message._listener_for_children, value_field.message_type, key_checker)
+ message._listener_for_children, value_field.message_type, key_checker,
+ field.message_type)
return MakeMessageMapDefault
else:
value_checker = type_checkers.GetTypeChecker(value_field)
def MakePrimitiveMapDefault(message):
return containers.ScalarMap(
- message._listener_for_children, key_checker, value_checker)
+ message._listener_for_children, key_checker, value_checker,
+ field.message_type)
return MakePrimitiveMapDefault
def _DefaultValueConstructorForField(field):
@@ -490,6 +485,9 @@ def _AddInitMethod(message_descriptor, cls):
if field is None:
raise TypeError("%s() got an unexpected keyword argument '%s'" %
(message_descriptor.name, field_name))
+ if field_value is None:
+ # field=None is the same as no field at all.
+ continue
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
@@ -737,32 +735,21 @@ def _AddPropertiesForExtensions(descriptor, cls):
constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number)
+ # TODO(amauryfa): Migrate all users of these attributes to functions like
+ # pool.FindExtensionByNumber(descriptor).
+ if descriptor.file is not None:
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ pool = descriptor.file.pool
+ cls._extensions_by_number = pool._extensions_by_number[descriptor]
+ cls._extensions_by_name = pool._extensions_by_name[descriptor]
def _AddStaticMethods(cls):
# TODO(robinson): This probably needs to be thread-safe(?)
def RegisterExtension(extension_handle):
extension_handle.containing_type = cls.DESCRIPTOR
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle)
_AttachFieldHelpers(cls, extension_handle)
-
- # Try to insert our extension, failing if an extension with the same number
- # already exists.
- actual_handle = cls._extensions_by_number.setdefault(
- extension_handle.number, extension_handle)
- if actual_handle is not extension_handle:
- raise AssertionError(
- 'Extensions "%s" and "%s" both try to extend message type "%s" with '
- 'field number %d.' %
- (extension_handle.full_name, actual_handle.full_name,
- cls.DESCRIPTOR.full_name, extension_handle.number))
-
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- handle = extension_handle # avoid line wrapping
- if _IsMessageSetExtension(handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
-
cls.RegisterExtension = staticmethod(RegisterExtension)
def FromString(s):
@@ -889,17 +876,6 @@ def _AddClearExtensionMethod(cls):
cls.ClearExtension = ClearExtension
-def _AddClearMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def Clear(self):
- # Clear fields.
- self._fields = {}
- self._unknown_fields = ()
- self._oneofs = {}
- self._Modified()
- cls.Clear = Clear
-
-
def _AddHasExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def HasExtension(self, extension_handle):
@@ -917,7 +893,7 @@ def _AddHasExtensionMethod(cls):
def _InternalUnpackAny(msg):
"""Unpacks Any message and returns the unpacked message.
- This internal method is differnt from public Any Unpack method which takes
+ This internal method is different from public Any Unpack method which takes
the target message as argument. _InternalUnpackAny method does not have
target message type and need to find the message type in descriptor pool.
@@ -927,26 +903,33 @@ def _InternalUnpackAny(msg):
Returns:
The unpacked message.
"""
+ # TODO(amauryfa): Don't use the factory of generated messages.
+ # To make Any work with custom factories, use the message factory of the
+ # parent message.
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import symbol_database
+ factory = symbol_database.Default()
+
type_url = msg.type_url
- db = symbol_database.Default()
if not type_url:
return None
# TODO(haberman): For now we just strip the hostname. Better logic will be
# required.
- type_name = type_url.split("/")[-1]
- descriptor = db.pool.FindMessageTypeByName(type_name)
+ type_name = type_url.split('/')[-1]
+ descriptor = factory.pool.FindMessageTypeByName(type_name)
if descriptor is None:
return None
- message_class = db.GetPrototype(descriptor)
+ message_class = factory.GetPrototype(descriptor)
message = message_class()
message.ParseFromString(msg.value)
return message
+
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def __eq__(self, other):
@@ -999,16 +982,6 @@ def _AddUnicodeMethod(unused_message_descriptor, cls):
cls.__unicode__ = __unicode__
-def _AddSetListenerMethod(cls):
- """Helper for _AddMessageMethods()."""
- def SetListener(self, listener):
- if listener is None:
- self._listener = message_listener_mod.NullMessageListener()
- else:
- self._listener = listener
- cls._SetListener = SetListener
-
-
def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
@@ -1037,11 +1010,16 @@ def _AddByteSizeMethod(message_descriptor, cls):
return self._cached_byte_size
size = 0
- for field_descriptor, field_value in self.ListFields():
- size += field_descriptor._sizer(field_value)
-
- for tag_bytes, value_bytes in self._unknown_fields:
- size += len(tag_bytes) + len(value_bytes)
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ size = descriptor.fields_by_name['key']._sizer(self.key)
+ size += descriptor.fields_by_name['value']._sizer(self.value)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ size += field_descriptor._sizer(field_value)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ size += len(tag_bytes) + len(value_bytes)
self._cached_byte_size = size
self._cached_byte_size_dirty = False
@@ -1054,32 +1032,46 @@ def _AddByteSizeMethod(message_descriptor, cls):
def _AddSerializeToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializeToString(self):
+ def SerializeToString(self, **kwargs):
# Check if the message has all of its required fields set.
errors = []
if not self.IsInitialized():
raise message_mod.EncodeError(
'Message %s is missing required fields: %s' % (
self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
- return self.SerializePartialToString()
+ return self.SerializePartialToString(**kwargs)
cls.SerializeToString = SerializeToString
def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializePartialToString(self):
+ def SerializePartialToString(self, **kwargs):
out = BytesIO()
- self._InternalSerialize(out.write)
+ self._InternalSerialize(out.write, **kwargs)
return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
- def InternalSerialize(self, write_bytes):
- for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value)
- for tag_bytes, value_bytes in self._unknown_fields:
- write_bytes(tag_bytes)
- write_bytes(value_bytes)
+ def InternalSerialize(self, write_bytes, deterministic=None):
+ if deterministic is None:
+ deterministic = (
+ api_implementation.IsPythonDefaultSerializationDeterministic())
+ else:
+ deterministic = bool(deterministic)
+
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ descriptor.fields_by_name['key']._encoder(
+ write_bytes, self.key, deterministic)
+ descriptor.fields_by_name['value']._encoder(
+ write_bytes, self.value, deterministic)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ field_descriptor._encoder(write_bytes, field_value, deterministic)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ write_bytes(tag_bytes)
+ write_bytes(value_bytes)
cls._InternalSerialize = InternalSerialize
@@ -1117,7 +1109,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not is_proto3:
+ if (not is_proto3 or
+ api_implementation.GetPythonProto3PreserveUnknownsDefault()):
if not unknown_field_list:
unknown_field_list = self._unknown_fields = []
unknown_field_list.append(
@@ -1234,7 +1227,7 @@ def _AddMergeFromMethod(cls):
if not isinstance(msg, cls):
raise TypeError(
"Parameter to MergeFrom() must be instance of same class: "
- "expected %s got %s." % (cls.__name__, type(msg).__name__))
+ 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
assert msg is not self
self._Modified()
@@ -1288,6 +1281,38 @@ def _AddWhichOneofMethod(message_descriptor, cls):
cls.WhichOneof = WhichOneof
+def _AddReduceMethod(cls):
+ def __reduce__(self): # pylint: disable=invalid-name
+ return (type(self), (), self.__getstate__())
+ cls.__reduce__ = __reduce__
+
+
+def _Clear(self):
+ # Clear fields.
+ self._fields = {}
+ self._unknown_fields = ()
+ self._oneofs = {}
+ self._Modified()
+
+
+def _DiscardUnknownFields(self):
+ self._unknown_fields = []
+ for field, value in self.ListFields():
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for sub_message in value:
+ sub_message.DiscardUnknownFields()
+ else:
+ value.DiscardUnknownFields()
+
+
+def _SetListener(self, listener):
+ if listener is None:
+ self._listener = message_listener_mod.NullMessageListener()
+ else:
+ self._listener = listener
+
+
def _AddMessageMethods(message_descriptor, cls):
"""Adds implementations of all Message methods to cls."""
_AddListFieldsMethod(message_descriptor, cls)
@@ -1296,12 +1321,10 @@ def _AddMessageMethods(message_descriptor, cls):
if message_descriptor.is_extendable:
_AddClearExtensionMethod(cls)
_AddHasExtensionMethod(cls)
- _AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
_AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
- _AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
@@ -1309,6 +1332,11 @@ def _AddMessageMethods(message_descriptor, cls):
_AddIsInitializedMethod(message_descriptor, cls)
_AddMergeFromMethod(cls)
_AddWhichOneofMethod(message_descriptor, cls)
+ _AddReduceMethod(cls)
+ # Adds methods which do not depend on cls.
+ cls.Clear = _Clear
+ cls.DiscardUnknownFields = _DiscardUnknownFields
+ cls._SetListener = _SetListener
def _AddPrivateHelperMethods(message_descriptor, cls):
@@ -1518,3 +1546,14 @@ class _ExtensionDict(object):
Extension field descriptor.
"""
return self._extended_message._extensions_by_name.get(name, None)
+
+ def _FindExtensionByNumber(self, number):
+ """Tries to find a known extension with the field number.
+
+ Args:
+ number: Extension field number.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extended_message._extensions_by_number.get(number, None)
diff --git a/python/google/protobuf/internal/python_protobuf.cc b/python/google/protobuf/internal/python_protobuf.cc
new file mode 100644
index 00000000..f90cc438
--- /dev/null
+++ b/python/google/protobuf/internal/python_protobuf.cc
@@ -0,0 +1,63 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: qrczak@google.com (Marcin Kowalczyk)
+
+#include <google/protobuf/python/python_protobuf.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+static const Message* GetCProtoInsidePyProtoStub(PyObject* msg) {
+ return NULL;
+}
+static Message* MutableCProtoInsidePyProtoStub(PyObject* msg) {
+ return NULL;
+}
+
+// This is initialized with a default, stub implementation.
+// If python-google.protobuf.cc is loaded, the function pointer is overridden
+// with a full implementation.
+const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg) =
+ GetCProtoInsidePyProtoStub;
+Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg) =
+ MutableCProtoInsidePyProtoStub;
+
+const Message* GetCProtoInsidePyProto(PyObject* msg) {
+ return GetCProtoInsidePyProtoPtr(msg);
+}
+Message* MutableCProtoInsidePyProto(PyObject* msg) {
+ return MutableCProtoInsidePyProtoPtr(msg);
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 752f2f5d..0306ff46 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -42,9 +42,10 @@ import six
import struct
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -59,9 +60,13 @@ from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import wire_format
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import decoder
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
class _MiniDecoder(object):
"""Decodes a stream of values from a string.
@@ -94,12 +99,12 @@ class _MiniDecoder(object):
return wire_format.UnpackTag(self.ReadVarint())
def ReadFloat(self):
- result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+ result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
self._pos += 4
return result
def ReadDouble(self):
- result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+ result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
self._pos += 8
return result
@@ -107,7 +112,7 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(unittest.TestCase):
+class ReflectionTest(BaseTestCase):
def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
@@ -119,11 +124,13 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestAllTypes(
optional_int32=24,
optional_double=54.321,
- optional_string='optional_string')
+ optional_string='optional_string',
+ optional_float=None)
self.assertEqual(24, proto.optional_int32)
self.assertEqual(54.321, proto.optional_double)
self.assertEqual('optional_string', proto.optional_string)
+ self.assertFalse(proto.HasField("optional_float"))
def testRepeatedScalarConstructor(self):
# Constructor with only repeated scalar types should succeed.
@@ -131,12 +138,14 @@ class ReflectionTest(unittest.TestCase):
repeated_int32=[1, 2, 3, 4],
repeated_double=[1.23, 54.321],
repeated_bool=[True, False, False],
- repeated_string=["optional_string"])
+ repeated_string=["optional_string"],
+ repeated_float=None)
self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
self.assertEqual([1.23, 54.321], list(proto.repeated_double))
self.assertEqual([True, False, False], list(proto.repeated_bool))
self.assertEqual(["optional_string"], list(proto.repeated_string))
+ self.assertEqual([], list(proto.repeated_float))
def testRepeatedCompositeConstructor(self):
# Constructor with only repeated composite types should succeed.
@@ -187,7 +196,8 @@ class ReflectionTest(unittest.TestCase):
repeated_foreign_message=[
unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)])
+ unittest_pb2.ForeignMessage(c=12)],
+ optional_nested_message=None)
self.assertEqual(24, proto.optional_int32)
self.assertEqual('optional_string', proto.optional_string)
@@ -204,6 +214,7 @@ class ReflectionTest(unittest.TestCase):
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
+ self.assertFalse(proto.HasField("optional_nested_message"))
def testConstructorTypeError(self):
self.assertRaises(
@@ -609,10 +620,24 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+ self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
+ self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
+ self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
+ # TODO(jieluo): Fix type checking difference for python and c extension
+ if api_implementation.Type() == 'python':
+ self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
+ else:
+ proto.optional_bool = 1.1
- def testIntegerTypes(self):
+ def assertIntegerTypes(self, integer_fn):
+ """Verifies setting of scalar integers.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestGetAndDeserialize(field_name, value, expected_type):
proto = unittest_pb2.TestAllTypes()
+ value = integer_fn(value)
setattr(proto, field_name, value)
self.assertIsInstance(getattr(proto, field_name), expected_type)
proto2 = unittest_pb2.TestAllTypes()
@@ -624,12 +649,12 @@ class ReflectionTest(unittest.TestCase):
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
try:
integer_64 = long
- except NameError: # Python3
+ except NameError: # Python3
integer_64 = int
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
# in an int.
- TestGetAndDeserialize('optional_uint32', 1 << 31, long)
+ TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
else:
# 64-bit python can fit uint32 inside an int
TestGetAndDeserialize('optional_uint32', 1 << 31, int)
@@ -638,25 +663,62 @@ class ReflectionTest(unittest.TestCase):
TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
- def testSingleScalarBoundsChecking(self):
+ def testIntegerTypes(self):
+ self.assertIntegerTypes(lambda x: x)
+
+ def testNonStandardIntegerTypes(self):
+ self.assertIntegerTypes(test_util.NonStandardInteger)
+
+ def testIllegalValuesForIntegers(self):
+ pb = unittest_pb2.TestAllTypes()
+
+ # Strings are illegal, even when the represent an integer.
+ with self.assertRaises(TypeError):
+ pb.optional_uint64 = '2'
+
+ # The exact error should propagate with a poorly written custom integer.
+ with self.assertRaisesRegexp(RuntimeError, 'my_error'):
+ pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
+
+ def assetIntegerBoundsChecking(self, integer_fn):
+ """Verifies bounds checking for scalar integer fields.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
+ expected_min = integer_fn(expected_min)
+ expected_max = integer_fn(expected_max)
setattr(pb, field_name, expected_min)
self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
self.assertEqual(expected_max, getattr(pb, field_name))
- self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
- self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
+ self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
+ expected_min - 1)
+ self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
+ expected_max + 1)
TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
+ # A bit of white-box testing since -1 is an int and not a long in C++ and
+ # so goes down a different path.
+ pb = unittest_pb2.TestAllTypes()
+ with self.assertRaises((ValueError, TypeError)):
+ pb.optional_uint64 = integer_fn(-(1 << 63))
pb = unittest_pb2.TestAllTypes()
- pb.optional_nested_enum = 1
+ pb.optional_nested_enum = integer_fn(1)
self.assertEqual(1, pb.optional_nested_enum)
+ def testSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(lambda x: x)
+
+ def testNonStandardSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
+
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
@@ -668,6 +730,12 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_int32[0] = 23
self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
+ self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
+ self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
+ 'index', 23)
+
+ proto.repeated_string.append('2')
+ self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
# Repeated enums tests.
#proto.repeated_nested_enum.append(0)
@@ -955,6 +1023,14 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(4, len(proto.repeated_nested_message))
self.assertEqual(n1, proto.repeated_nested_message[2])
self.assertEqual(n2, proto.repeated_nested_message[3])
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend, n1)
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend, [0])
+ wrong_message_type = unittest_pb2.TestAllTypes()
+ self.assertRaises(TypeError,
+ proto.repeated_nested_message.extend,
+ [wrong_message_type])
# Test clearing.
proto.ClearField('repeated_nested_message')
@@ -965,6 +1041,9 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_nested_message.add(bb=23)
self.assertEqual(1, len(proto.repeated_nested_message))
self.assertEqual(23, proto.repeated_nested_message[0].bb)
+ self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
+ with self.assertRaises(Exception):
+ proto.repeated_nested_message[0] = 23
def testRepeatedCompositeRemove(self):
proto = unittest_pb2.TestAllTypes()
@@ -1175,12 +1254,18 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(not extendee_proto.HasExtension(extension))
def testRegisteredExtensions(self):
- self.assertTrue('protobuf_unittest.optional_int32_extension' in
- unittest_pb2.TestAllExtensions._extensions_by_name)
- self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ pool = unittest_pb2.DESCRIPTOR.pool
+ self.assertTrue(
+ pool.FindExtensionByNumber(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
+ self.assertIs(
+ pool.FindExtensionByName(
+ 'protobuf_unittest.optional_int32_extension').containing_type,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR)
# Make sure extensions haven't been registered into types that shouldn't
# have any.
- self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+ self.assertEqual(0, len(
+ pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
@@ -1492,7 +1577,14 @@ class ReflectionTest(unittest.TestCase):
container = copy.deepcopy(proto1.repeated_int32)
self.assertEqual([2, 3], container)
- # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
+ message1 = proto1.repeated_nested_message.add()
+ message1.bb = 1
+ messages = copy.deepcopy(proto1.repeated_nested_message)
+ self.assertEqual(proto1.repeated_nested_message, messages)
+ message1.bb = 2
+ self.assertNotEqual(proto1.repeated_nested_message, messages)
+
+ # TODO(anuraag): Implement deepcopy for extension dict
def testClear(self):
proto = unittest_pb2.TestAllTypes()
@@ -1544,6 +1636,20 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.HasField('optional_foreign_message'))
self.assertEqual(0, proto.optional_foreign_message.c)
+ def testDisconnectingInOneof(self):
+ m = unittest_pb2.TestOneof2() # This message has two messages in a oneof.
+ m.foo_message.qux_int = 5
+ sub_message = m.foo_message
+ # Accessing another message's field does not clear the first one
+ self.assertEqual(m.foo_lazy_message.qux_int, 0)
+ self.assertEqual(m.foo_message.qux_int, 5)
+ # But mutating another message in the oneof detaches the first one.
+ m.foo_lazy_message.qux_int = 6
+ self.assertEqual(m.foo_message.qux_int, 0)
+ # The reference we got above was detached and is still valid.
+ self.assertEqual(sub_message.qux_int, 5)
+ sub_message.qux_int = 7
+
def testOneOf(self):
proto = unittest_pb2.TestAllTypes()
proto.oneof_uint32 = 10
@@ -1562,8 +1668,11 @@ class ReflectionTest(unittest.TestCase):
proto.SerializeToString()
proto.SerializePartialToString()
- def assertNotInitialized(self, proto):
+ def assertNotInitialized(self, proto, error_size=None):
+ errors = []
self.assertFalse(proto.IsInitialized())
+ self.assertFalse(proto.IsInitialized(errors))
+ self.assertEqual(error_size, len(errors))
self.assertRaises(message.EncodeError, proto.SerializeToString)
# "Partial" serialization doesn't care if message is uninitialized.
proto.SerializePartialToString()
@@ -1577,7 +1686,7 @@ class ReflectionTest(unittest.TestCase):
# The case of uninitialized required fields.
proto = unittest_pb2.TestRequired()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
proto.a = proto.b = proto.c = 2
self.assertInitialized(proto)
@@ -1585,14 +1694,14 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestRequiredForeign()
self.assertInitialized(proto)
proto.optional_message.a = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 2)
proto.optional_message.b = 0
proto.optional_message.c = 0
self.assertInitialized(proto)
# Uninitialized repeated submessage.
message1 = proto.repeated_message.add()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
message1.a = message1.b = message1.c = 0
self.assertInitialized(proto)
@@ -1601,11 +1710,11 @@ class ReflectionTest(unittest.TestCase):
extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add()
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 6)
message1.a = 1
message1.b = 1
message1.c = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 3)
message2.a = 2
message2.b = 2
message2.c = 2
@@ -1615,7 +1724,7 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1
- self.assertNotInitialized(proto)
+ self.assertNotInitialized(proto, 2)
proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3
self.assertInitialized(proto)
@@ -1802,7 +1911,7 @@ class ReflectionTest(unittest.TestCase):
# into separate TestCase classes.
-class TestAllTypesEqualityTest(unittest.TestCase):
+class TestAllTypesEqualityTest(BaseTestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
@@ -1818,7 +1927,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(unittest.TestCase):
+class FullProtosEqualityTest(BaseTestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1904,7 +2013,7 @@ class FullProtosEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(unittest.TestCase):
+class ExtensionEqualityTest(BaseTestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1937,7 +2046,7 @@ class ExtensionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(unittest.TestCase):
+class MutualRecursionEqualityTest(BaseTestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1949,7 +2058,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(unittest.TestCase):
+class ByteSizeTest(BaseTestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -2074,6 +2183,8 @@ class ByteSizeTest(unittest.TestCase):
foreign_message_1 = self.proto.repeated_nested_message.add()
foreign_message_1.bb = 9
self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
+ repeated_nested_message = copy.deepcopy(
+ self.proto.repeated_nested_message)
# 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
del self.proto.repeated_nested_message[0]
@@ -2094,6 +2205,16 @@ class ByteSizeTest(unittest.TestCase):
del self.proto.repeated_nested_message[0]
self.assertEqual(0, self.Size())
+ self.assertEqual(2, len(repeated_nested_message))
+ del repeated_nested_message[0:1]
+ # TODO(jieluo): Fix cpp extension bug when delete repeated message.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(1, len(repeated_nested_message))
+ del repeated_nested_message[-1]
+ # TODO(jieluo): Fix cpp extension bug when delete repeated message.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(0, len(repeated_nested_message))
+
def testRepeatedGroups(self):
# 2-byte START_GROUP plus 2-byte END_GROUP.
group_0 = self.proto.repeatedgroup.add()
@@ -2110,6 +2231,10 @@ class ByteSizeTest(unittest.TestCase):
proto.Extensions[extension] = 23
# 1 byte for tag, 1 byte for value.
self.assertEqual(2, proto.ByteSize())
+ field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
+ 'optional_int32']
+ with self.assertRaises(KeyError):
+ proto.Extensions[field] = 23
def testCacheInvalidationForNonrepeatedScalar(self):
# Test non-extension.
@@ -2245,7 +2370,7 @@ class ByteSizeTest(unittest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(unittest.TestCase):
+class SerializationTest(BaseTestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
@@ -2806,7 +2931,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(unittest.TestCase):
+class OptionsTest(BaseTestCase):
def testMessageOptions(self):
proto = message_set_extensions_pb2.TestMessageSet()
@@ -2833,7 +2958,7 @@ class OptionsTest(unittest.TestCase):
-class ClassAPITest(unittest.TestCase):
+class ClassAPITest(BaseTestCase):
@unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
@@ -2916,6 +3041,9 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(file_descriptor_str, file_descriptor)
return file_descriptor.SerializeToString()
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
+ # This test can only run once; the second time, it raises errors about
+ # conflicting message descriptors.
def testParsingFlatClassWithExplicitClassDeclaration(self):
"""Test that the generated class can parse a flat message."""
# TODO(xiaofeng): This test fails with cpp implemetnation in the call
@@ -2940,6 +3068,7 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(msg_str, msg)
self.assertEqual(msg.flat, [0, 1, 2])
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testParsingFlatClass(self):
"""Test that the generated class can parse a flat message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
@@ -2955,6 +3084,7 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(msg_str, msg)
self.assertEqual(msg.flat, [0, 1, 2])
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index 98614b77..77239f44 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -34,10 +34,12 @@
__author__ = 'petar@google.com (Petar Petrov)'
+
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
@@ -80,6 +82,10 @@ class FooUnitTest(unittest.TestCase):
service_descriptor = unittest_pb2.TestService.GetDescriptor()
srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
unittest_pb2.BarRequest(), MyCallback)
+ self.assertTrue(srvc.GetRequestClass(service_descriptor.methods[1]) is
+ unittest_pb2.BarRequest)
+ self.assertTrue(srvc.GetResponseClass(service_descriptor.methods[1]) is
+ unittest_pb2.BarResponse)
self.assertEqual('Method Bar not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
index 0cb935a8..af42681a 100644
--- a/python/google/protobuf/internal/symbol_database_test.py
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -33,31 +33,35 @@
"""Tests for google.protobuf.symbol_database."""
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import unittest_pb2
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
from google.protobuf import symbol_database
+
class SymbolDatabaseTest(unittest.TestCase):
def _Database(self):
- # TODO(b/17734095): Remove this difference when the C++ implementation
- # supports multiple databases.
if descriptor._USE_C_DESCRIPTORS:
- return symbol_database.Default()
+ # The C++ implementation does not allow mixing descriptors from
+ # different pools.
+ db = symbol_database.SymbolDatabase(pool=descriptor_pool.Default())
else:
db = symbol_database.SymbolDatabase()
- # Register representative types from unittest_pb2.
- db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
- db.RegisterMessage(unittest_pb2.TestAllTypes)
- db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
- db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
- db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
- db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
- db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
- return db
+ # Register representative types from unittest_pb2.
+ db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
+ db.RegisterMessage(unittest_pb2.TestAllTypes)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
+ db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
+ db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
+ db.RegisterServiceDescriptor(unittest_pb2._TESTSERVICE)
+ return db
def testGetPrototype(self):
instance = self._Database().GetPrototype(
@@ -106,7 +110,13 @@ class SymbolDatabaseTest(unittest.TestCase):
self._Database().pool.FindMessageTypeByName(
'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
- def testFindFindContainingSymbol(self):
+ def testFindServiceByName(self):
+ self.assertEqual(
+ 'protobuf_unittest.TestService',
+ self._Database().pool.FindServiceByName(
+ 'protobuf_unittest.TestService').full_name)
+
+ def testFindFileContainingSymbol(self):
# Lookup based on either enum or message.
self.assertEqual(
'google/protobuf/unittest.proto',
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index ac88fa81..a6e34ef5 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -36,11 +36,18 @@ This is intentionally modeled on C++ code in
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
+import operator
import os.path
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
-from google.protobuf import descriptor_pb2
+
+try:
+ long # Python 2
+except NameError:
+ long = int # Python 3
+
# Tests whether the given TestAllTypes message is proto2 or not.
# This is used to gate several fields/features that only exist
@@ -48,6 +55,7 @@ from google.protobuf import descriptor_pb2
def IsProto2(message):
return message.DESCRIPTOR.syntax == "proto2"
+
def SetAllNonLazyFields(message):
"""Sets every non-lazy field in the message to a unique value.
@@ -125,22 +133,37 @@ def SetAllNonLazyFields(message):
message.repeated_string_piece.append(u'224')
message.repeated_cord.append(u'225')
- # Add a second one of each field.
- message.repeated_int32.append(301)
- message.repeated_int64.append(302)
- message.repeated_uint32.append(303)
- message.repeated_uint64.append(304)
- message.repeated_sint32.append(305)
- message.repeated_sint64.append(306)
- message.repeated_fixed32.append(307)
- message.repeated_fixed64.append(308)
- message.repeated_sfixed32.append(309)
- message.repeated_sfixed64.append(310)
- message.repeated_float.append(311)
- message.repeated_double.append(312)
- message.repeated_bool.append(False)
- message.repeated_string.append(u'315')
- message.repeated_bytes.append(b'316')
+ # Add a second one of each field and set value by index.
+ message.repeated_int32.append(0)
+ message.repeated_int64.append(0)
+ message.repeated_uint32.append(0)
+ message.repeated_uint64.append(0)
+ message.repeated_sint32.append(0)
+ message.repeated_sint64.append(0)
+ message.repeated_fixed32.append(0)
+ message.repeated_fixed64.append(0)
+ message.repeated_sfixed32.append(0)
+ message.repeated_sfixed64.append(0)
+ message.repeated_float.append(0)
+ message.repeated_double.append(0)
+ message.repeated_bool.append(True)
+ message.repeated_string.append(u'0')
+ message.repeated_bytes.append(b'0')
+ message.repeated_int32[1] = 301
+ message.repeated_int64[1] = 302
+ message.repeated_uint32[1] = 303
+ message.repeated_uint64[1] = 304
+ message.repeated_sint32[1] = 305
+ message.repeated_sint64[1] = 306
+ message.repeated_fixed32[1] = 307
+ message.repeated_fixed64[1] = 308
+ message.repeated_sfixed32[1] = 309
+ message.repeated_sfixed64[1] = 310
+ message.repeated_float[1] = 311
+ message.repeated_double[1] = 312
+ message.repeated_bool[1] = False
+ message.repeated_string[1] = u'315'
+ message.repeated_bytes[1] = b'316'
if IsProto2(message):
message.repeatedgroup.add().a = 317
@@ -149,7 +172,8 @@ def SetAllNonLazyFields(message):
message.repeated_import_message.add().d = 320
message.repeated_lazy_message.add().bb = 327
- message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
+ message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
+ message.repeated_nested_enum[1] = unittest_pb2.TestAllTypes.BAZ
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
if IsProto2(message):
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
@@ -692,3 +716,153 @@ def SetAllUnpackedFields(message):
message.unpacked_bool.extend([True, False])
message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
+
+
+class NonStandardInteger(numbers.Integral):
+ """An integer object that does not subclass int.
+
+ This is used to verify that both C++ and regular proto systems can handle
+ integer others than int and long and that they handle them in predictable
+ ways.
+
+ NonStandardInteger is the minimal legal specification for a custom Integral.
+ As such, it does not support 0 < x < 5 and it is not hashable.
+
+ Note: This is added here instead of relying on numpy or a similar library
+ with custom integers to limit dependencies.
+ """
+
+ def __init__(self, val, error_string_on_conversion=None):
+ assert isinstance(val, numbers.Integral)
+ if isinstance(val, NonStandardInteger):
+ val = val.val
+ self.val = val
+ self.error_string_on_conversion = error_string_on_conversion
+
+ def __long__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return long(self.val)
+
+ def __abs__(self):
+ return NonStandardInteger(operator.abs(self.val))
+
+ def __add__(self, y):
+ return NonStandardInteger(operator.add(self.val, y))
+
+ def __div__(self, y):
+ return NonStandardInteger(operator.div(self.val, y))
+
+ def __eq__(self, y):
+ return operator.eq(self.val, y)
+
+ def __floordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(self.val, y))
+
+ def __truediv__(self, y):
+ return NonStandardInteger(operator.truediv(self.val, y))
+
+ def __invert__(self):
+ return NonStandardInteger(operator.invert(self.val))
+
+ def __mod__(self, y):
+ return NonStandardInteger(operator.mod(self.val, y))
+
+ def __mul__(self, y):
+ return NonStandardInteger(operator.mul(self.val, y))
+
+ def __neg__(self):
+ return NonStandardInteger(operator.neg(self.val))
+
+ def __pos__(self):
+ return NonStandardInteger(operator.pos(self.val))
+
+ def __pow__(self, y):
+ return NonStandardInteger(operator.pow(self.val, y))
+
+ def __trunc__(self):
+ return int(self.val)
+
+ def __radd__(self, y):
+ return NonStandardInteger(operator.add(y, self.val))
+
+ def __rdiv__(self, y):
+ return NonStandardInteger(operator.div(y, self.val))
+
+ def __rmod__(self, y):
+ return NonStandardInteger(operator.mod(y, self.val))
+
+ def __rmul__(self, y):
+ return NonStandardInteger(operator.mul(y, self.val))
+
+ def __rpow__(self, y):
+ return NonStandardInteger(operator.pow(y, self.val))
+
+ def __rfloordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(y, self.val))
+
+ def __rtruediv__(self, y):
+ return NonStandardInteger(operator.truediv(y, self.val))
+
+ def __lshift__(self, y):
+ return NonStandardInteger(operator.lshift(self.val, y))
+
+ def __rshift__(self, y):
+ return NonStandardInteger(operator.rshift(self.val, y))
+
+ def __rlshift__(self, y):
+ return NonStandardInteger(operator.lshift(y, self.val))
+
+ def __rrshift__(self, y):
+ return NonStandardInteger(operator.rshift(y, self.val))
+
+ def __le__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.le(self.val, y)
+
+ def __lt__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.lt(self.val, y)
+
+ def __and__(self, y):
+ return NonStandardInteger(operator.and_(self.val, y))
+
+ def __or__(self, y):
+ return NonStandardInteger(operator.or_(self.val, y))
+
+ def __xor__(self, y):
+ return NonStandardInteger(operator.xor(self.val, y))
+
+ def __rand__(self, y):
+ return NonStandardInteger(operator.and_(y, self.val))
+
+ def __ror__(self, y):
+ return NonStandardInteger(operator.or_(y, self.val))
+
+ def __rxor__(self, y):
+ return NonStandardInteger(operator.xor(y, self.val))
+
+ def __bool__(self):
+ return self.val
+
+ def __nonzero__(self):
+ return self.val
+
+ def __ceil__(self):
+ return self
+
+ def __floor__(self):
+ return self
+
+ def __int__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return int(self.val)
+
+ def __round__(self):
+ return self
+
+ def __repr__(self):
+ return 'NonStandardInteger(%s)' % self.val
diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py
new file mode 100644
index 00000000..8ce06519
--- /dev/null
+++ b/python/google/protobuf/internal/testing_refleaks.py
@@ -0,0 +1,126 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""A subclass of unittest.TestCase which checks for reference leaks.
+
+To use:
+- Use testing_refleak.BaseTestCase instead of unittest.TestCase
+- Configure and compile Python with --with-pydebug
+
+If sys.gettotalrefcount() is not available (because Python was built without
+the Py_DEBUG option), then this module is a no-op and tests will run normally.
+"""
+
+import gc
+import sys
+
+try:
+ import copy_reg as copyreg #PY26
+except ImportError:
+ import copyreg
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+
+class LocalTestResult(unittest.TestResult):
+ """A TestResult which forwards events to a parent object, except for Skips."""
+
+ def __init__(self, parent_result):
+ unittest.TestResult.__init__(self)
+ self.parent_result = parent_result
+
+ def addError(self, test, error):
+ self.parent_result.addError(test, error)
+
+ def addFailure(self, test, error):
+ self.parent_result.addFailure(test, error)
+
+ def addSkip(self, test, reason):
+ pass
+
+
+class ReferenceLeakCheckerTestCase(unittest.TestCase):
+ """A TestCase which runs tests multiple times, collecting reference counts."""
+
+ NB_RUNS = 3
+
+ def run(self, result=None):
+ # python_message.py registers all Message classes to some pickle global
+ # registry, which makes the classes immortal.
+ # We save a copy of this registry, and reset it before we could references.
+ self._saved_pickle_registry = copyreg.dispatch_table.copy()
+
+ # Run the test twice, to warm up the instance attributes.
+ super(ReferenceLeakCheckerTestCase, self).run(result=result)
+ super(ReferenceLeakCheckerTestCase, self).run(result=result)
+
+ oldrefcount = 0
+ local_result = LocalTestResult(result)
+
+ refcount_deltas = []
+ for _ in range(self.NB_RUNS):
+ oldrefcount = self._getRefcounts()
+ super(ReferenceLeakCheckerTestCase, self).run(result=local_result)
+ newrefcount = self._getRefcounts()
+ refcount_deltas.append(newrefcount - oldrefcount)
+ print(refcount_deltas, self)
+
+ try:
+ self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
+ except Exception: # pylint: disable=broad-except
+ result.addError(self, sys.exc_info())
+
+ def _getRefcounts(self):
+ copyreg.dispatch_table.clear()
+ copyreg.dispatch_table.update(self._saved_pickle_registry)
+ # It is sometimes necessary to gc.collect() multiple times, to ensure
+ # that all objects can be collected.
+ gc.collect()
+ gc.collect()
+ gc.collect()
+ return sys.gettotalrefcount()
+
+
+if hasattr(sys, 'gettotalrefcount'):
+ BaseTestCase = ReferenceLeakCheckerTestCase
+ SkipReferenceLeakChecker = unittest.skip
+
+else:
+ # When PyDEBUG is not enabled, run the tests normally.
+ BaseTestCase = unittest.TestCase
+
+ def SkipReferenceLeakChecker(reason):
+ del reason # Don't skip, so don't need a reason.
+ def Same(func):
+ return func
+ return Same
diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py
index 338a287b..c7d182c4 100755
--- a/python/google/protobuf/internal/text_encoding_test.py
+++ b/python/google/protobuf/internal/text_encoding_test.py
@@ -33,9 +33,10 @@
"""Tests for google.protobuf.text_encoding."""
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import text_encoding
TEST_VALUES = [
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 0e14556c..237a2d50 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -1,4 +1,5 @@
#! /usr/bin/env python
+# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -35,23 +36,29 @@
__author__ = 'kenton@google.com (Kenton Varda)'
+import math
import re
import six
import string
try:
- import unittest2 as unittest
+ import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top
except ImportError:
- import unittest
+ import unittest # pylint: disable=g-import-not-at-top
+
from google.protobuf.internal import _parameterized
+from google.protobuf import any_pb2
+from google.protobuf import any_test_pb2
from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
-from google.protobuf.internal import test_util
+from google.protobuf.internal import any_test_pb2 as test_extend_any
from google.protobuf.internal import message_set_extensions_pb2
+from google.protobuf.internal import test_util
+from google.protobuf import descriptor_pool
from google.protobuf import text_format
@@ -62,7 +69,7 @@ class SimpleTextFormatTests(unittest.TestCase):
# expects single characters. Therefore it's an error (in addition to being
# non-sensical in the first place) to try to specify a "quote mark" that is
# more than one character.
- def TestQuoteMarksAreSingleChars(self):
+ def testQuoteMarksAreSingleChars(self):
for quote in text_format._QUOTES:
self.assertEqual(1, len(quote))
@@ -89,13 +96,11 @@ class TextFormatBase(unittest.TestCase):
.replace('e-0','e-').replace('e-0','e-')
# Floating point fields are printed with .0 suffix even if they are
# actualy integer numbers.
- text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ text = re.compile(r'\.0$', re.MULTILINE).sub('', text)
return text
-@_parameterized.Parameters(
- (unittest_pb2),
- (unittest_proto3_arena_pb2))
+@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
class TextFormatTest(TextFormatBase):
def testPrintExotic(self, message_module):
@@ -119,8 +124,10 @@ class TextFormatTest(TextFormatBase):
'repeated_string: "\\303\\274\\352\\234\\237"\n')
def testPrintExoticUnicodeSubclass(self, message_module):
+
class UnicodeSub(six.text_type):
pass
+
message = message_module.TestAllTypes()
message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
self.CompareToGoldenText(
@@ -164,8 +171,8 @@ class TextFormatTest(TextFormatBase):
message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
message.repeated_string.append(u'\u00fc\ua71f')
self.CompareToGoldenText(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, as_one_line=True)),
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, as_one_line=True)),
'repeated_int64: -9223372036854775808'
' repeated_uint64: 18446744073709551615'
' repeated_double: 123.456'
@@ -186,21 +193,23 @@ class TextFormatTest(TextFormatBase):
message.repeated_string.append(u'\u00fc\ua71f')
# Test as_utf8 = False.
- wire_text = text_format.MessageToString(
- message, as_one_line=True, as_utf8=False)
+ wire_text = text_format.MessageToString(message,
+ as_one_line=True,
+ as_utf8=False)
parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message)
# Test as_utf8 = True.
- wire_text = text_format.MessageToString(
- message, as_one_line=True, as_utf8=True)
+ wire_text = text_format.MessageToString(message,
+ as_one_line=True,
+ as_utf8=True)
parsed_message = message_module.TestAllTypes()
r = text_format.Parse(wire_text, parsed_message)
self.assertIs(r, parsed_message)
self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ '\n%s != %s' % (message, parsed_message))
def testPrintRawUtf8String(self, message_module):
message = message_module.TestAllTypes()
@@ -210,7 +219,7 @@ class TextFormatTest(TextFormatBase):
parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ '\n%s != %s' % (message, parsed_message))
def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
@@ -231,14 +240,15 @@ class TextFormatTest(TextFormatBase):
message.payload.repeated_double.append(.000078900)
formatted_fields = ['optional_float: 1.25',
'optional_double: -3.45678901234568e-6',
- 'repeated_float: -5642',
- 'repeated_double: 7.89e-5']
+ 'repeated_float: -5642', 'repeated_double: 7.89e-5']
text_message = text_format.MessageToString(message, float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
- 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields))
+ 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(
+ *formatted_fields))
# as_one_line=True is a separate code branch where float_format is passed.
- text_message = text_format.MessageToString(message, as_one_line=True,
+ text_message = text_format.MessageToString(message,
+ as_one_line=True,
float_format='.15g')
self.CompareToGoldenText(
self.RemoveRedundantZeros(text_message),
@@ -249,6 +259,36 @@ class TextFormatTest(TextFormatBase):
message.c = 123
self.assertEqual('c: 123\n', str(message))
+ def testPrintField(self, message_module):
+ message = message_module.TestAllTypes()
+ field = message.DESCRIPTOR.fields_by_name['optional_float']
+ value = message.optional_float
+ out = text_format.TextWriter(False)
+ text_format.PrintField(field, value, out)
+ self.assertEqual('optional_float: 0.0\n', out.getvalue())
+ out.close()
+ # Test Printer
+ out = text_format.TextWriter(False)
+ printer = text_format._Printer(out)
+ printer.PrintField(field, value)
+ self.assertEqual('optional_float: 0.0\n', out.getvalue())
+ out.close()
+
+ def testPrintFieldValue(self, message_module):
+ message = message_module.TestAllTypes()
+ field = message.DESCRIPTOR.fields_by_name['optional_float']
+ value = message.optional_float
+ out = text_format.TextWriter(False)
+ text_format.PrintFieldValue(field, value, out)
+ self.assertEqual('0.0', out.getvalue())
+ out.close()
+ # Test Printer
+ out = text_format.TextWriter(False)
+ printer = text_format._Printer(out)
+ printer.PrintFieldValue(field, value)
+ self.assertEqual('0.0', out.getvalue())
+ out.close()
+
def testParseAllFields(self, message_module):
message = message_module.TestAllTypes()
test_util.SetAllFields(message)
@@ -260,6 +300,33 @@ class TextFormatTest(TextFormatBase):
if message_module is unittest_pb2:
test_util.ExpectAllFieldsSet(self, message)
+ def testParseAndMergeUtf8(self, message_module):
+ message = message_module.TestAllTypes()
+ test_util.SetAllFields(message)
+ ascii_text = text_format.MessageToString(message)
+ ascii_text = ascii_text.encode('utf-8')
+
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, message)
+
+ parsed_message.Clear()
+ text_format.Merge(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+ if message_module is unittest_pb2:
+ test_util.ExpectAllFieldsSet(self, message)
+
+ if six.PY2:
+ msg2 = message_module.TestAllTypes()
+ text = (u'optional_string: "café"')
+ text_format.Merge(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
+ msg2.Clear()
+ text_format.Parse(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
+
def testParseExotic(self, message_module):
message = message_module.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
@@ -280,8 +347,7 @@ class TextFormatTest(TextFormatBase):
self.assertEqual(123.456, message.repeated_double[0])
self.assertEqual(1.23e22, message.repeated_double[1])
self.assertEqual(1.23e-18, message.repeated_double[2])
- self.assertEqual(
- '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
+ self.assertEqual('\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
self.assertEqual('foocorgegrault', message.repeated_string[1])
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
@@ -304,6 +370,7 @@ class TextFormatTest(TextFormatBase):
def testParseRepeatedScalarShortFormat(self, message_module):
message = message_module.TestAllTypes()
text = ('repeated_int64: [100, 200];\n'
+ 'repeated_int64: []\n'
'repeated_int64: 300,\n'
'repeated_string: ["one", "two"];\n')
text_format.Parse(text, message)
@@ -314,6 +381,18 @@ class TextFormatTest(TextFormatBase):
self.assertEqual(u'one', message.repeated_string[0])
self.assertEqual(u'two', message.repeated_string[1])
+ def testParseRepeatedMessageShortFormat(self, message_module):
+ message = message_module.TestAllTypes()
+ text = ('repeated_nested_message: [{bb: 100}, {bb: 200}],\n'
+ 'repeated_nested_message: {bb: 300}\n'
+ 'repeated_nested_message [{bb: 400}];\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_nested_message[0].bb)
+ self.assertEqual(200, message.repeated_nested_message[1].bb)
+ self.assertEqual(300, message.repeated_nested_message[2].bb)
+ self.assertEqual(400, message.repeated_nested_message[3].bb)
+
def testParseEmptyText(self, message_module):
message = message_module.TestAllTypes()
text = ''
@@ -323,50 +402,39 @@ class TextFormatTest(TextFormatBase):
def testParseInvalidUtf8(self, message_module):
message = message_module.TestAllTypes()
text = 'repeated_string: "\\xc3\\xc3"'
- self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Parse(text, message)
+ self.assertEqual(e.exception.GetLine(), 1)
+ self.assertEqual(e.exception.GetColumn(), 28)
def testParseSingleWord(self, message_module):
message = message_module.TestAllTypes()
text = 'foo'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
- r'"foo".'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"foo".'), text_format.Parse, text, message)
def testParseUnknownField(self, message_module):
message = message_module.TestAllTypes()
text = 'unknown_field: 8\n'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
- r'"unknown_field".'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"unknown_field".'), text_format.Parse, text, message)
def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
- r'has no value named BARR.'),
- text_format.Parse, text, message)
-
- message = message_module.TestAllTypes()
- text = 'optional_nested_enum: 100'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
- r'has no value with number 100.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value named BARR.'), text_format.Parse,
+ text, message)
def testParseBadIntValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:17 : Couldn\'t parse integer: bork'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ ('1:17 : Couldn\'t parse integer: bork'),
+ text_format.Parse, text, message)
def testParseStringFieldUnescape(self, message_module):
message = message_module.TestAllTypes()
@@ -376,6 +444,7 @@ class TextFormatTest(TextFormatBase):
repeated_string: "\\\\xf\\\\x62"
repeated_string: "\\\\\xf\\\\\x62"
repeated_string: "\x5cx20"'''
+
text_format.Parse(text, message)
SLASH = '\\'
@@ -390,8 +459,7 @@ class TextFormatTest(TextFormatBase):
def testMergeDuplicateScalars(self, message_module):
message = message_module.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
+ text = ('optional_int32: 42 ' 'optional_int32: 67')
r = text_format.Merge(text, message)
self.assertIs(r, message)
self.assertEqual(67, message.optional_int32)
@@ -411,6 +479,19 @@ class TextFormatTest(TextFormatBase):
text_format.Parse(text_format.MessageToString(m), m2)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+ def testMergeMultipleOneof(self, message_module):
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ text_format.Merge(m_string, m2)
+ self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
+
+ def testParseMultipleOneof(self, message_module):
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ with self.assertRaisesRegexp(text_format.ParseError,
+ ' is specified along with field '):
+ text_format.Parse(m_string, m2)
+
# These are tests that aren't fundamentally specific to proto2, but are at
# the moment because of differences between the proto2 and proto3 test schemas.
@@ -421,12 +502,13 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message)
self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
'text_format_unittest_data_pointy_oneof.txt')
def testParseGolden(self):
- golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_data_oneof_implemented.txt'))
parsed_message = unittest_pb2.TestAllTypes()
r = text_format.Parse(golden_text, parsed_message)
self.assertIs(r, parsed_message)
@@ -442,34 +524,73 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
self.RemoveRedundantZeros(text_format.MessageToString(message)),
'text_format_unittest_data_oneof_implemented.txt')
- def testPrintAllFieldsPointy(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(
- text_format.MessageToString(message, pointy_brackets=True)),
- 'text_format_unittest_data_pointy_oneof.txt')
-
def testPrintInIndexOrder(self):
message = unittest_pb2.TestFieldOrderings()
- message.my_string = '115'
+ # Fields are listed in index order instead of field number.
+ message.my_string = 'str'
message.my_int = 101
message.my_float = 111
message.optional_nested_message.oo = 0
message.optional_nested_message.bb = 1
+ message.Extensions[unittest_pb2.my_extension_string] = 'ext_str0'
+ # Extensions are listed based on the order of extension number.
+ # Extension number 12.
+ message.Extensions[unittest_pb2.TestExtensionOrderings2.
+ test_ext_orderings2].my_string = 'ext_str2'
+ # Extension number 13.
+ message.Extensions[unittest_pb2.TestExtensionOrderings1.
+ test_ext_orderings1].my_string = 'ext_str1'
+ # Extension number 14.
+ message.Extensions[
+ unittest_pb2.TestExtensionOrderings2.TestExtensionOrderings3.
+ test_ext_orderings3].my_string = 'ext_str3'
+
+ # Print in index order.
self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message, use_index_order=True)),
- 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
- 'optional_nested_message {\n oo: 0\n bb: 1\n}\n')
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, use_index_order=True)),
+ 'my_string: "str"\n'
+ 'my_int: 101\n'
+ 'my_float: 111\n'
+ 'optional_nested_message {\n'
+ ' oo: 0\n'
+ ' bb: 1\n'
+ '}\n'
+ '[protobuf_unittest.TestExtensionOrderings2.test_ext_orderings2] {\n'
+ ' my_string: "ext_str2"\n'
+ '}\n'
+ '[protobuf_unittest.TestExtensionOrderings1.test_ext_orderings1] {\n'
+ ' my_string: "ext_str1"\n'
+ '}\n'
+ '[protobuf_unittest.TestExtensionOrderings2.TestExtensionOrderings3'
+ '.test_ext_orderings3] {\n'
+ ' my_string: "ext_str3"\n'
+ '}\n'
+ '[protobuf_unittest.my_extension_string]: "ext_str0"\n')
+ # By default, print in field number order.
self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(
- message)),
- 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
- 'optional_nested_message {\n bb: 1\n oo: 0\n}\n')
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'my_int: 101\n'
+ 'my_string: "str"\n'
+ '[protobuf_unittest.TestExtensionOrderings2.test_ext_orderings2] {\n'
+ ' my_string: "ext_str2"\n'
+ '}\n'
+ '[protobuf_unittest.TestExtensionOrderings1.test_ext_orderings1] {\n'
+ ' my_string: "ext_str1"\n'
+ '}\n'
+ '[protobuf_unittest.TestExtensionOrderings2.TestExtensionOrderings3'
+ '.test_ext_orderings3] {\n'
+ ' my_string: "ext_str3"\n'
+ '}\n'
+ '[protobuf_unittest.my_extension_string]: "ext_str0"\n'
+ 'my_float: 111\n'
+ 'optional_nested_message {\n'
+ ' bb: 1\n'
+ ' oo: 0\n'
+ '}\n')
def testMergeLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
+ opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
parsed_message = unittest_pb2.TestAllTypes()
r = text_format.MergeLines(opened, parsed_message)
self.assertIs(r, parsed_message)
@@ -479,7 +600,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
self.assertEqual(message, parsed_message)
def testParseLinesGolden(self):
- opened = self.ReadGolden('text_format_unittest_data.txt')
+ opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
parsed_message = unittest_pb2.TestAllTypes()
r = text_format.ParseLines(opened, parsed_message)
self.assertIs(r, parsed_message)
@@ -495,14 +616,13 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message.map_int64_int64[-2**33] = -2**34
message.map_uint32_uint32[123] = 456
message.map_uint64_uint64[2**33] = 2**34
- message.map_string_string["abc"] = "123"
+ message.map_string_string['abc'] = '123'
message.map_int32_foreign_message[111].c = 5
# Maps are serialized to text format using their underlying repeated
# representation.
self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'map_int32_int32 {\n'
+ text_format.MessageToString(message), 'map_int32_int32 {\n'
' key: -123\n'
' value: -456\n'
'}\n'
@@ -535,29 +655,24 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
message.map_string_string[letter] = 'dummy'
for letter in reversed(string.ascii_uppercase[0:13]):
message.map_string_string[letter] = 'dummy'
- golden = ''.join((
- 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,)
- for letter in string.ascii_uppercase))
+ golden = ''.join(('map_string_string {\n key: "%c"\n value: "dummy"\n}\n'
+ % (letter,) for letter in string.ascii_uppercase))
self.CompareToGoldenText(text_format.MessageToString(message), golden)
- def testMapOrderSemantics(self):
- golden_lines = self.ReadGolden('map_test_data.txt')
- # The C++ implementation emits defaulted-value fields, while the Python
- # implementation does not. Adjusting for this is awkward, but it is
- # valuable to test against a common golden file.
- line_blacklist = (' key: 0\n',
- ' value: 0\n',
- ' key: false\n',
- ' value: false\n')
- golden_lines = [line for line in golden_lines if line not in line_blacklist]
+ # TODO(teboring): In c/137553523, not serializing default value for map entry
+ # message has been fixed. This test needs to be disabled in order to submit
+ # that cl. Add this back when c/137553523 has been submitted.
+ # def testMapOrderSemantics(self):
+ # golden_lines = self.ReadGolden('map_test_data.txt')
- message = map_unittest_pb2.TestMap()
- text_format.ParseLines(golden_lines, message)
- candidate = text_format.MessageToString(message)
- # The Python implementation emits "1.0" for the double value that the C++
- # implementation emits as "1".
- candidate = candidate.replace('1.0', '1', 2)
- self.assertMultiLineEqual(candidate, ''.join(golden_lines))
+ # message = map_unittest_pb2.TestMap()
+ # text_format.ParseLines(golden_lines, message)
+ # candidate = text_format.MessageToString(message)
+ # # The Python implementation emits "1.0" for the double value that the C++
+ # # implementation emits as "1".
+ # candidate = candidate.replace('1.0', '1', 2)
+ # candidate = candidate.replace('0.0', '0', 2)
+ # self.assertMultiLineEqual(candidate, ''.join(golden_lines))
# Tests of proto2-only features (MessageSet, extensions, etc.).
@@ -570,8 +685,7 @@ class Proto2Tests(TextFormatBase):
message.message_set.Extensions[ext1].i = 23
message.message_set.Extensions[ext2].str = 'foo'
self.CompareToGoldenText(
- text_format.MessageToString(message),
- 'message_set {\n'
+ text_format.MessageToString(message), 'message_set {\n'
' [protobuf_unittest.TestMessageSetExtension1] {\n'
' i: 23\n'
' }\n'
@@ -589,6 +703,24 @@ class Proto2Tests(TextFormatBase):
' text: \"bar\"\n'
'}\n')
+ def testPrintMessageSetByFieldNumber(self):
+ out = text_format.TextWriter(False)
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ text_format.PrintMessage(message, out, use_field_number=True)
+ self.CompareToGoldenText(out.getvalue(), '1 {\n'
+ ' 1545008 {\n'
+ ' 15: 23\n'
+ ' }\n'
+ ' 1547769 {\n'
+ ' 25: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ out.close()
+
def testPrintMessageSetAsOneLine(self):
message = unittest_mset_pb2.TestMessageSetContainer()
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
@@ -608,8 +740,7 @@ class Proto2Tests(TextFormatBase):
def testParseMessageSet(self):
message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
+ text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n')
text_format.Parse(text, message)
self.assertEqual(1, message.repeated_uint64[0])
self.assertEqual(2, message.repeated_uint64[1])
@@ -629,6 +760,62 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+ def testExtensionInsideAnyMessage(self):
+ message = test_extend_any.TestAny()
+ text = ('value {\n'
+ ' [type.googleapis.com/google.protobuf.internal.TestAny] {\n'
+ ' [google.protobuf.internal.TestAnyExtension1.extension1] {\n'
+ ' i: 10\n'
+ ' }\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, descriptor_pool=descriptor_pool.Default()),
+ text)
+
+ def testParseMessageByFieldNumber(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('34: 1\n' 'repeated_uint64: 2\n')
+ text_format.Parse(text, message, allow_field_number=True)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('1 {\n'
+ ' 1545008 {\n'
+ ' 15: 23\n'
+ ' }\n'
+ ' 1547769 {\n'
+ ' 25: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Parse(text, message, allow_field_number=True)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEqual(23, message.message_set.Extensions[ext1].i)
+ self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+
+ # Can't parse field number without set allow_field_number=True.
+ message = unittest_pb2.TestAllTypes()
+ text = '34:1\n'
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"34".'), text_format.Parse, text, message)
+
+ # Can't parse if field number is not found.
+ text = '1234:1\n'
+ six.assertRaisesRegex(
+ self,
+ text_format.ParseError,
+ (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
+ r'"1234".'),
+ text_format.Parse,
+ text,
+ message,
+ allow_field_number=True)
+
def testPrintAllExtensions(self):
message = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(message)
@@ -669,15 +856,17 @@ class Proto2Tests(TextFormatBase):
text = ('message_set {\n'
' [unknown_extension] {\n'
' i: 23\n'
+ ' bin: "\xe0"'
' [nested_unknown_ext]: {\n'
' i: 23\n'
+ ' x: x\n'
' test: "test_string"\n'
' floaty_float: -0.315\n'
' num: -inf\n'
' multiline_str: "abc"\n'
' "def"\n'
' "xyz."\n'
- ' [nested_unknown_ext]: <\n'
+ ' [nested_unknown_ext.ext]: <\n'
' i: 23\n'
' i: 24\n'
' pointfloat: .3\n'
@@ -689,6 +878,10 @@ class Proto2Tests(TextFormatBase):
' }\n'
' }\n'
' [unknown_extension]: 5\n'
+ ' [unknown_extension_with_number_field] {\n'
+ ' 1: "some_field"\n'
+ ' 2: -0.451\n'
+ ' }\n'
'}\n')
text_format.Parse(text, message, allow_unknown_extension=True)
golden = 'message_set {\n}\n'
@@ -704,7 +897,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: }',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -716,7 +911,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: "',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -728,7 +925,9 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'Invalid field value: "',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -740,24 +939,27 @@ class Proto2Tests(TextFormatBase):
six.assertRaisesRegex(self,
text_format.ParseError,
'5:1 : Expected ">".',
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
# Don't allow unknown fields with allow_unknown_extension=True.
message = unittest_mset_pb2.TestMessageSetContainer()
malformed = ('message_set {\n'
' unknown_field: true\n'
- ' \n' # Missing '>' here.
'}\n')
six.assertRaisesRegex(self,
text_format.ParseError,
('2:3 : Message type '
'"proto2_wireformat_unittest.TestMessageSet" has no'
' field named "unknown_field".'),
- text_format.Parse, malformed, message,
+ text_format.Parse,
+ malformed,
+ message,
allow_unknown_extension=True)
- # Parse known extension correcty.
+ # Parse known extension correctly.
message = unittest_mset_pb2.TestMessageSetContainer()
text = ('message_set {\n'
' [protobuf_unittest.TestMessageSetExtension1] {\n'
@@ -773,70 +975,87 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+ def testParseBadIdentifier(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { "bb": 1 }')
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Parse(text, message)
+ self.assertEqual(str(e.exception),
+ '1:27 : Expected identifier or number, got "bb".')
+
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
- six.assertRaisesRegex(self,
- text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError,
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'), text_format.Parse, text, message)
+
+ def testParseNumericUnknownEnum(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_nested_enum: 100'
+ six.assertRaisesRegex(self, text_format.ParseError,
+ (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ r'has no value with number 100.'), text_format.Parse,
+ text, message)
def testMergeDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
text_format.Merge(text, message)
- self.assertEqual(
- 67,
- message.Extensions[unittest_pb2.optional_int32_extension])
+ self.assertEqual(67,
+ message.Extensions[unittest_pb2.optional_int32_extension])
def testParseDuplicateExtensionScalars(self):
message = unittest_pb2.TestAllExtensions()
text = ('[protobuf_unittest.optional_int32_extension]: 42 '
'[protobuf_unittest.optional_int32_extension]: 67')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
- 'should not have multiple '
- '"protobuf_unittest.optional_int32_extension" extensions.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:96 : Message type "protobuf_unittest.TestAllExtensions" '
+ 'should not have multiple '
+ '"protobuf_unittest.optional_int32_extension" extensions.'),
+ text_format.Parse, text, message)
- def testParseDuplicateNestedMessageScalars(self):
+ def testParseDuplicateMessages(self):
message = unittest_pb2.TestAllTypes()
text = ('optional_nested_message { bb: 1 } '
'optional_nested_message { bb: 2 }')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
- 'should not have multiple "bb" fields.'),
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:59 : Message type "protobuf_unittest.TestAllTypes" '
+ 'should not have multiple "optional_nested_message" fields.'),
+ text_format.Parse, text,
+ message)
+
+ def testParseDuplicateExtensionMessages(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = ('[protobuf_unittest.optional_nested_message_extension]: {} '
+ '[protobuf_unittest.optional_nested_message_extension]: {}')
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:114 : Message type "protobuf_unittest.TestAllExtensions" '
+ 'should not have multiple '
+ '"protobuf_unittest.optional_nested_message_extension" extensions.'),
+ text_format.Parse, text, message)
def testParseDuplicateScalars(self):
message = unittest_pb2.TestAllTypes()
- text = ('optional_int32: 42 '
- 'optional_int32: 67')
- six.assertRaisesRegex(self,
- text_format.ParseError,
- ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
- 'have multiple "optional_int32" fields.'),
- text_format.Parse, text, message)
+ text = ('optional_int32: 42 ' 'optional_int32: 67')
+ six.assertRaisesRegex(self, text_format.ParseError, (
+ '1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'), text_format.Parse, text,
+ message)
def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <'
- six.assertRaisesRegex(self,
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Parse, text, message)
text = 'RepeatedGroup: {'
- six.assertRaisesRegex(self,
- text_format.ParseError, '1:16 : Expected "}".',
- text_format.Parse, text, message)
+ six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Parse, text, message)
def testParseEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
@@ -887,10 +1106,212 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(-2**34, message.map_int64_int64[-2**33])
self.assertEqual(456, message.map_uint32_uint32[123])
self.assertEqual(2**34, message.map_uint64_uint64[2**33])
- self.assertEqual("123", message.map_string_string["abc"])
+ self.assertEqual('123', message.map_string_string['abc'])
self.assertEqual(5, message.map_int32_foreign_message[111].c)
+class Proto3Tests(unittest.TestCase):
+
+ def testPrintMessageExpandAny(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+
+ def testTopAnyMessage(self):
+ packed_msg = unittest_pb2.OneString()
+ msg = any_pb2.Any()
+ msg.Pack(packed_msg)
+ text = text_format.MessageToString(msg)
+ other_msg = text_format.Parse(text, any_pb2.Any())
+ self.assertEqual(msg, other_msg)
+
+ def testPrintMessageExpandAnyRepeated(self):
+ packed_message = unittest_pb2.OneString()
+ message = any_test_pb2.TestAny()
+ packed_message.data = 'string0'
+ message.repeated_any_value.add().Pack(packed_message)
+ packed_message.data = 'string1'
+ message.repeated_any_value.add().Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message),
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string0"\n'
+ ' }\n'
+ '}\n'
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string1"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyDescriptorPoolMissingType(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ empty_pool = descriptor_pool.DescriptorPool()
+ self.assertEqual(
+ text_format.MessageToString(message, descriptor_pool=empty_pool),
+ 'any_value {\n'
+ ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+
+ def testPrintMessageExpandAnyPointyBrackets(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ pointy_brackets=True),
+ 'any_value <\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] <\n'
+ ' data: "string"\n'
+ ' >\n'
+ '>\n')
+
+ def testPrintMessageExpandAnyAsOneLine(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ as_one_line=True),
+ 'any_value {'
+ ' [type.googleapis.com/protobuf_unittest.OneString]'
+ ' { data: "string" } '
+ '}')
+
+ def testPrintMessageExpandAnyAsOneLinePointyBrackets(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ self.assertEqual(
+ text_format.MessageToString(message,
+ as_one_line=True,
+ pointy_brackets=True,
+ descriptor_pool=descriptor_pool.Default()),
+ 'any_value <'
+ ' [type.googleapis.com/protobuf_unittest.OneString]'
+ ' < data: "string" > '
+ '>')
+
+ def testUnknownEnums(self):
+ message = unittest_proto3_arena_pb2.TestAllTypes()
+ message2 = unittest_proto3_arena_pb2.TestAllTypes()
+ message.optional_nested_enum = 999
+ text_string = text_format.MessageToString(message)
+ text_format.Parse(text_string, message2)
+ self.assertEqual(999, message2.optional_nested_enum)
+
+ def testMergeExpandedAny(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+ message.Clear()
+ text_format.Parse(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+ def testMergeExpandedAnyRepeated(self):
+ message = any_test_pb2.TestAny()
+ text = ('repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string0"\n'
+ ' }\n'
+ '}\n'
+ 'repeated_any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string1"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.repeated_any_value[0].Unpack(packed_message)
+ self.assertEqual('string0', packed_message.data)
+ message.repeated_any_value[1].Unpack(packed_message)
+ self.assertEqual('string1', packed_message.data)
+
+ def testMergeExpandedAnyPointyBrackets(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] <\n'
+ ' data: "string"\n'
+ ' >\n'
+ '}\n')
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+ def testMergeAlternativeUrl(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.otherapi.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ self.assertEqual('type.otherapi.com/protobuf_unittest.OneString',
+ message.any_value.type_url)
+
+ def testMergeExpandedAnyDescriptorPoolMissingType(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n'
+ ' }\n'
+ '}\n')
+ with self.assertRaises(text_format.ParseError) as e:
+ empty_pool = descriptor_pool.DescriptorPool()
+ text_format.Merge(text, message, descriptor_pool=empty_pool)
+ self.assertEqual(
+ str(e.exception),
+ 'Type protobuf_unittest.OneString not found in descriptor pool')
+
+ def testMergeUnexpandedAny(self):
+ text = ('any_value {\n'
+ ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+ message = any_test_pb2.TestAny()
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ message.any_value.Unpack(packed_message)
+ self.assertEqual('string', packed_message.data)
+
+ def testMergeMissingAnyEndToken(self):
+ message = any_test_pb2.TestAny()
+ text = ('any_value {\n'
+ ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' data: "string"\n')
+ with self.assertRaises(text_format.ParseError) as e:
+ text_format.Merge(text, message)
+ self.assertEqual(str(e.exception), '3:11 : Expected "}".')
+
+
class TokenizerTest(unittest.TestCase):
def testSimpleTokenCases(self):
@@ -900,140 +1321,336 @@ class TokenizerTest(unittest.TestCase):
'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
- 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ')
- tokenizer = text_format._Tokenizer(text.splitlines())
- methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
- ':',
+ 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f '
+ 'False_bool: False True_bool: True X:iNf Y:-inF Z:nAN')
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':',
(tokenizer.ConsumeString, 'string1'),
- (tokenizer.ConsumeIdentifier, 'identifier2'),
- ':',
- (tokenizer.ConsumeInt32, 123),
- (tokenizer.ConsumeIdentifier, 'identifier3'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'identifier2'), ':',
+ (tokenizer.ConsumeInteger, 123),
+ (tokenizer.ConsumeIdentifier, 'identifier3'), ':',
(tokenizer.ConsumeString, 'string'),
- (tokenizer.ConsumeIdentifier, 'identifiER_4'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'identifiER_4'), ':',
(tokenizer.ConsumeFloat, 1.1e+2),
- (tokenizer.ConsumeIdentifier, 'ID5'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID5'), ':',
(tokenizer.ConsumeFloat, -0.23),
- (tokenizer.ConsumeIdentifier, 'ID6'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID6'), ':',
(tokenizer.ConsumeString, 'aaaa\'bbbb'),
- (tokenizer.ConsumeIdentifier, 'ID7'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID7'), ':',
(tokenizer.ConsumeString, 'aa\"bb'),
- (tokenizer.ConsumeIdentifier, 'ID8'),
- ':',
- '{',
- (tokenizer.ConsumeIdentifier, 'A'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID8'), ':', '{',
+ (tokenizer.ConsumeIdentifier, 'A'), ':',
(tokenizer.ConsumeFloat, float('inf')),
- (tokenizer.ConsumeIdentifier, 'B'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'B'), ':',
(tokenizer.ConsumeFloat, -float('inf')),
- (tokenizer.ConsumeIdentifier, 'C'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'C'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'D'),
- ':',
- (tokenizer.ConsumeBool, False),
- '}',
- (tokenizer.ConsumeIdentifier, 'ID9'),
- ':',
- (tokenizer.ConsumeUint32, 22),
- (tokenizer.ConsumeIdentifier, 'ID10'),
- ':',
- (tokenizer.ConsumeInt64, -111111111111111111),
- (tokenizer.ConsumeIdentifier, 'ID11'),
- ':',
- (tokenizer.ConsumeInt32, -22),
- (tokenizer.ConsumeIdentifier, 'ID12'),
- ':',
- (tokenizer.ConsumeUint64, 2222222222222222222),
- (tokenizer.ConsumeIdentifier, 'ID13'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'D'), ':',
+ (tokenizer.ConsumeBool, False), '}',
+ (tokenizer.ConsumeIdentifier, 'ID9'), ':',
+ (tokenizer.ConsumeInteger, 22),
+ (tokenizer.ConsumeIdentifier, 'ID10'), ':',
+ (tokenizer.ConsumeInteger, -111111111111111111),
+ (tokenizer.ConsumeIdentifier, 'ID11'), ':',
+ (tokenizer.ConsumeInteger, -22),
+ (tokenizer.ConsumeIdentifier, 'ID12'), ':',
+ (tokenizer.ConsumeInteger, 2222222222222222222),
+ (tokenizer.ConsumeIdentifier, 'ID13'), ':',
(tokenizer.ConsumeFloat, 1.23456),
- (tokenizer.ConsumeIdentifier, 'ID14'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'ID14'), ':',
(tokenizer.ConsumeFloat, 1.2e+2),
- (tokenizer.ConsumeIdentifier, 'false_bool'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'false_bool'), ':',
(tokenizer.ConsumeBool, False),
- (tokenizer.ConsumeIdentifier, 'true_BOOL'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'true_BOOL'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'true_bool1'),
- ':',
+ (tokenizer.ConsumeIdentifier, 'true_bool1'), ':',
(tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'false_BOOL1'),
- ':',
- (tokenizer.ConsumeBool, False)]
+ (tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':',
+ (tokenizer.ConsumeBool, False),
+ (tokenizer.ConsumeIdentifier, 'False_bool'), ':',
+ (tokenizer.ConsumeBool, False),
+ (tokenizer.ConsumeIdentifier, 'True_bool'), ':',
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'X'), ':',
+ (tokenizer.ConsumeFloat, float('inf')),
+ (tokenizer.ConsumeIdentifier, 'Y'), ':',
+ (tokenizer.ConsumeFloat, float('-inf')),
+ (tokenizer.ConsumeIdentifier, 'Z'), ':',
+ (tokenizer.ConsumeFloat, float('nan'))]
i = 0
while not tokenizer.AtEnd():
m = methods[i]
- if type(m) == str:
+ if isinstance(m, str):
token = tokenizer.token
self.assertEqual(token, m)
tokenizer.NextToken()
+ elif isinstance(m[1], float) and math.isnan(m[1]):
+ self.assertTrue(math.isnan(m[0]()))
else:
self.assertEqual(m[1], m[0]())
i += 1
- def testConsumeIntegers(self):
+ def testConsumeAbstractIntegers(self):
# This test only tests the failures in the integer parsing methods as well
# as the '0' special cases.
int64_max = (1 << 63) - 1
uint32_max = (1 << 32) - 1
text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
- tokenizer = text_format._Tokenizer(text.splitlines())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
- self.assertEqual(-1, tokenizer.ConsumeInt32())
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(-1, tokenizer.ConsumeInteger())
+
+ self.assertEqual(uint32_max + 1, tokenizer.ConsumeInteger())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32)
- self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64())
+ self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger())
+ self.assertTrue(tokenizer.AtEnd())
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64)
- self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64())
+ text = '-0 0 0 1.2'
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(True, tokenizer.TryConsumeInteger())
+ self.assertEqual(False, tokenizer.TryConsumeInteger())
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger()
+ self.assertEqual(1.2, tokenizer.ConsumeFloat())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeIntegers(self):
+ # This test only tests the failures in the integer parsing methods as well
+ # as the '0' special cases.
+ int64_max = (1 << 63) - 1
+ uint32_max = (1 << 32) - 1
+ text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint32, tokenizer)
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint64, tokenizer)
+ self.assertEqual(-1, text_format._ConsumeInt32(tokenizer))
+
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeUint32, tokenizer)
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeInt32, tokenizer)
+ self.assertEqual(uint32_max + 1, text_format._ConsumeInt64(tokenizer))
+
+ self.assertRaises(text_format.ParseError,
+ text_format._ConsumeInt64, tokenizer)
+ self.assertEqual(int64_max + 1, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
text = '-0 -0 0 0'
- tokenizer = text_format._Tokenizer(text.splitlines())
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
+ self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
def testConsumeByteString(self):
text = '"string1\''
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = 'string1"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\xt"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\x"'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
def testConsumeBool(self):
text = 'not-a-bool'
- tokenizer = text_format._Tokenizer(text.splitlines())
+ tokenizer = text_format.Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
+ def testSkipComment(self):
+ tokenizer = text_format.Tokenizer('# some comment'.splitlines())
+ self.assertTrue(tokenizer.AtEnd())
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+
+ def testConsumeComment(self):
+ tokenizer = text_format.Tokenizer('# some comment'.splitlines(),
+ skip_comments=False)
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeTwoComments(self):
+ text = '# some comment\n# another comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual('# another comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeTrailingComment(self):
+ text = 'some_number: 4\n# some comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+
+ self.assertEqual('some_number', tokenizer.ConsumeIdentifier())
+ self.assertEqual(tokenizer.token, ':')
+ tokenizer.NextToken()
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertFalse(tokenizer.AtEnd())
+
+ self.assertEqual('# some comment', tokenizer.ConsumeComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeLineComment(self):
+ tokenizer = text_format.Tokenizer('# some comment'.splitlines(),
+ skip_comments=False)
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual((False, '# some comment'),
+ tokenizer.ConsumeCommentOrTrailingComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeTwoLineComments(self):
+ text = '# some comment\n# another comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertEqual((False, '# some comment'),
+ tokenizer.ConsumeCommentOrTrailingComment())
+ self.assertFalse(tokenizer.AtEnd())
+ self.assertEqual((False, '# another comment'),
+ tokenizer.ConsumeCommentOrTrailingComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeAndCheckTrailingComment(self):
+ text = 'some_number: 4 # some comment' # trailing comment on the same line
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertRaises(text_format.ParseError,
+ tokenizer.ConsumeCommentOrTrailingComment)
+
+ self.assertEqual('some_number', tokenizer.ConsumeIdentifier())
+ self.assertEqual(tokenizer.token, ':')
+ tokenizer.NextToken()
+ self.assertRaises(text_format.ParseError,
+ tokenizer.ConsumeCommentOrTrailingComment)
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertFalse(tokenizer.AtEnd())
+
+ self.assertEqual((True, '# some comment'),
+ tokenizer.ConsumeCommentOrTrailingComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testHashinComment(self):
+ text = 'some_number: 4 # some comment # not a new comment'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ self.assertEqual('some_number', tokenizer.ConsumeIdentifier())
+ self.assertEqual(tokenizer.token, ':')
+ tokenizer.NextToken()
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertEqual((True, '# some comment # not a new comment'),
+ tokenizer.ConsumeCommentOrTrailingComment())
+ self.assertTrue(tokenizer.AtEnd())
+
+
+# Tests for pretty printer functionality.
+@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
+class PrettyPrinterTest(TextFormatBase):
+
+ def testPrettyPrintNoMatch(self, message_module):
+
+ def printer(message, indent, as_one_line):
+ del message, indent, as_one_line
+ return None
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { bb: 42 }')
+
+ def testPrettyPrintOneLine(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ return 'My lucky number is %s' % m.bb
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { My lucky number is 42 }')
+
+ def testPrettyPrintMultiLine(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ line_deliminator = (' ' if as_one_line else '\n') + ' ' * indent
+ return 'My lucky number is:%s%s' % (line_deliminator, m.bb)
+ return None
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { My lucky number is: 42 }')
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=False, message_formatter=printer),
+ 'repeated_nested_message {\n My lucky number is:\n 42\n}\n')
+
+ def testPrettyPrintEntireMessage(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.DESCRIPTOR:
+ return 'The is the message!'
+ return None
+
+ message = message_module.TestAllTypes()
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=False, message_formatter=printer),
+ 'The is the message!\n')
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'The is the message!')
+
+ def testPrettyPrintMultipleParts(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ return 'My lucky number is %s' % m.bb
+ return None
+
+ message = message_module.TestAllTypes()
+ message.optional_int32 = 61
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ msg = message.repeated_nested_message.add()
+ msg.bb = 99
+ msg = message.optional_nested_message
+ msg.bb = 1
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ ('optional_int32: 61 '
+ 'optional_nested_message { My lucky number is 1 } '
+ 'repeated_nested_message { My lucky number is 42 } '
+ 'repeated_nested_message { My lucky number is 99 }'))
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index f30ca6a8..4a76cd4e 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -45,6 +45,7 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
import six
if six.PY3:
@@ -109,6 +110,16 @@ class TypeChecker(object):
return proposed_value
+class TypeCheckerWithDefault(TypeChecker):
+
+ def __init__(self, default_value, *acceptable_types):
+ TypeChecker.__init__(self, acceptable_types)
+ self._default_value = default_value
+
+ def DefaultValue(self):
+ return self._default_value
+
+
# IntValueChecker and its subclasses perform integer type-checks
# and bounds-checks.
class IntValueChecker(object):
@@ -116,11 +127,11 @@ class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if not self._MIN <= proposed_value <= self._MAX:
+ if not self._MIN <= int(proposed_value) <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
@@ -140,11 +151,11 @@ class EnumValueChecker(object):
self._enum_type = enum_type
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if proposed_value not in self._enum_type.values_by_number:
+ if int(proposed_value) not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
@@ -212,12 +223,13 @@ _VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
- _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int),
- _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes),
+ _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault(
+ 0.0, numbers.Real),
+ _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault(
+ 0.0, numbers.Real),
+ _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault(
+ False, bool, numbers.Integral),
+ _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
}
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 9685b8b4..8b7de2e7 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -36,7 +36,7 @@
__author__ = 'bohdank@google.com (Bohdan Koval)'
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
from google.protobuf import unittest_mset_pb2
@@ -47,16 +47,23 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
-def SkipIfCppImplementation(func):
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
+# CheckUnknownField() cannot be used by the C++ implementation because
+# some protect members are called. It is not a behavior difference
+# for python and C++ implementation.
+def SkipCheckUnknownFieldIfCppImplementation(func):
return unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
- 'C++ implementation does not expose unknown fields to Python')(func)
+ 'Addtional test for pure python involved protect members')(func)
-class UnknownFieldsTest(unittest.TestCase):
+class UnknownFieldsTest(BaseTestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -73,11 +80,23 @@ class UnknownFieldsTest(unittest.TestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
- def testSerializeProto3(self):
- # Verify that proto3 doesn't preserve unknown fields.
+ def expectSerializeProto3(self, preserve):
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
- self.assertEqual(0, len(message.SerializeToString()))
+ if preserve:
+ self.assertEqual(self.all_fields_data, message.SerializeToString())
+ else:
+ self.assertEqual(0, len(message.SerializeToString()))
+
+ def testSerializeProto3(self):
+ # Verify that proto3 unknown fields behavior.
+ default_preserve = (api_implementation
+ .GetPythonProto3PreserveUnknownsDefault())
+ self.expectSerializeProto3(default_preserve)
+ api_implementation.SetPythonProto3PreserveUnknownsDefault(
+ not default_preserve)
+ self.expectSerializeProto3(not default_preserve)
+ api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@@ -119,8 +138,28 @@ class UnknownFieldsTest(unittest.TestCase):
message.ParseFromString(self.all_fields.SerializeToString())
self.assertNotEqual(self.empty_message, message)
-
-class UnknownFieldsAccessorsTest(unittest.TestCase):
+ def testDiscardUnknownFields(self):
+ self.empty_message.DiscardUnknownFields()
+ self.assertEqual(b'', self.empty_message.SerializeToString())
+ # Test message field and repeated message field.
+ message = unittest_pb2.TestAllTypes()
+ other_message = unittest_pb2.TestAllTypes()
+ other_message.optional_string = 'discard'
+ message.optional_nested_message.ParseFromString(
+ other_message.SerializeToString())
+ message.repeated_nested_message.add().ParseFromString(
+ other_message.SerializeToString())
+ self.assertNotEqual(
+ b'', message.optional_nested_message.SerializeToString())
+ self.assertNotEqual(
+ b'', message.repeated_nested_message[0].SerializeToString())
+ message.DiscardUnknownFields()
+ self.assertEqual(b'', message.optional_nested_message.SerializeToString())
+ self.assertEqual(
+ b'', message.repeated_nested_message[0].SerializeToString())
+
+
+class UnknownFieldsAccessorsTest(BaseTestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -129,60 +168,51 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- if api_implementation.Type() != 'cpp':
- # _unknown_fields is an implementation detail.
- self.unknown_fields = self.empty_message._unknown_fields
- # All the tests that use GetField() check an implementation detail of the
- # Python implementation, which stores unknown fields as serialized strings.
- # These tests are skipped by the C++ implementation: it's enough to check that
- # the message is correctly serialized.
+ # CheckUnknownField() is an additional Pure Python check which checks
+ # a detail of unknown fields. It cannot be used by the C++
+ # implementation because some protect members are called.
+ # The test is added for historical reasons. It is not necessary as
+ # serialized string is checked.
- def GetField(self, name):
+ def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
- for tag_bytes, value in self.unknown_fields:
+ for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(value, 0, len(value), self.all_fields, result_dict)
- return result_dict[field_descriptor]
-
- @SkipIfCppImplementation
- def testEnum(self):
- value = self.GetField('optional_nested_enum')
- self.assertEqual(self.all_fields.optional_nested_enum, value)
-
- @SkipIfCppImplementation
- def testRepeatedEnum(self):
- value = self.GetField('repeated_nested_enum')
- self.assertEqual(self.all_fields.repeated_nested_enum, value)
-
- @SkipIfCppImplementation
- def testVarint(self):
- value = self.GetField('optional_int32')
- self.assertEqual(self.all_fields.optional_int32, value)
-
- @SkipIfCppImplementation
- def testFixed32(self):
- value = self.GetField('optional_fixed32')
- self.assertEqual(self.all_fields.optional_fixed32, value)
-
- @SkipIfCppImplementation
- def testFixed64(self):
- value = self.GetField('optional_fixed64')
- self.assertEqual(self.all_fields.optional_fixed64, value)
-
- @SkipIfCppImplementation
- def testLengthDelimited(self):
- value = self.GetField('optional_string')
- self.assertEqual(self.all_fields.optional_string, value)
-
- @SkipIfCppImplementation
- def testGroup(self):
- value = self.GetField('optionalgroup')
- self.assertEqual(self.all_fields.optionalgroup, value)
+ self.assertEqual(expected_value, result_dict[field_descriptor])
+
+ @SkipCheckUnknownFieldIfCppImplementation
+ def testCheckUnknownFieldValue(self):
+ # Test enum.
+ self.CheckUnknownField('optional_nested_enum',
+ self.all_fields.optional_nested_enum)
+ # Test repeated enum.
+ self.CheckUnknownField('repeated_nested_enum',
+ self.all_fields.repeated_nested_enum)
+
+ # Test varint.
+ self.CheckUnknownField('optional_int32',
+ self.all_fields.optional_int32)
+ # Test fixed32.
+ self.CheckUnknownField('optional_fixed32',
+ self.all_fields.optional_fixed32)
+
+ # Test fixed64.
+ self.CheckUnknownField('optional_fixed64',
+ self.all_fields.optional_fixed64)
+
+ # Test lengthd elimited.
+ self.CheckUnknownField('optional_string',
+ self.all_fields.optional_string)
+
+ # Test group.
+ self.CheckUnknownField('optionalgroup',
+ self.all_fields.optionalgroup)
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@@ -221,45 +251,44 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.assertEqual(message.SerializeToString(), self.all_fields_data)
-class UnknownEnumValuesTest(unittest.TestCase):
+class UnknownEnumValuesTest(BaseTestCase):
def setUp(self):
self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
self.message = missing_enum_values_pb2.TestEnumValues()
+ # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
self.message.optional_nested_enum = (
- missing_enum_values_pb2.TestEnumValues.ZERO)
+ missing_enum_values_pb2.TestEnumValues.ZERO)
self.message.repeated_nested_enum.extend([
- missing_enum_values_pb2.TestEnumValues.ZERO,
- missing_enum_values_pb2.TestEnumValues.ONE,
- ])
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
self.message.packed_nested_enum.extend([
- missing_enum_values_pb2.TestEnumValues.ZERO,
- missing_enum_values_pb2.TestEnumValues.ONE,
- ])
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
self.message_data = self.message.SerializeToString()
self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
self.missing_message.ParseFromString(self.message_data)
- if api_implementation.Type() != 'cpp':
- # _unknown_fields is an implementation detail.
- self.unknown_fields = self.missing_message._unknown_fields
- # All the tests that use GetField() check an implementation detail of the
- # Python implementation, which stores unknown fields as serialized strings.
- # These tests are skipped by the C++ implementation: it's enough to check that
- # the message is correctly serialized.
+ # CheckUnknownField() is an additional Pure Python check which checks
+ # a detail of unknown fields. It cannot be used by the C++
+ # implementation because some protect members are called.
+ # The test is added for historical reasons. It is not necessary as
+ # serialized string is checked.
- def GetField(self, name):
+ def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
- for tag_bytes, value in self.unknown_fields:
+ for tag_bytes, value in self.missing_message._unknown_fields:
if tag_bytes == field_tag:
decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
- tag_bytes][0]
+ tag_bytes][0]
decoder(value, 0, len(value), self.message, result_dict)
- return result_dict[field_descriptor]
+ self.assertEqual(expected_value, result_dict[field_descriptor])
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@@ -274,21 +303,28 @@ class UnknownEnumValuesTest(unittest.TestCase):
# default value.
self.assertEqual(missing.optional_nested_enum, 0)
- @SkipIfCppImplementation
def testUnknownEnumValue(self):
self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
- value = self.GetField('optional_nested_enum')
- self.assertEqual(self.message.optional_nested_enum, value)
+ self.assertEqual(self.missing_message.optional_nested_enum, 2)
+ # Clear does not do anything.
+ serialized = self.missing_message.SerializeToString()
+ self.missing_message.ClearField('optional_nested_enum')
+ self.assertEqual(self.missing_message.SerializeToString(), serialized)
- @SkipIfCppImplementation
def testUnknownRepeatedEnumValue(self):
- value = self.GetField('repeated_nested_enum')
- self.assertEqual(self.message.repeated_nested_enum, value)
+ self.assertEqual([], self.missing_message.repeated_nested_enum)
- @SkipIfCppImplementation
def testUnknownPackedEnumValue(self):
- value = self.GetField('packed_nested_enum')
- self.assertEqual(self.message.packed_nested_enum, value)
+ self.assertEqual([], self.missing_message.packed_nested_enum)
+
+ @SkipCheckUnknownFieldIfCppImplementation
+ def testCheckUnknownFieldValueForEnum(self):
+ self.CheckUnknownField('optional_nested_enum',
+ self.message.optional_nested_enum)
+ self.CheckUnknownField('repeated_nested_enum',
+ self.message.repeated_nested_enum)
+ self.CheckUnknownField('packed_nested_enum',
+ self.message.packed_nested_enum)
def testRoundTrip(self):
new_message = missing_enum_values_pb2.TestEnumValues()
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index d35fcc5f..37a65cfa 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -40,6 +40,7 @@ This files defines well known classes which need extra maintenance including:
__author__ = 'jieluo@google.com (Jie Luo)'
+import collections
from datetime import datetime
from datetime import timedelta
import six
@@ -53,6 +54,7 @@ _NANOS_PER_MICROSECOND = 1000
_MILLIS_PER_SECOND = 1000
_MICROS_PER_SECOND = 1000000
_SECONDS_PER_DAY = 24 * 3600
+_DURATION_SECONDS_MAX = 315576000000
class Error(Exception):
@@ -66,13 +68,14 @@ class ParseError(Error):
class Any(object):
"""Class for Any Message type."""
- def Pack(self, msg, type_url_prefix='type.googleapis.com/'):
+ def Pack(self, msg, type_url_prefix='type.googleapis.com/',
+ deterministic=None):
"""Packs the specified message into current Any message."""
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
else:
self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
- self.value = msg.SerializeToString()
+ self.value = msg.SerializeToString(deterministic=deterministic)
def Unpack(self, msg):
"""Unpacks the current Any message into specified message."""
@@ -82,10 +85,14 @@ class Any(object):
msg.ParseFromString(self.value)
return True
+ def TypeName(self):
+ """Returns the protobuf type name of the inner message."""
+ # Only last part is to be used: b/25630112
+ return self.type_url.split('/')[-1]
+
def Is(self, descriptor):
"""Checks if this Any represents the given protobuf type."""
- # Only last part is to be used: b/25630112
- return self.type_url.split('/')[-1] == descriptor.full_name
+ return self.TypeName() == descriptor.full_name
class Timestamp(object):
@@ -243,6 +250,7 @@ class Duration(object):
represent the exact Duration value. For example: "1s", "1.010s",
"1.000000100s", "-3.100s"
"""
+ _CheckDurationValid(self.seconds, self.nanos)
if self.seconds < 0 or self.nanos < 0:
result = '-'
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
@@ -282,14 +290,17 @@ class Duration(object):
try:
pos = value.find('.')
if pos == -1:
- self.seconds = int(value[:-1])
- self.nanos = 0
+ seconds = int(value[:-1])
+ nanos = 0
else:
- self.seconds = int(value[:pos])
+ seconds = int(value[:pos])
if value[0] == '-':
- self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
else:
- self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ _CheckDurationValid(seconds, nanos)
+ self.seconds = seconds
+ self.nanos = nanos
except ValueError:
raise ParseError(
'Couldn\'t parse duration: {0}.'.format(value))
@@ -341,12 +352,12 @@ class Duration(object):
self.nanos, _NANOS_PER_MICROSECOND))
def FromTimedelta(self, td):
- """Convertd timedelta to Duration."""
+ """Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
def _NormalizeDuration(self, seconds, nanos):
- """Set Duration by seconds and nonas."""
+ """Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
if seconds < 0 and nanos > 0:
seconds += 1
@@ -355,6 +366,20 @@ class Duration(object):
self.nanos = nanos
+def _CheckDurationValid(seconds, nanos):
+ if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
+ raise Error(
+ 'Duration is not valid: Seconds {0} must be in range '
+ '[-315576000000, 315576000000].'.format(seconds))
+ if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
+ raise Error(
+ 'Duration is not valid: Nanos {0} must be in range '
+ '[-999999999, 999999999].'.format(nanos))
+ if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
+ raise Error(
+ 'Duration is not valid: Sign mismatch.')
+
+
def _RoundTowardZero(value, divider):
"""Truncates the remainder part after division."""
# For some languanges, the sign of the remainder is implementation
@@ -375,13 +400,16 @@ class FieldMask(object):
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
- return ','.join(self.paths)
+ camelcase_paths = []
+ for path in self.paths:
+ camelcase_paths.append(_SnakeCaseToCamelCase(path))
+ return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
self.Clear()
for path in value.split(','):
- self.paths.append(path)
+ self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
@@ -450,7 +478,7 @@ def _IsValidPath(message_descriptor, path):
parts = path.split('.')
last = parts.pop()
for name in parts:
- field = message_descriptor.fields_by_name[name]
+ field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
@@ -468,6 +496,48 @@ def _CheckFieldMaskMessage(message):
message_descriptor.full_name))
+def _SnakeCaseToCamelCase(path_name):
+ """Converts a path name from snake_case to camelCase."""
+ result = []
+ after_underscore = False
+ for c in path_name:
+ if c.isupper():
+ raise Error('Fail to print FieldMask to Json string: Path name '
+ '{0} must not contain uppercase letters.'.format(path_name))
+ if after_underscore:
+ if c.islower():
+ result.append(c.upper())
+ after_underscore = False
+ else:
+ raise Error('Fail to print FieldMask to Json string: The '
+ 'character after a "_" must be a lowercase letter '
+ 'in path name {0}.'.format(path_name))
+ elif c == '_':
+ after_underscore = True
+ else:
+ result += c
+
+ if after_underscore:
+ raise Error('Fail to print FieldMask to Json string: Trailing "_" '
+ 'in path name {0}.'.format(path_name))
+ return ''.join(result)
+
+
+def _CamelCaseToSnakeCase(path_name):
+ """Converts a field name from camelCase to snake_case."""
+ result = []
+ for c in path_name:
+ if c == '_':
+ raise ParseError('Fail to parse FieldMask: Path name '
+ '{0} must not contain "_"s.'.format(path_name))
+ if c.isupper():
+ result += '_'
+ result += c.lower()
+ else:
+ result += c
+ return ''.join(result)
+
+
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
@@ -582,9 +652,10 @@ def _MergeMessage(
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
- _MergeMessage(
- child, getattr(source, name), getattr(destination, name),
- replace_message, replace_repeated)
+ if source.HasField(name):
+ _MergeMessage(
+ child, getattr(source, name), getattr(destination, name),
+ replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
@@ -633,6 +704,12 @@ def _SetStructValue(struct_value, value):
struct_value.string_value = value
elif isinstance(value, _INT_OR_FLOAT):
struct_value.number_value = value
+ elif isinstance(value, dict):
+ struct_value.struct_value.Clear()
+ struct_value.struct_value.update(value)
+ elif isinstance(value, list):
+ struct_value.list_value.Clear()
+ struct_value.list_value.extend(value)
else:
raise ValueError('Unexpected type')
@@ -663,18 +740,49 @@ class Struct(object):
def __getitem__(self, key):
return _GetStructValue(self.fields[key])
+ def __contains__(self, item):
+ return item in self.fields
+
def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)
+ def __delitem__(self, key):
+ del self.fields[key]
+
+ def __len__(self):
+ return len(self.fields)
+
+ def __iter__(self):
+ return iter(self.fields)
+
+ def keys(self): # pylint: disable=invalid-name
+ return self.fields.keys()
+
+ def values(self): # pylint: disable=invalid-name
+ return [self[key] for key in self]
+
+ def items(self): # pylint: disable=invalid-name
+ return [(key, self[key]) for key in self]
+
def get_or_create_list(self, key):
"""Returns a list for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('list_value'):
+ # Clear will mark list_value modified which will indeed create a list.
+ self.fields[key].list_value.Clear()
return self.fields[key].list_value
def get_or_create_struct(self, key):
"""Returns a struct for this key, creating if it didn't exist already."""
+ if not self.fields[key].HasField('struct_value'):
+ # Clear will mark struct_value modified which will indeed create a struct.
+ self.fields[key].struct_value.Clear()
return self.fields[key].struct_value
- # TODO(haberman): allow constructing/merging from dict.
+ def update(self, dictionary): # pylint: disable=invalid-name
+ for key, value in dictionary.items():
+ _SetStructValue(self.fields[key], value)
+
+collections.MutableMapping.register(Struct)
class ListValue(object):
@@ -697,17 +805,28 @@ class ListValue(object):
def __setitem__(self, index, value):
_SetStructValue(self.values.__getitem__(index), value)
+ def __delitem__(self, key):
+ del self.values[key]
+
def items(self):
for i in range(len(self)):
yield self[i]
def add_struct(self):
"""Appends and returns a struct value as the next value in the list."""
- return self.values.add().struct_value
+ struct_value = self.values.add().struct_value
+ # Clear will mark struct_value modified which will indeed create a struct.
+ struct_value.Clear()
+ return struct_value
def add_list(self):
"""Appends and returns a list value as the next value in the list."""
- return self.values.add().list_value
+ list_value = self.values.add().list_value
+ # Clear will mark list_value modified which will indeed create a list.
+ list_value.Clear()
+ return list_value
+
+collections.MutableSequence.register(ListValue)
WKTBASES = {
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 6acbee22..965940b2 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -34,10 +34,11 @@
__author__ = 'jieluo@google.com (Jie Luo)'
+import collections
from datetime import datetime
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
@@ -100,11 +101,15 @@ class TimeUtilTest(TimeUtilTestBase):
message.FromJsonString('1970-01-01T00:00:00.1Z')
self.assertEqual(0, message.seconds)
self.assertEqual(100000000, message.nanos)
- # Parsing accpets offsets.
+ # Parsing accepts offsets.
message.FromJsonString('1970-01-01T00:00:00-08:00')
self.assertEqual(8 * 3600, message.seconds)
self.assertEqual(0, message.nanos)
+ # It is not easy to check with current time. For test coverage only.
+ message.GetCurrentTime()
+ self.assertNotEqual(8 * 3600, message.seconds)
+
def testDurationSerializeAndParse(self):
message = duration_pb2.Duration()
# Generated output should contain 3, 6, or 9 fractional digits.
@@ -268,6 +273,17 @@ class TimeUtilTest(TimeUtilTestBase):
def testInvalidTimestamp(self):
message = timestamp_pb2.Timestamp()
self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Failed to parse timestamp: missing valid timezone offset.',
+ message.FromJsonString,
+ '')
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Failed to parse timestamp: invalid trailing data '
+ '1970-01-01T00:00:01Ztrail.',
+ message.FromJsonString,
+ '1970-01-01T00:00:01Ztrail')
+ self.assertRaisesRegexp(
ValueError,
'time data \'10000-01-01T00:00:00\' does not match'
' format \'%Y-%m-%dT%H:%M:%S\'',
@@ -284,7 +300,7 @@ class TimeUtilTest(TimeUtilTestBase):
'1972-01-01T01:00:00.01+08',)
self.assertRaisesRegexp(
ValueError,
- 'year is out of range',
+ 'year (0 )?is out of range',
message.FromJsonString,
'0000-01-01T00:00:00Z')
message.seconds = 253402300800
@@ -303,6 +319,38 @@ class TimeUtilTest(TimeUtilTestBase):
well_known_types.ParseError,
'Couldn\'t parse duration: 1...2s.',
message.FromJsonString, '1...2s')
+ text = '-315576000001.000000000s'
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds -315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.FromJsonString, text)
+ text = '315576000001.000000000s'
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds 315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.FromJsonString, text)
+ message.seconds = -315576000001
+ message.nanos = 0
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds -315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.ToJsonString)
+ message.seconds = 0
+ message.nanos = 999999999 + 1
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Nanos 1000000000 must be in range'
+ r' \[-999999999\, 999999999\].',
+ message.ToJsonString)
+ message.seconds = -1
+ message.nanos = 1
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Sign mismatch.',
+ message.ToJsonString)
class FieldMaskTest(unittest.TestCase):
@@ -322,6 +370,20 @@ class FieldMaskTest(unittest.TestCase):
mask.FromJsonString('foo,bar')
self.assertEqual(['foo', 'bar'], mask.paths)
+ # Test camel case
+ mask.Clear()
+ mask.paths.append('foo_bar')
+ self.assertEqual('fooBar', mask.ToJsonString())
+ mask.paths.append('bar_quz')
+ self.assertEqual('fooBar,barQuz', mask.ToJsonString())
+
+ mask.FromJsonString('')
+ self.assertEqual('', mask.ToJsonString())
+ mask.FromJsonString('fooBar')
+ self.assertEqual(['foo_bar'], mask.paths)
+ mask.FromJsonString('fooBar,barQuz')
+ self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
+
def testDescriptorToFieldMask(self):
mask = field_mask_pb2.FieldMask()
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -330,10 +392,37 @@ class FieldMaskTest(unittest.TestCase):
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
for field in msg_descriptor.fields:
self.assertTrue(field.name in mask.paths)
+
+ def testIsValidForDescriptor(self):
+ msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ # Empty mask
+ mask = field_mask_pb2.FieldMask()
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # All fields from descriptor
+ mask.AllFieldsFromDescriptor(msg_descriptor)
+ self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # Child under optional message
mask.paths.append('optional_nested_message.bb')
self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
+ # Repeated field is only allowed in the last position of path
mask.paths.append('repeated_nested_message.bb')
self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid top level field
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('xxx')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in root
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('xxx.zzz')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in internal node
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('optional_nested_message.xxx.zzz')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
+ # Invalid field in leaf
+ mask = field_mask_pb2.FieldMask()
+ mask.paths.append('optional_nested_message.xxx')
+ self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
def testCanonicalFrom(self):
mask = field_mask_pb2.FieldMask()
@@ -389,6 +478,9 @@ class FieldMaskTest(unittest.TestCase):
mask2.FromJsonString('foo.bar,bar')
out_mask.Union(mask1, mask2)
self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
+ src = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError):
+ out_mask.Union(src, mask2)
def testIntersect(self):
mask1 = field_mask_pb2.FieldMask()
@@ -502,22 +594,98 @@ class FieldMaskTest(unittest.TestCase):
nested_src.payload.repeated_int32.append(1234)
nested_dst.payload.repeated_int32.append(5678)
# Repeated fields will be appended by default.
- mask.FromJsonString('payload.repeated_int32')
+ mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(2, len(nested_dst.payload.repeated_int32))
self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
# Change the behavior to replace repeated fields.
- mask.FromJsonString('payload.repeated_int32')
+ mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst, False, True)
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+ # Test Merge oneof field.
+ new_msg = unittest_pb2.TestOneof2()
+ dst = unittest_pb2.TestOneof2()
+ dst.foo_message.qux_int = 1
+ mask = field_mask_pb2.FieldMask()
+ mask.FromJsonString('fooMessage,fooLazyMessage.quxInt')
+ mask.MergeMessage(new_msg, dst)
+ self.assertTrue(dst.HasField('foo_message'))
+ self.assertFalse(dst.HasField('foo_lazy_message'))
+
+ def testMergeErrors(self):
+ src = unittest_pb2.TestAllTypes()
+ dst = unittest_pb2.TestAllTypes()
+ mask = field_mask_pb2.FieldMask()
+ test_util.SetAllFields(src)
+ mask.FromJsonString('optionalInt32.field')
+ with self.assertRaises(ValueError) as e:
+ mask.MergeMessage(src, dst)
+ self.assertEqual('Error: Field optional_int32 in message '
+ 'protobuf_unittest.TestAllTypes is not a singular '
+ 'message field and cannot have sub-fields.',
+ str(e.exception))
+
+ def testSnakeCaseToCamelCase(self):
+ self.assertEqual('fooBar',
+ well_known_types._SnakeCaseToCamelCase('foo_bar'))
+ self.assertEqual('FooBar',
+ well_known_types._SnakeCaseToCamelCase('_foo_bar'))
+ self.assertEqual('foo3Bar',
+ well_known_types._SnakeCaseToCamelCase('foo3_bar'))
+
+ # No uppercase letter is allowed.
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: Path name Foo must '
+ 'not contain uppercase letters.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'Foo')
+ # Any character after a "_" must be a lowercase letter.
+ # 1. "_" cannot be followed by another "_".
+ # 2. "_" cannot be followed by a digit.
+ # 3. "_" cannot appear as the last character.
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: The character after a '
+ '"_" must be a lowercase letter in path name foo__bar.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo__bar')
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: The character after a '
+ '"_" must be a lowercase letter in path name foo_3bar.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo_3bar')
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: Trailing "_" in path '
+ 'name foo_bar_.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo_bar_')
+
+ def testCamelCaseToSnakeCase(self):
+ self.assertEqual('foo_bar',
+ well_known_types._CamelCaseToSnakeCase('fooBar'))
+ self.assertEqual('_foo_bar',
+ well_known_types._CamelCaseToSnakeCase('FooBar'))
+ self.assertEqual('foo3_bar',
+ well_known_types._CamelCaseToSnakeCase('foo3Bar'))
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
+ well_known_types._CamelCaseToSnakeCase,
+ 'foo_bar')
+
class StructTest(unittest.TestCase):
def testStruct(self):
struct = struct_pb2.Struct()
+ self.assertIsInstance(struct, collections.Mapping)
+ self.assertEqual(0, len(struct))
struct_class = struct.__class__
struct['key1'] = 5
@@ -525,56 +693,157 @@ class StructTest(unittest.TestCase):
struct['key3'] = True
struct.get_or_create_struct('key4')['subkey'] = 11.0
struct_list = struct.get_or_create_list('key5')
+ self.assertIsInstance(struct_list, collections.Sequence)
struct_list.extend([6, 'seven', True, False, None])
struct_list.add_struct()['subkey2'] = 9
+ struct['key6'] = {'subkey': {}}
+ struct['key7'] = [2, False]
+ self.assertEqual(7, len(struct))
self.assertTrue(isinstance(struct, well_known_types.Struct))
- self.assertEquals(5, struct['key1'])
- self.assertEquals('abc', struct['key2'])
+ self.assertEqual(5, struct['key1'])
+ self.assertEqual('abc', struct['key2'])
self.assertIs(True, struct['key3'])
- self.assertEquals(11, struct['key4']['subkey'])
+ self.assertEqual(11, struct['key4']['subkey'])
inner_struct = struct_class()
inner_struct['subkey2'] = 9
- self.assertEquals([6, 'seven', True, False, None, inner_struct],
- list(struct['key5'].items()))
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct['key5'].items()))
+ self.assertEqual({}, dict(struct['key6']['subkey'].fields))
+ self.assertEqual([2, False], list(struct['key7'].items()))
serialized = struct.SerializeToString()
-
struct2 = struct_pb2.Struct()
struct2.ParseFromString(serialized)
- self.assertEquals(struct, struct2)
+ self.assertEqual(struct, struct2)
+ for key, value in struct.items():
+ self.assertIn(key, struct)
+ self.assertIn(key, struct2)
+ self.assertEqual(value, struct2[key])
+
+ self.assertEqual(7, len(struct.keys()))
+ self.assertEqual(7, len(struct.values()))
+ for key in struct.keys():
+ self.assertIn(key, struct)
+ self.assertIn(key, struct2)
+ self.assertEqual(struct[key], struct2[key])
+
+ item = (next(iter(struct.keys())), next(iter(struct.values())))
+ self.assertEqual(item, next(iter(struct.items())))
self.assertTrue(isinstance(struct2, well_known_types.Struct))
- self.assertEquals(5, struct2['key1'])
- self.assertEquals('abc', struct2['key2'])
+ self.assertEqual(5, struct2['key1'])
+ self.assertEqual('abc', struct2['key2'])
self.assertIs(True, struct2['key3'])
- self.assertEquals(11, struct2['key4']['subkey'])
- self.assertEquals([6, 'seven', True, False, None, inner_struct],
- list(struct2['key5'].items()))
+ self.assertEqual(11, struct2['key4']['subkey'])
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct2['key5'].items()))
struct_list = struct2['key5']
- self.assertEquals(6, struct_list[0])
- self.assertEquals('seven', struct_list[1])
- self.assertEquals(True, struct_list[2])
- self.assertEquals(False, struct_list[3])
- self.assertEquals(None, struct_list[4])
- self.assertEquals(inner_struct, struct_list[5])
+ self.assertEqual(6, struct_list[0])
+ self.assertEqual('seven', struct_list[1])
+ self.assertEqual(True, struct_list[2])
+ self.assertEqual(False, struct_list[3])
+ self.assertEqual(None, struct_list[4])
+ self.assertEqual(inner_struct, struct_list[5])
struct_list[1] = 7
- self.assertEquals(7, struct_list[1])
+ self.assertEqual(7, struct_list[1])
struct_list.add_list().extend([1, 'two', True, False, None])
- self.assertEquals([1, 'two', True, False, None],
- list(struct_list[6].items()))
+ self.assertEqual([1, 'two', True, False, None],
+ list(struct_list[6].items()))
+ struct_list.extend([{'nested_struct': 30}, ['nested_list', 99], {}, []])
+ self.assertEqual(11, len(struct_list.values))
+ self.assertEqual(30, struct_list[7]['nested_struct'])
+ self.assertEqual('nested_list', struct_list[8][0])
+ self.assertEqual(99, struct_list[8][1])
+ self.assertEqual({}, dict(struct_list[9].fields))
+ self.assertEqual([], list(struct_list[10].items()))
+ struct_list[0] = {'replace': 'set'}
+ struct_list[1] = ['replace', 'set']
+ self.assertEqual('set', struct_list[0]['replace'])
+ self.assertEqual(['replace', 'set'], list(struct_list[1].items()))
text_serialized = str(struct)
struct3 = struct_pb2.Struct()
text_format.Merge(text_serialized, struct3)
- self.assertEquals(struct, struct3)
+ self.assertEqual(struct, struct3)
struct.get_or_create_struct('key3')['replace'] = 12
- self.assertEquals(12, struct['key3']['replace'])
+ self.assertEqual(12, struct['key3']['replace'])
+
+ # Tests empty list.
+ struct.get_or_create_list('empty_list')
+ empty_list = struct['empty_list']
+ self.assertEqual([], list(empty_list.items()))
+ list2 = struct_pb2.ListValue()
+ list2.add_list()
+ empty_list = list2[0]
+ self.assertEqual([], list(empty_list.items()))
+
+ # Tests empty struct.
+ struct.get_or_create_struct('empty_struct')
+ empty_struct = struct['empty_struct']
+ self.assertEqual({}, dict(empty_struct.fields))
+ list2.add_struct()
+ empty_struct = list2[1]
+ self.assertEqual({}, dict(empty_struct.fields))
+
+ self.assertEqual(9, len(struct))
+ del struct['key3']
+ del struct['key4']
+ self.assertEqual(7, len(struct))
+ self.assertEqual(6, len(struct['key5']))
+ del struct['key5'][1]
+ self.assertEqual(5, len(struct['key5']))
+ self.assertEqual([6, True, False, None, inner_struct],
+ list(struct['key5'].items()))
+
+ def testMergeFrom(self):
+ struct = struct_pb2.Struct()
+ struct_class = struct.__class__
+
+ dictionary = {
+ 'key1': 5,
+ 'key2': 'abc',
+ 'key3': True,
+ 'key4': {'subkey': 11.0},
+ 'key5': [6, 'seven', True, False, None, {'subkey2': 9}],
+ 'key6': [['nested_list', True]],
+ 'empty_struct': {},
+ 'empty_list': []
+ }
+ struct.update(dictionary)
+ self.assertEqual(5, struct['key1'])
+ self.assertEqual('abc', struct['key2'])
+ self.assertIs(True, struct['key3'])
+ self.assertEqual(11, struct['key4']['subkey'])
+ inner_struct = struct_class()
+ inner_struct['subkey2'] = 9
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct['key5'].items()))
+ self.assertEqual(2, len(struct['key6'][0].values))
+ self.assertEqual('nested_list', struct['key6'][0][0])
+ self.assertEqual(True, struct['key6'][0][1])
+ empty_list = struct['empty_list']
+ self.assertEqual([], list(empty_list.items()))
+ empty_struct = struct['empty_struct']
+ self.assertEqual({}, dict(empty_struct.fields))
+
+ # According to documentation: "When parsing from the wire or when merging,
+ # if there are duplicate map keys the last key seen is used".
+ duplicate = {
+ 'key4': {'replace': 20},
+ 'key5': [[False, 5]]
+ }
+ struct.update(duplicate)
+ self.assertEqual(1, len(struct['key4'].fields))
+ self.assertEqual(20, struct['key4']['replace'])
+ self.assertEqual(1, len(struct['key5'].values))
+ self.assertEqual(False, struct['key5'][0][0])
+ self.assertEqual(5, struct['key5'][0][1])
class AnyTest(unittest.TestCase):
@@ -610,6 +879,14 @@ class AnyTest(unittest.TestCase):
raise AttributeError('%s should not have Pack method.' %
msg_descriptor.full_name)
+ def testMessageName(self):
+ # Creates and sets message.
+ submessage = any_test_pb2.TestAny()
+ submessage.int_value = 12345
+ msg = any_pb2.Any()
+ msg.Pack(submessage)
+ self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny')
+
def testPackWithCustomTypeUrl(self):
submessage = any_test_pb2.TestAny()
submessage.int_value = 12345
@@ -631,6 +908,20 @@ class AnyTest(unittest.TestCase):
self.assertTrue(msg.Unpack(unpacked_message))
self.assertEqual(submessage, unpacked_message)
+ def testPackDeterministic(self):
+ submessage = any_test_pb2.TestAny()
+ for i in range(10):
+ submessage.map_value[str(i)] = i * 2
+ msg = any_pb2.Any()
+ msg.Pack(submessage, deterministic=True)
+ serialized = msg.SerializeToString(deterministic=True)
+ golden = (b'\n4type.googleapis.com/google.protobuf.internal.TestAny\x12F'
+ b'\x1a\x05\n\x010\x10\x00\x1a\x05\n\x011\x10\x02\x1a\x05\n\x01'
+ b'2\x10\x04\x1a\x05\n\x013\x10\x06\x1a\x05\n\x014\x10\x08\x1a'
+ b'\x05\n\x015\x10\n\x1a\x05\n\x016\x10\x0c\x1a\x05\n\x017\x10'
+ b'\x0e\x1a\x05\n\x018\x10\x10\x1a\x05\n\x019\x10\x12')
+ self.assertEqual(golden, serialized)
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
index f659d18e..da120f33 100755
--- a/python/google/protobuf/internal/wire_format_test.py
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -35,9 +35,10 @@
__author__ = 'robinson@google.com (Will Robinson)'
try:
- import unittest2 as unittest
+ import unittest2 as unittest #PY26
except ImportError:
import unittest
+
from google.protobuf import message
from google.protobuf.internal import wire_format
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index 23382bdb..8d338d3e 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -42,15 +42,28 @@ Simple usage example:
__author__ = 'jieluo@google.com (Jie Luo)'
+# pylint: disable=g-statement-before-imports,g-import-not-at-top
+try:
+ from collections import OrderedDict
+except ImportError:
+ from ordereddict import OrderedDict # PY26
+# pylint: enable=g-statement-before-imports,g-import-not-at-top
+
import base64
import json
import math
-import six
+
+from operator import methodcaller
+
+import re
import sys
+import six
+
from google.protobuf import descriptor
from google.protobuf import symbol_database
+
_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S'
_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32,
descriptor.FieldDescriptor.CPPTYPE_UINT32,
@@ -64,6 +77,12 @@ _INFINITY = 'Infinity'
_NEG_INFINITY = '-Infinity'
_NAN = 'NaN'
+_UNPAIRED_SURROGATE_PATTERN = re.compile(six.u(
+ r'[\ud800-\udbff](?![\udc00-\udfff])|(?<![\ud800-\udbff])[\udc00-\udfff]'
+))
+
+_VALID_EXTENSION_NAME = re.compile(r'\[[a-zA-Z0-9\._]*\]$')
+
class Error(Exception):
"""Top-level module error for json_format."""
@@ -77,7 +96,12 @@ class ParseError(Error):
"""Thrown in case of parsing error."""
-def MessageToJson(message, including_default_value_fields=False):
+def MessageToJson(message,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False,
+ indent=2,
+ sort_keys=False,
+ use_integers_for_enums=False):
"""Converts protobuf message to JSON format.
Args:
@@ -86,26 +110,50 @@ def MessageToJson(message, including_default_value_fields=False):
repeated fields, and map fields will always be serialized. If
False, only serialize non-empty fields. Singular message fields
and oneof fields are not affected by this option.
+ preserving_proto_field_name: If True, use the original proto field
+ names as defined in the .proto file. If False, convert the field
+ names to lowerCamelCase.
+ indent: The JSON object will be pretty-printed with this indent level.
+ An indent level of 0 or negative will only insert newlines.
+ sort_keys: If True, then the output will be sorted by field names.
+ use_integers_for_enums: If true, print integers instead of enum names.
Returns:
A string containing the JSON formatted protocol buffer message.
"""
- js = _MessageToJsonObject(message, including_default_value_fields)
- return json.dumps(js, indent=2)
+ printer = _Printer(including_default_value_fields,
+ preserving_proto_field_name,
+ use_integers_for_enums)
+ return printer.ToJsonString(message, indent, sort_keys)
-def _MessageToJsonObject(message, including_default_value_fields):
- """Converts message to an object according to Proto3 JSON Specification."""
- message_descriptor = message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- return _WrapperMessageToJsonObject(message)
- if full_name in _WKTJSONMETHODS:
- return _WKTJSONMETHODS[full_name][0](
- message, including_default_value_fields)
- js = {}
- return _RegularMessageToJsonObject(
- message, js, including_default_value_fields)
+def MessageToDict(message,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False,
+ use_integers_for_enums=False):
+ """Converts protobuf message to a dictionary.
+
+ When the dictionary is encoded to JSON, it conforms to proto3 JSON spec.
+
+ Args:
+ message: The protocol buffers message instance to serialize.
+ including_default_value_fields: If True, singular primitive fields,
+ repeated fields, and map fields will always be serialized. If
+ False, only serialize non-empty fields. Singular message fields
+ and oneof fields are not affected by this option.
+ preserving_proto_field_name: If True, use the original proto field
+ names as defined in the .proto file. If False, convert the field
+ names to lowerCamelCase.
+ use_integers_for_enums: If true, print integers instead of enum names.
+
+ Returns:
+ A dict representation of the protocol buffer message.
+ """
+ printer = _Printer(including_default_value_fields,
+ preserving_proto_field_name,
+ use_integers_for_enums)
+ # pylint: disable=protected-access
+ return printer._MessageToJsonObject(message)
def _IsMapEntry(field):
@@ -114,114 +162,208 @@ def _IsMapEntry(field):
field.message_type.GetOptions().map_entry)
-def _RegularMessageToJsonObject(message, js, including_default_value_fields):
- """Converts normal message according to Proto3 JSON Specification."""
- fields = message.ListFields()
- include_default = including_default_value_fields
+class _Printer(object):
+ """JSON format printer for protocol message."""
+
+ def __init__(self,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False,
+ use_integers_for_enums=False):
+ self.including_default_value_fields = including_default_value_fields
+ self.preserving_proto_field_name = preserving_proto_field_name
+ self.use_integers_for_enums = use_integers_for_enums
+
+ def ToJsonString(self, message, indent, sort_keys):
+ js = self._MessageToJsonObject(message)
+ return json.dumps(js, indent=indent, sort_keys=sort_keys)
+
+ def _MessageToJsonObject(self, message):
+ """Converts message to an object according to Proto3 JSON Specification."""
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ return self._WrapperMessageToJsonObject(message)
+ if full_name in _WKTJSONMETHODS:
+ return methodcaller(_WKTJSONMETHODS[full_name][0], message)(self)
+ js = {}
+ return self._RegularMessageToJsonObject(message, js)
+
+ def _RegularMessageToJsonObject(self, message, js):
+ """Converts normal message according to Proto3 JSON Specification."""
+ fields = message.ListFields()
- try:
- for field, value in fields:
- name = field.camelcase_name
- if _IsMapEntry(field):
- # Convert a map field.
- v_field = field.message_type.fields_by_name['value']
- js_map = {}
- for key in value:
- if isinstance(key, bool):
- if key:
- recorded_key = 'true'
+ try:
+ for field, value in fields:
+ if self.preserving_proto_field_name:
+ name = field.name
+ else:
+ name = field.json_name
+ if _IsMapEntry(field):
+ # Convert a map field.
+ v_field = field.message_type.fields_by_name['value']
+ js_map = {}
+ for key in value:
+ if isinstance(key, bool):
+ if key:
+ recorded_key = 'true'
+ else:
+ recorded_key = 'false'
else:
- recorded_key = 'false'
+ recorded_key = key
+ js_map[recorded_key] = self._FieldToJsonObject(
+ v_field, value[key])
+ js[name] = js_map
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ # Convert a repeated field.
+ js[name] = [self._FieldToJsonObject(field, k)
+ for k in value]
+ elif field.is_extension:
+ f = field
+ if (f.containing_type.GetOptions().message_set_wire_format and
+ f.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ f.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
+ f = f.message_type
+ name = '[%s.%s]' % (f.full_name, name)
+ js[name] = self._FieldToJsonObject(field, value)
+ else:
+ js[name] = self._FieldToJsonObject(field, value)
+
+ # Serialize default value if including_default_value_fields is True.
+ if self.including_default_value_fields:
+ message_descriptor = message.DESCRIPTOR
+ for field in message_descriptor.fields:
+ # Singular message fields and oneof fields will not be affected.
+ if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
+ field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
+ field.containing_oneof):
+ continue
+ if self.preserving_proto_field_name:
+ name = field.name
+ else:
+ name = field.json_name
+ if name in js:
+ # Skip the field which has been serailized already.
+ continue
+ if _IsMapEntry(field):
+ js[name] = {}
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ js[name] = []
else:
- recorded_key = key
- js_map[recorded_key] = _FieldToJsonObject(
- v_field, value[key], including_default_value_fields)
- js[name] = js_map
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- # Convert a repeated field.
- js[name] = [_FieldToJsonObject(field, k, include_default)
- for k in value]
+ js[name] = self._FieldToJsonObject(field, field.default_value)
+
+ except ValueError as e:
+ raise SerializeToJsonError(
+ 'Failed to serialize {0} field: {1}.'.format(field.name, e))
+
+ return js
+
+ def _FieldToJsonObject(self, field, value):
+ """Converts field value according to Proto3 JSON Specification."""
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ return self._MessageToJsonObject(value)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ if self.use_integers_for_enums:
+ return value
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ return enum_value.name
else:
- js[name] = _FieldToJsonObject(field, value, include_default)
-
- # Serialize default value if including_default_value_fields is True.
- if including_default_value_fields:
- message_descriptor = message.DESCRIPTOR
- for field in message_descriptor.fields:
- # Singular message fields and oneof fields will not be affected.
- if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and
- field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
- field.containing_oneof):
- continue
- name = field.camelcase_name
- if name in js:
- # Skip the field which has been serailized already.
- continue
- if _IsMapEntry(field):
- js[name] = {}
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- js[name] = []
+ if field.file.syntax == 'proto3':
+ return value
+ raise SerializeToJsonError('Enum field contains an integer value '
+ 'which can not mapped to an enum value.')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # Use base64 Data encoding for bytes
+ return base64.b64encode(value).decode('utf-8')
+ else:
+ return value
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ return bool(value)
+ elif field.cpp_type in _INT64_TYPES:
+ return str(value)
+ elif field.cpp_type in _FLOAT_TYPES:
+ if math.isinf(value):
+ if value < 0.0:
+ return _NEG_INFINITY
else:
- js[name] = _FieldToJsonObject(field, field.default_value)
+ return _INFINITY
+ if math.isnan(value):
+ return _NAN
+ return value
+
+ def _AnyMessageToJsonObject(self, message):
+ """Converts Any message according to Proto3 JSON Specification."""
+ if not message.ListFields():
+ return {}
+ # Must print @type first, use OrderedDict instead of {}
+ js = OrderedDict()
+ type_url = message.type_url
+ js['@type'] = type_url
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ sub_message.ParseFromString(message.value)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ js['value'] = self._WrapperMessageToJsonObject(sub_message)
+ return js
+ if full_name in _WKTJSONMETHODS:
+ js['value'] = methodcaller(_WKTJSONMETHODS[full_name][0],
+ sub_message)(self)
+ return js
+ return self._RegularMessageToJsonObject(sub_message, js)
+
+ def _GenericMessageToJsonObject(self, message):
+ """Converts message according to Proto3 JSON Specification."""
+ # Duration, Timestamp and FieldMask have ToJsonString method to do the
+ # convert. Users can also call the method directly.
+ return message.ToJsonString()
+
+ def _ValueMessageToJsonObject(self, message):
+ """Converts Value message according to Proto3 JSON Specification."""
+ which = message.WhichOneof('kind')
+ # If the Value message is not set treat as null_value when serialize
+ # to JSON. The parse back result will be different from original message.
+ if which is None or which == 'null_value':
+ return None
+ if which == 'list_value':
+ return self._ListValueMessageToJsonObject(message.list_value)
+ if which == 'struct_value':
+ value = message.struct_value
+ else:
+ value = getattr(message, which)
+ oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
+ return self._FieldToJsonObject(oneof_descriptor, value)
- except ValueError as e:
- raise SerializeToJsonError(
- 'Failed to serialize {0} field: {1}.'.format(field.name, e))
+ def _ListValueMessageToJsonObject(self, message):
+ """Converts ListValue message according to Proto3 JSON Specification."""
+ return [self._ValueMessageToJsonObject(value)
+ for value in message.values]
- return js
+ def _StructMessageToJsonObject(self, message):
+ """Converts Struct message according to Proto3 JSON Specification."""
+ fields = message.fields
+ ret = {}
+ for key in fields:
+ ret[key] = self._ValueMessageToJsonObject(fields[key])
+ return ret
+ def _WrapperMessageToJsonObject(self, message):
+ return self._FieldToJsonObject(
+ message.DESCRIPTOR.fields_by_name['value'], message.value)
-def _FieldToJsonObject(
- field, value, including_default_value_fields=False):
- """Converts field value according to Proto3 JSON Specification."""
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- return _MessageToJsonObject(value, including_default_value_fields)
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- enum_value = field.enum_type.values_by_number.get(value, None)
- if enum_value is not None:
- return enum_value.name
- else:
- raise SerializeToJsonError('Enum field contains an integer value '
- 'which can not mapped to an enum value.')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
- if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # Use base64 Data encoding for bytes
- return base64.b64encode(value).decode('utf-8')
- else:
- return value
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- return bool(value)
- elif field.cpp_type in _INT64_TYPES:
- return str(value)
- elif field.cpp_type in _FLOAT_TYPES:
- if math.isinf(value):
- if value < 0.0:
- return _NEG_INFINITY
- else:
- return _INFINITY
- if math.isnan(value):
- return _NAN
- return value
+def _IsWrapperMessage(message_descriptor):
+ return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
-def _AnyMessageToJsonObject(message, including_default):
- """Converts Any message according to Proto3 JSON Specification."""
- if not message.ListFields():
- return {}
- js = {}
- type_url = message.type_url
- js['@type'] = type_url
- sub_message = _CreateMessageFromTypeUrl(type_url)
- sub_message.ParseFromString(message.value)
- message_descriptor = sub_message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- js['value'] = _WrapperMessageToJsonObject(sub_message)
- return js
- if full_name in _WKTJSONMETHODS:
- js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default)
- return js
- return _RegularMessageToJsonObject(sub_message, js, including_default)
+
+def _DuplicateChecker(js):
+ result = {}
+ for name, value in js:
+ if name in result:
+ raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name))
+ result[name] = value
+ return result
def _CreateMessageFromTypeUrl(type_url):
@@ -238,69 +380,13 @@ def _CreateMessageFromTypeUrl(type_url):
return message_class()
-def _GenericMessageToJsonObject(message, unused_including_default):
- """Converts message by ToJsonString according to Proto3 JSON Specification."""
- # Duration, Timestamp and FieldMask have ToJsonString method to do the
- # convert. Users can also call the method directly.
- return message.ToJsonString()
-
-
-def _ValueMessageToJsonObject(message, unused_including_default=False):
- """Converts Value message according to Proto3 JSON Specification."""
- which = message.WhichOneof('kind')
- # If the Value message is not set treat as null_value when serialize
- # to JSON. The parse back result will be different from original message.
- if which is None or which == 'null_value':
- return None
- if which == 'list_value':
- return _ListValueMessageToJsonObject(message.list_value)
- if which == 'struct_value':
- value = message.struct_value
- else:
- value = getattr(message, which)
- oneof_descriptor = message.DESCRIPTOR.fields_by_name[which]
- return _FieldToJsonObject(oneof_descriptor, value)
-
-
-def _ListValueMessageToJsonObject(message, unused_including_default=False):
- """Converts ListValue message according to Proto3 JSON Specification."""
- return [_ValueMessageToJsonObject(value)
- for value in message.values]
-
-
-def _StructMessageToJsonObject(message, unused_including_default=False):
- """Converts Struct message according to Proto3 JSON Specification."""
- fields = message.fields
- js = {}
- for key in fields.keys():
- js[key] = _ValueMessageToJsonObject(fields[key])
- return js
-
-
-def _IsWrapperMessage(message_descriptor):
- return message_descriptor.file.name == 'google/protobuf/wrappers.proto'
-
-
-def _WrapperMessageToJsonObject(message):
- return _FieldToJsonObject(
- message.DESCRIPTOR.fields_by_name['value'], message.value)
-
-
-def _DuplicateChecker(js):
- result = {}
- for name, value in js:
- if name in result:
- raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name))
- result[name] = value
- return result
-
-
-def Parse(text, message):
+def Parse(text, message, ignore_unknown_fields=False):
"""Parses a JSON representation of a protocol message into a message.
Args:
text: Message JSON representation.
- message: A protocol beffer message to merge into.
+ message: A protocol buffer message to merge into.
+ ignore_unknown_fields: If True, do not raise errors for unknown fields.
Returns:
The same message passed as argument.
@@ -317,213 +403,255 @@ def Parse(text, message):
js = json.loads(text, object_pairs_hook=_DuplicateChecker)
except ValueError as e:
raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
- _ConvertMessage(js, message)
- return message
+ return ParseDict(js, message, ignore_unknown_fields)
-def _ConvertFieldValuePair(js, message):
- """Convert field value pairs into regular message.
+def ParseDict(js_dict, message, ignore_unknown_fields=False):
+ """Parses a JSON dictionary representation into a message.
Args:
- js: A JSON object to convert the field value pairs.
- message: A regular protocol message to record the data.
+ js_dict: Dict representation of a JSON message.
+ message: A protocol buffer message to merge into.
+ ignore_unknown_fields: If True, do not raise errors for unknown fields.
- Raises:
- ParseError: In case of problems converting.
+ Returns:
+ The same message passed as argument.
"""
- names = []
- message_descriptor = message.DESCRIPTOR
- for name in js:
- try:
- field = message_descriptor.fields_by_camelcase_name.get(name, None)
- if not field:
- raise ParseError(
- 'Message type "{0}" has no field named "{1}".'.format(
- message_descriptor.full_name, name))
- if name in names:
- raise ParseError(
- 'Message type "{0}" should not have multiple "{1}" fields.'.format(
- message.DESCRIPTOR.full_name, name))
- names.append(name)
- # Check no other oneof field is parsed.
- if field.containing_oneof is not None:
- oneof_name = field.containing_oneof.name
- if oneof_name in names:
- raise ParseError('Message type "{0}" should not have multiple "{1}" '
- 'oneof fields.'.format(
- message.DESCRIPTOR.full_name, oneof_name))
- names.append(oneof_name)
-
- value = js[name]
- if value is None:
- message.ClearField(field.name)
- continue
-
- # Parse field value.
- if _IsMapEntry(field):
- message.ClearField(field.name)
- _ConvertMapFieldValue(value, message, field)
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- message.ClearField(field.name)
- if not isinstance(value, list):
- raise ParseError('repeated field {0} must be in [] which is '
- '{1}.'.format(name, value))
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- # Repeated message field.
- for item in value:
- sub_message = getattr(message, field.name).add()
- # None is a null_value in Value.
- if (item is None and
- sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
- raise ParseError('null is not allowed to be used as an element'
- ' in a repeated field.')
- _ConvertMessage(item, sub_message)
- else:
- # Repeated scalar field.
- for item in value:
- if item is None:
- raise ParseError('null is not allowed to be used as an element'
- ' in a repeated field.')
- getattr(message, field.name).append(
- _ConvertScalarFieldValue(item, field))
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- sub_message = getattr(message, field.name)
- _ConvertMessage(value, sub_message)
- else:
- setattr(message, field.name, _ConvertScalarFieldValue(value, field))
- except ParseError as e:
- if field and field.containing_oneof is None:
- raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
- else:
- raise ParseError(str(e))
- except ValueError as e:
- raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
- except TypeError as e:
- raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+ parser = _Parser(ignore_unknown_fields)
+ parser.ConvertMessage(js_dict, message)
+ return message
-def _ConvertMessage(value, message):
- """Convert a JSON object into a message.
+_INT_OR_FLOAT = six.integer_types + (float,)
- Args:
- value: A JSON object.
- message: A WKT or regular protocol message to record the data.
- Raises:
- ParseError: In case of convert problems.
- """
- message_descriptor = message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- _ConvertWrapperMessage(value, message)
- elif full_name in _WKTJSONMETHODS:
- _WKTJSONMETHODS[full_name][1](value, message)
- else:
- _ConvertFieldValuePair(value, message)
-
-
-def _ConvertAnyMessage(value, message):
- """Convert a JSON representation into Any message."""
- if isinstance(value, dict) and not value:
- return
- try:
- type_url = value['@type']
- except KeyError:
- raise ParseError('@type is missing when parsing any message.')
-
- sub_message = _CreateMessageFromTypeUrl(type_url)
- message_descriptor = sub_message.DESCRIPTOR
- full_name = message_descriptor.full_name
- if _IsWrapperMessage(message_descriptor):
- _ConvertWrapperMessage(value['value'], sub_message)
- elif full_name in _WKTJSONMETHODS:
- _WKTJSONMETHODS[full_name][1](value['value'], sub_message)
- else:
- del value['@type']
- _ConvertFieldValuePair(value, sub_message)
- # Sets Any message
- message.value = sub_message.SerializeToString()
- message.type_url = type_url
-
-
-def _ConvertGenericMessage(value, message):
- """Convert a JSON representation into message with FromJsonString."""
- # Durantion, Timestamp, FieldMask have FromJsonString method to do the
- # convert. Users can also call the method directly.
- message.FromJsonString(value)
+class _Parser(object):
+ """JSON format parser for protocol message."""
+ def __init__(self,
+ ignore_unknown_fields):
+ self.ignore_unknown_fields = ignore_unknown_fields
-_INT_OR_FLOAT = six.integer_types + (float,)
-
+ def ConvertMessage(self, value, message):
+ """Convert a JSON object into a message.
-def _ConvertValueMessage(value, message):
- """Convert a JSON representation into Value message."""
- if isinstance(value, dict):
- _ConvertStructMessage(value, message.struct_value)
- elif isinstance(value, list):
- _ConvertListValueMessage(value, message.list_value)
- elif value is None:
- message.null_value = 0
- elif isinstance(value, bool):
- message.bool_value = value
- elif isinstance(value, six.string_types):
- message.string_value = value
- elif isinstance(value, _INT_OR_FLOAT):
- message.number_value = value
- else:
- raise ParseError('Unexpected type for Value message.')
-
-
-def _ConvertListValueMessage(value, message):
- """Convert a JSON representation into ListValue message."""
- if not isinstance(value, list):
- raise ParseError(
- 'ListValue must be in [] which is {0}.'.format(value))
- message.ClearField('values')
- for item in value:
- _ConvertValueMessage(item, message.values.add())
-
-
-def _ConvertStructMessage(value, message):
- """Convert a JSON representation into Struct message."""
- if not isinstance(value, dict):
- raise ParseError(
- 'Struct must be in a dict which is {0}.'.format(value))
- for key in value:
- _ConvertValueMessage(value[key], message.fields[key])
- return
-
-
-def _ConvertWrapperMessage(value, message):
- """Convert a JSON representation into Wrapper message."""
- field = message.DESCRIPTOR.fields_by_name['value']
- setattr(message, 'value', _ConvertScalarFieldValue(value, field))
-
-
-def _ConvertMapFieldValue(value, message, field):
- """Convert map field value for a message map field.
+ Args:
+ value: A JSON object.
+ message: A WKT or regular protocol message to record the data.
- Args:
- value: A JSON object to convert the map field value.
- message: A protocol message to record the converted data.
- field: The descriptor of the map field to be converted.
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ self._ConvertWrapperMessage(value, message)
+ elif full_name in _WKTJSONMETHODS:
+ methodcaller(_WKTJSONMETHODS[full_name][1], value, message)(self)
+ else:
+ self._ConvertFieldValuePair(value, message)
+
+ def _ConvertFieldValuePair(self, js, message):
+ """Convert field value pairs into regular message.
+
+ Args:
+ js: A JSON object to convert the field value pairs.
+ message: A regular protocol message to record the data.
+
+ Raises:
+ ParseError: In case of problems converting.
+ """
+ names = []
+ message_descriptor = message.DESCRIPTOR
+ fields_by_json_name = dict((f.json_name, f)
+ for f in message_descriptor.fields)
+ for name in js:
+ try:
+ field = fields_by_json_name.get(name, None)
+ if not field:
+ field = message_descriptor.fields_by_name.get(name, None)
+ if not field and _VALID_EXTENSION_NAME.match(name):
+ if not message_descriptor.is_extendable:
+ raise ParseError('Message type {0} does not have extensions'.format(
+ message_descriptor.full_name))
+ identifier = name[1:-1] # strip [] brackets
+ identifier = '.'.join(identifier.split('.')[:-1])
+ # pylint: disable=protected-access
+ field = message.Extensions._FindExtensionByName(identifier)
+ # pylint: enable=protected-access
+ if not field:
+ if self.ignore_unknown_fields:
+ continue
+ raise ParseError(
+ ('Message type "{0}" has no field named "{1}".\n'
+ ' Available Fields(except extensions): {2}').format(
+ message_descriptor.full_name, name,
+ message_descriptor.fields))
+ if name in names:
+ raise ParseError('Message type "{0}" should not have multiple '
+ '"{1}" fields.'.format(
+ message.DESCRIPTOR.full_name, name))
+ names.append(name)
+ # Check no other oneof field is parsed.
+ if field.containing_oneof is not None:
+ oneof_name = field.containing_oneof.name
+ if oneof_name in names:
+ raise ParseError('Message type "{0}" should not have multiple '
+ '"{1}" oneof fields.'.format(
+ message.DESCRIPTOR.full_name, oneof_name))
+ names.append(oneof_name)
+
+ value = js[name]
+ if value is None:
+ if (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE
+ and field.message_type.full_name == 'google.protobuf.Value'):
+ sub_message = getattr(message, field.name)
+ sub_message.null_value = 0
+ else:
+ message.ClearField(field.name)
+ continue
- Raises:
- ParseError: In case of convert problems.
- """
- if not isinstance(value, dict):
- raise ParseError(
- 'Map field {0} must be in a dict which is {1}.'.format(
- field.name, value))
- key_field = field.message_type.fields_by_name['key']
- value_field = field.message_type.fields_by_name['value']
- for key in value:
- key_value = _ConvertScalarFieldValue(key, key_field, True)
- if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- _ConvertMessage(value[key], getattr(message, field.name)[key_value])
+ # Parse field value.
+ if _IsMapEntry(field):
+ message.ClearField(field.name)
+ self._ConvertMapFieldValue(value, message, field)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ message.ClearField(field.name)
+ if not isinstance(value, list):
+ raise ParseError('repeated field {0} must be in [] which is '
+ '{1}.'.format(name, value))
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # Repeated message field.
+ for item in value:
+ sub_message = getattr(message, field.name).add()
+ # None is a null_value in Value.
+ if (item is None and
+ sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'):
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ self.ConvertMessage(item, sub_message)
+ else:
+ # Repeated scalar field.
+ for item in value:
+ if item is None:
+ raise ParseError('null is not allowed to be used as an element'
+ ' in a repeated field.')
+ getattr(message, field.name).append(
+ _ConvertScalarFieldValue(item, field))
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.is_extension:
+ sub_message = message.Extensions[field]
+ else:
+ sub_message = getattr(message, field.name)
+ sub_message.SetInParent()
+ self.ConvertMessage(value, sub_message)
+ else:
+ setattr(message, field.name, _ConvertScalarFieldValue(value, field))
+ except ParseError as e:
+ if field and field.containing_oneof is None:
+ raise ParseError('Failed to parse {0} field: {1}'.format(name, e))
+ else:
+ raise ParseError(str(e))
+ except ValueError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+ except TypeError as e:
+ raise ParseError('Failed to parse {0} field: {1}.'.format(name, e))
+
+ def _ConvertAnyMessage(self, value, message):
+ """Convert a JSON representation into Any message."""
+ if isinstance(value, dict) and not value:
+ return
+ try:
+ type_url = value['@type']
+ except KeyError:
+ raise ParseError('@type is missing when parsing any message.')
+
+ sub_message = _CreateMessageFromTypeUrl(type_url)
+ message_descriptor = sub_message.DESCRIPTOR
+ full_name = message_descriptor.full_name
+ if _IsWrapperMessage(message_descriptor):
+ self._ConvertWrapperMessage(value['value'], sub_message)
+ elif full_name in _WKTJSONMETHODS:
+ methodcaller(
+ _WKTJSONMETHODS[full_name][1], value['value'], sub_message)(self)
else:
- getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
- value[key], value_field)
+ del value['@type']
+ self._ConvertFieldValuePair(value, sub_message)
+ # Sets Any message
+ message.value = sub_message.SerializeToString()
+ message.type_url = type_url
+
+ def _ConvertGenericMessage(self, value, message):
+ """Convert a JSON representation into message with FromJsonString."""
+ # Duration, Timestamp, FieldMask have a FromJsonString method to do the
+ # conversion. Users can also call the method directly.
+ message.FromJsonString(value)
+
+ def _ConvertValueMessage(self, value, message):
+ """Convert a JSON representation into Value message."""
+ if isinstance(value, dict):
+ self._ConvertStructMessage(value, message.struct_value)
+ elif isinstance(value, list):
+ self. _ConvertListValueMessage(value, message.list_value)
+ elif value is None:
+ message.null_value = 0
+ elif isinstance(value, bool):
+ message.bool_value = value
+ elif isinstance(value, six.string_types):
+ message.string_value = value
+ elif isinstance(value, _INT_OR_FLOAT):
+ message.number_value = value
+ else:
+ raise ParseError('Unexpected type for Value message.')
+
+ def _ConvertListValueMessage(self, value, message):
+ """Convert a JSON representation into ListValue message."""
+ if not isinstance(value, list):
+ raise ParseError(
+ 'ListValue must be in [] which is {0}.'.format(value))
+ message.ClearField('values')
+ for item in value:
+ self._ConvertValueMessage(item, message.values.add())
+
+ def _ConvertStructMessage(self, value, message):
+ """Convert a JSON representation into Struct message."""
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Struct must be in a dict which is {0}.'.format(value))
+ for key in value:
+ self._ConvertValueMessage(value[key], message.fields[key])
+ return
+
+ def _ConvertWrapperMessage(self, value, message):
+ """Convert a JSON representation into Wrapper message."""
+ field = message.DESCRIPTOR.fields_by_name['value']
+ setattr(message, 'value', _ConvertScalarFieldValue(value, field))
+
+ def _ConvertMapFieldValue(self, value, message, field):
+ """Convert map field value for a message map field.
+
+ Args:
+ value: A JSON object to convert the map field value.
+ message: A protocol message to record the converted data.
+ field: The descriptor of the map field to be converted.
+
+ Raises:
+ ParseError: In case of convert problems.
+ """
+ if not isinstance(value, dict):
+ raise ParseError(
+ 'Map field {0} must be in a dict which is {1}.'.format(
+ field.name, value))
+ key_field = field.message_type.fields_by_name['key']
+ value_field = field.message_type.fields_by_name['value']
+ for key in value:
+ key_value = _ConvertScalarFieldValue(key, key_field, True)
+ if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ self.ConvertMessage(value[key], getattr(
+ message, field.name)[key_value])
+ else:
+ getattr(message, field.name)[key_value] = _ConvertScalarFieldValue(
+ value[key], value_field)
def _ConvertScalarFieldValue(value, field, require_str=False):
@@ -550,15 +678,27 @@ def _ConvertScalarFieldValue(value, field, require_str=False):
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
return base64.b64decode(value)
else:
+ # Checking for unpaired surrogates appears to be unreliable,
+ # depending on the specific Python version, so we check manually.
+ if _UNPAIRED_SURROGATE_PATTERN.search(value):
+ raise ParseError('Unpaired surrogate')
return value
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
# Convert an enum value.
enum_value = field.enum_type.values_by_name.get(value, None)
if enum_value is None:
- raise ParseError(
- 'Enum value must be a string literal with double quotes. '
- 'Type "{0}" has no value named {1}.'.format(
- field.enum_type.full_name, value))
+ try:
+ number = int(value)
+ enum_value = field.enum_type.values_by_number.get(number, None)
+ except ValueError:
+ raise ParseError('Invalid enum value {0} for enum type {1}.'.format(
+ value, field.enum_type.full_name))
+ if enum_value is None:
+ if field.file.syntax == 'proto3':
+ # Proto3 accepts unknown enums.
+ return number
+ raise ParseError('Invalid enum value {0} for enum type {1}.'.format(
+ value, field.enum_type.full_name))
return enum_value.number
@@ -574,7 +714,7 @@ def _ConvertInteger(value):
Raises:
ParseError: If an integer couldn't be consumed.
"""
- if isinstance(value, float):
+ if isinstance(value, float) and not value.is_integer():
raise ParseError('Couldn\'t parse integer: {0}.'.format(value))
if isinstance(value, six.text_type) and value.find(' ') != -1:
@@ -628,18 +768,18 @@ def _ConvertBool(value, require_str):
return value
_WKTJSONMETHODS = {
- 'google.protobuf.Any': [_AnyMessageToJsonObject,
- _ConvertAnyMessage],
- 'google.protobuf.Duration': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.FieldMask': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.ListValue': [_ListValueMessageToJsonObject,
- _ConvertListValueMessage],
- 'google.protobuf.Struct': [_StructMessageToJsonObject,
- _ConvertStructMessage],
- 'google.protobuf.Timestamp': [_GenericMessageToJsonObject,
- _ConvertGenericMessage],
- 'google.protobuf.Value': [_ValueMessageToJsonObject,
- _ConvertValueMessage]
+ 'google.protobuf.Any': ['_AnyMessageToJsonObject',
+ '_ConvertAnyMessage'],
+ 'google.protobuf.Duration': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.FieldMask': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.ListValue': ['_ListValueMessageToJsonObject',
+ '_ConvertListValueMessage'],
+ 'google.protobuf.Struct': ['_StructMessageToJsonObject',
+ '_ConvertStructMessage'],
+ 'google.protobuf.Timestamp': ['_GenericMessageToJsonObject',
+ '_ConvertGenericMessage'],
+ 'google.protobuf.Value': ['_ValueMessageToJsonObject',
+ '_ConvertValueMessage']
}
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index de2f5697..eeb0d576 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -184,9 +184,15 @@ class Message(object):
self.Clear()
self.MergeFromString(serialized)
- def SerializeToString(self):
+ def SerializeToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
+ Arguments:
+ **kwargs: Keyword arguments to the serialize method, accepts
+ the following keyword args:
+ deterministic: If true, requests deterministic serialization of the
+ protobuf, with predictable ordering of map keys.
+
Returns:
A binary string representation of the message if all of the required
fields in the message are set (i.e. the message is initialized).
@@ -196,12 +202,18 @@ class Message(object):
"""
raise NotImplementedError
- def SerializePartialToString(self):
+ def SerializePartialToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
This method is similar to SerializeToString but doesn't check if the
message is initialized.
+ Arguments:
+ **kwargs: Keyword arguments to the serialize method, accepts
+ the following keyword args:
+ deterministic: If true, requests deterministic serialization of the
+ protobuf, with predictable ordering of map keys.
+
Returns:
A string representation of the partial message.
"""
@@ -225,10 +237,11 @@ class Message(object):
# """
def ListFields(self):
"""Returns a list of (FieldDescriptor, value) tuples for all
- fields in the message which are not empty. A singular field is non-empty
- if HasField() would return true, and a repeated field is non-empty if
- it contains at least one element. The fields are ordered by field
- number"""
+ fields in the message which are not empty. A message field is
+ non-empty if HasField() would return true. A singular primitive field
+ is non-empty if HasField() would return true in proto2 or it is non zero
+ in proto3. A repeated field is non-empty if it contains at least one
+ element. The fields are ordered by field number"""
raise NotImplementedError
def HasField(self, field_name):
@@ -255,6 +268,9 @@ class Message(object):
def ClearExtension(self, extension_handle):
raise NotImplementedError
+ def DiscardUnknownFields(self):
+ raise NotImplementedError
+
def ByteSize(self):
"""Returns the serialized size of this message.
Recursively calls ByteSize() on all contained messages.
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 1b059d13..e4fb065e 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -66,7 +66,7 @@ class MessageFactory(object):
Returns:
A class describing the passed in descriptor.
"""
- if descriptor.full_name not in self._classes:
+ if descriptor not in self._classes:
descriptor_name = descriptor.name
if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
@@ -75,16 +75,16 @@ class MessageFactory(object):
(message.Message,),
{'DESCRIPTOR': descriptor, '__module__': None})
# If module not set, it wrongly points to the reflection.py module.
- self._classes[descriptor.full_name] = result_class
+ self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
- if extension.containing_type.full_name not in self._classes:
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
- return self._classes[descriptor.full_name]
+ return self._classes[descriptor]
def GetMessages(self, files):
"""Gets all the messages from a specified file.
@@ -103,13 +103,8 @@ class MessageFactory(object):
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
- for name, msg in file_desc.message_types_by_name.items():
- if file_desc.package:
- full_name = '.'.join([file_desc.package, name])
- else:
- full_name = msg.name
- result[full_name] = self.GetPrototype(
- self.pool.FindMessageTypeByName(full_name))
+ for desc in file_desc.message_types_by_name.values():
+ result[desc.full_name] = self.GetPrototype(desc)
# While the extension FieldDescriptors are created by the descriptor pool,
# the python classes created in the factory need them to be registered
@@ -120,10 +115,10 @@ class MessageFactory(object):
# ignore the registration if the original was the same, or raise
# an error if they were different.
- for name, extension in file_desc.extensions_by_name.items():
- if extension.containing_type.full_name not in self._classes:
+ for extension in file_desc.extensions_by_name.values():
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
return result
@@ -135,13 +130,22 @@ def GetMessages(file_protos):
"""Builds a dictionary of all the messages available in a set of files.
Args:
- file_protos: A sequence of file protos to build messages out of.
+ file_protos: Iterable of FileDescriptorProto to build messages out of.
Returns:
A dictionary mapping proto names to the message classes. This will include
any dependent messages as well as any messages defined in the same file as
a specified message.
"""
- for file_proto in file_protos:
+ # The cpp implementation of the protocol buffer library requires to add the
+ # message in topological order of the dependency graph.
+ file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
+ def _AddFile(file_proto):
+ for dependency in file_proto.dependency:
+ if dependency in file_by_name:
+ # Remove from elements to be visited, in order to cut cycles.
+ _AddFile(file_by_name.pop(dependency))
_FACTORY.pool.Add(file_proto)
+ while file_by_name:
+ _AddFile(file_by_name.popitem()[1])
return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])
diff --git a/python/google/protobuf/proto_api.h b/python/google/protobuf/proto_api.h
new file mode 100644
index 00000000..64d8dda9
--- /dev/null
+++ b/python/google/protobuf/proto_api.h
@@ -0,0 +1,92 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// This file can be included by other C++ libraries, typically extension modules
+// which want to interact with the Python Messages coming from the "cpp"
+// implementation of protocol buffers.
+//
+// Usage:
+// Declare a (probably static) variable to hold the API:
+// const PyProto_API* py_proto_api;
+// In some initialization function, write:
+// py_proto_api = static_cast<const PyProto_API*>(PyCapsule_Import(
+// PyProtoAPICapsuleName(), 0));
+// if (!py_proto_api) { ...handle ImportError... }
+// Then use the methods of the returned class:
+// py_proto_api->GetMessagePointer(...);
+
+#ifndef PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
+#define PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
+
+#include <Python.h>
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+namespace python {
+
+// Note on the implementation:
+// This API is designed after
+// https://docs.python.org/3/extending/extending.html#providing-a-c-api-for-an-extension-module
+// The class below contains no mutable state, and all methods are "const";
+// we use a C++ class instead of a C struct with functions pointers just because
+// the code looks more readable.
+struct PyProto_API {
+ // The API object is created at initialization time and never freed.
+ // This destructor is never called.
+ virtual ~PyProto_API() {}
+
+ // Operations on Messages.
+
+ // If the passed object is a Python Message, returns its internal pointer.
+ // Otherwise, returns NULL with an exception set.
+ virtual const Message* GetMessagePointer(PyObject* msg) const = 0;
+
+ // If the passed object is a Python Message, returns a mutable pointer.
+ // Otherwise, returns NULL with an exception set.
+ // This function will succeed only if there are no other Python objects
+ // pointing to the message, like submessages or repeated containers.
+ // With the current implementation, only empty messages are in this case.
+ virtual Message* GetMutableMessagePointer(PyObject* msg) const = 0;
+};
+
+inline const char* PyProtoAPICapsuleName() {
+ static const char kCapsuleName[] =
+ "protobuf.python.google.protobuf.cpp._message.proto_API";
+ return kCapsuleName;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
+
+#endif // PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
diff --git a/python/google/protobuf/pyext/__init__.py b/python/google/protobuf/pyext/__init__.py
index e69de29b..55856141 100644
--- a/python/google/protobuf/pyext/__init__.py
+++ b/python/google/protobuf/pyext/__init__.py
@@ -0,0 +1,4 @@
+try:
+ __import__('pkg_resources').declare_namespace(__name__)
+except ImportError:
+ __path__ = __import__('pkgutil').extend_path(__path__, __name__)
diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py
index b215211e..fc8eb32d 100644
--- a/python/google/protobuf/pyext/cpp_message.py
+++ b/python/google/protobuf/pyext/cpp_message.py
@@ -48,9 +48,9 @@ class GeneratedProtocolMessageType(_message.MessageMeta):
classes at runtime, as in this example:
mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
+ factory = symbol_database.Default()
+ factory.pool.AddDescriptor(mydescriptor)
+ MyProtoClass = factory.GetPrototype(mydescriptor)
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index a875a7be..8af0cb12 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -32,6 +32,7 @@
#include <Python.h>
#include <frameobject.h>
+#include <google/protobuf/stubs/hash.h>
#include <string>
#include <google/protobuf/io/coded_stream.h>
@@ -41,6 +42,7 @@
#include <google/protobuf/pyext/descriptor_containers.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
@@ -92,11 +94,10 @@ PyObject* PyString_FromCppString(const string& str) {
// TODO(amauryfa): Change the proto2 compiler to remove the assignments, and
// remove this hack.
bool _CalledFromGeneratedFile(int stacklevel) {
- PyThreadState *state = PyThreadState_GET();
- if (state == NULL) {
- return false;
- }
- PyFrameObject* frame = state->frame;
+#ifndef PYPY_VERSION
+ // This check is not critical and is somewhat difficult to implement correctly
+ // in PyPy.
+ PyFrameObject* frame = PyEval_GetFrame();
if (frame == NULL) {
return false;
}
@@ -106,10 +107,6 @@ bool _CalledFromGeneratedFile(int stacklevel) {
return false;
}
}
- if (frame->f_globals != frame->f_locals) {
- // Not at global module scope
- return false;
- }
if (frame->f_code->co_filename == NULL) {
return false;
@@ -122,6 +119,10 @@ bool _CalledFromGeneratedFile(int stacklevel) {
PyErr_Clear();
return false;
}
+ if ((filename_size < 3) || (strcmp(&filename[filename_size - 3], ".py") != 0)) {
+ // Cython's stack does not have .py file name and is not at global module scope.
+ return true;
+ }
if (filename_size < 7) {
// filename is too short.
return false;
@@ -130,6 +131,12 @@ bool _CalledFromGeneratedFile(int stacklevel) {
// Filename is not ending with _pb2.
return false;
}
+
+ if (frame->f_globals != frame->f_locals) {
+ // Not at global module scope
+ return false;
+ }
+#endif
return true;
}
@@ -172,50 +179,56 @@ template<>
const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) {
return descriptor->containing_type()->file();
}
+template<>
+const FileDescriptor* GetFileDescriptor(const MethodDescriptor* descriptor) {
+ return descriptor->service()->file();
+}
// Converts options into a Python protobuf, and cache the result.
//
// This is a bit tricky because options can contain extension fields defined in
// the same proto file. In this case the options parsed from the serialized_pb
-// have unkown fields, and we need to parse them again.
+// have unknown fields, and we need to parse them again.
//
// Always returns a new reference.
template<class DescriptorClass>
static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
- // Options (and their extensions) are completely resolved in the proto file
- // containing the descriptor.
- PyDescriptorPool* pool = GetDescriptorPool_FromPool(
+ // Options are cached in the pool that owns the descriptor.
+ // First search in the cache.
+ PyDescriptorPool* caching_pool = GetDescriptorPool_FromPool(
GetFileDescriptor(descriptor)->pool());
-
hash_map<const void*, PyObject*>* descriptor_options =
- pool->descriptor_options;
- // First search in the cache.
+ caching_pool->descriptor_options;
if (descriptor_options->find(descriptor) != descriptor_options->end()) {
PyObject *value = (*descriptor_options)[descriptor];
Py_INCREF(value);
return value;
}
+ // Similar to the C++ implementation, we return an Options object from the
+ // default (generated) factory, so that client code know that they can use
+ // extensions from generated files:
+ // d.GetOptions().Extensions[some_pb2.extension]
+ //
+ // The consequence is that extensions not defined in the default pool won't
+ // be available. If needed, we could add an optional 'message_factory'
+ // parameter to the GetOptions() function.
+ PyMessageFactory* message_factory =
+ GetDefaultDescriptorPool()->py_message_factory;
+
// Build the Options object: get its Python class, and make a copy of the C++
// read-only instance.
const Message& options(descriptor->options());
const Descriptor *message_type = options.GetDescriptor();
- PyObject* message_class(cdescriptor_pool::GetMessageClass(
- pool, message_type));
- if (message_class == NULL) {
- // The Options message was not found in the current DescriptorPool.
- // In this case, there cannot be extensions to these options, and we can
- // try to use the basic pool instead.
- PyErr_Clear();
- message_class = cdescriptor_pool::GetMessageClass(
- GetDefaultDescriptorPool(), message_type);
- }
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
+ message_factory, message_type);
if (message_class == NULL) {
PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s",
message_type->full_name().c_str());
return NULL;
}
- ScopedPyObjectPtr value(PyEval_CallObject(message_class, NULL));
+ ScopedPyObjectPtr value(
+ PyEval_CallObject(message_class->AsPyObject(), NULL));
if (value == NULL) {
return NULL;
}
@@ -237,7 +250,8 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
options.SerializeToString(&serialized);
io::CodedInputStream input(
reinterpret_cast<const uint8*>(serialized.c_str()), serialized.size());
- input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ input.SetExtensionRegistry(message_factory->pool->pool,
+ message_factory->message_factory);
bool success = cmsg->message->MergePartialFromCodedStream(&input);
if (!success) {
PyErr_Format(PyExc_ValueError, "Error parsing Options message");
@@ -247,7 +261,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// Cache the result.
Py_INCREF(value.get());
- (*pool->descriptor_options)[descriptor] = value.get();
+ (*descriptor_options)[descriptor] = value.get();
return value.release();
}
@@ -433,11 +447,12 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
// which contains this descriptor.
// This might not be the one you expect! For example the returned object does
// not know about extensions defined in a custom pool.
- PyObject* concrete_class(cdescriptor_pool::GetMessageClass(
- GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()),
+ CMessageClass* concrete_class(message_factory::GetMessageClass(
+ GetDescriptorPool_FromPool(
+ _GetDescriptor(self)->file()->pool())->py_message_factory,
_GetDescriptor(self)));
Py_XINCREF(concrete_class);
- return concrete_class;
+ return concrete_class->AsPyObject();
}
static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) {
@@ -552,6 +567,11 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
return CheckCalledFromGeneratedFile("_options");
}
+static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
+
static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
return CopyToPythonProto<DescriptorProto>(_GetDescriptor(self), target);
}
@@ -611,6 +631,8 @@ static PyGetSetDef Getters[] = {
{ "is_extendable", (getter)IsExtendable, (setter)NULL},
{ "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
{ "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{ "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
{NULL}
};
@@ -693,6 +715,14 @@ static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) {
return PyString_FromCppString(_GetDescriptor(self)->camelcase_name());
}
+static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->json_name());
+}
+
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
return PyInt_FromLong(_GetDescriptor(self)->type());
}
@@ -765,7 +795,7 @@ static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
- string value = _GetDescriptor(self)->default_value_string();
+ const string& value = _GetDescriptor(self)->default_value_string();
result = ToStringObject(_GetDescriptor(self), value);
break;
}
@@ -877,11 +907,17 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
return CheckCalledFromGeneratedFile("_options");
}
+static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
static PyGetSetDef Getters[] = {
{ "full_name", (getter)GetFullName, NULL, "Full name"},
{ "name", (getter)GetName, NULL, "Unqualified name"},
{ "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"},
+ { "json_name", (getter)GetJsonName, NULL, "Json name"},
+ { "file", (getter)GetFile, NULL, "File Descriptor"},
{ "type", (getter)GetType, NULL, "C++ Type"},
{ "cpp_type", (getter)GetCppType, NULL, "C++ Type"},
{ "label", (getter)GetLabel, NULL, "Label"},
@@ -904,6 +940,8 @@ static PyGetSetDef Getters[] = {
"Containing oneof"},
{ "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
{ "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{NULL}
};
@@ -1033,6 +1071,11 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
return CheckCalledFromGeneratedFile("_options");
}
+static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
+
static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
return CopyToPythonProto<EnumDescriptorProto>(_GetDescriptor(self), target);
}
@@ -1057,6 +1100,8 @@ static PyGetSetDef Getters[] = {
"Containing type"},
{ "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
{ "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{NULL}
};
@@ -1090,7 +1135,7 @@ PyTypeObject PyEnumDescriptor_Type = {
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
- enum_descriptor::Methods, // tp_getset
+ enum_descriptor::Methods, // tp_methods
0, // tp_members
enum_descriptor::Getters, // tp_getset
&descriptor::PyBaseDescriptor_Type, // tp_base
@@ -1157,6 +1202,10 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
return CheckCalledFromGeneratedFile("_options");
}
+static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
static PyGetSetDef Getters[] = {
{ "name", (getter)GetName, NULL, "name"},
@@ -1166,6 +1215,8 @@ static PyGetSetDef Getters[] = {
{ "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
{ "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{NULL}
};
@@ -1274,6 +1325,10 @@ static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) {
return NewFileExtensionsByName(_GetDescriptor(self));
}
+static PyObject* GetServicesByName(PyFileDescriptor* self, void *closure) {
+ return NewFileServicesByName(_GetDescriptor(self));
+}
+
static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) {
return NewFileDependencies(_GetDescriptor(self));
}
@@ -1304,6 +1359,11 @@ static int SetOptions(PyFileDescriptor *self, PyObject *value,
return CheckCalledFromGeneratedFile("_options");
}
+static int SetSerializedOptions(PyFileDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
+
static PyObject* GetSyntax(PyFileDescriptor *self, void *closure) {
return PyString_InternFromString(
FileDescriptor::SyntaxName(_GetDescriptor(self)->syntax()));
@@ -1323,11 +1383,14 @@ static PyGetSetDef Getters[] = {
{ "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"},
{ "extensions_by_name", (getter)GetExtensionsByName, NULL,
"Extensions by name"},
+ { "services_by_name", (getter)GetServicesByName, NULL, "Services by name"},
{ "dependencies", (getter)GetDependencies, NULL, "Dependencies"},
{ "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"},
{ "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
{ "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{ "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
{NULL}
};
@@ -1451,16 +1514,52 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
}
}
+static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
+ const OneofOptions& options(_GetDescriptor(self)->options());
+ if (&options != &OneofOptions::default_instance()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+static int SetHasOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("has_options");
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static int SetOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_options");
+}
+
+static int SetSerializedOptions(PyBaseDescriptor *self, PyObject *value,
+ void *closure) {
+ return CheckCalledFromGeneratedFile("_serialized_options");
+}
+
static PyGetSetDef Getters[] = {
{ "name", (getter)GetName, NULL, "Name"},
{ "full_name", (getter)GetFullName, NULL, "Full name"},
{ "index", (getter)GetIndex, NULL, "Index"},
{ "containing_type", (getter)GetContainingType, NULL, "Containing type"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "_serialized_options", (getter)NULL, (setter)SetSerializedOptions,
+ "Serialized Options"},
{ "fields", (getter)GetFields, NULL, "Fields"},
{NULL}
};
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS },
+ {NULL}
+};
+
} // namespace oneof_descriptor
PyTypeObject PyOneofDescriptor_Type = {
@@ -1491,7 +1590,7 @@ PyTypeObject PyOneofDescriptor_Type = {
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
- 0, // tp_methods
+ oneof_descriptor::Methods, // tp_methods
0, // tp_members
oneof_descriptor::Getters, // tp_getset
&descriptor::PyBaseDescriptor_Type, // tp_base
@@ -1503,6 +1602,245 @@ PyObject* PyOneofDescriptor_FromDescriptor(
&PyOneofDescriptor_Type, oneof_descriptor, NULL);
}
+namespace service_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const ServiceDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const ServiceDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetMethods(PyBaseDescriptor* self, void *closure) {
+ return NewServiceMethodsSeq(_GetDescriptor(self));
+}
+
+static PyObject* GetMethodsByName(PyBaseDescriptor* self, void *closure) {
+ return NewServiceMethodsByName(_GetDescriptor(self));
+}
+
+static PyObject* FindMethodByName(PyBaseDescriptor *self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const MethodDescriptor* method_descriptor =
+ _GetDescriptor(self)->FindMethodByName(string(name, name_size));
+ if (method_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name);
+ return NULL;
+ }
+
+ return PyMethodDescriptor_FromDescriptor(method_descriptor);
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<ServiceDescriptorProto>(_GetDescriptor(self),
+ target);
+}
+
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "Name", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name", NULL},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
+ { "index", (getter)GetIndex, NULL, "Index", NULL},
+
+ { "methods", (getter)GetMethods, NULL, "Methods", NULL},
+ { "methods_by_name", (getter)GetMethodsByName, NULL, "Methods by name", NULL},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O },
+ {NULL}
+};
+
+} // namespace service_descriptor
+
+PyTypeObject PyServiceDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ServiceDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Service Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ service_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ service_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyServiceDescriptor_FromDescriptor(
+ const ServiceDescriptor* service_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyServiceDescriptor_Type, service_descriptor, NULL);
+}
+
+const ServiceDescriptor* PyServiceDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyServiceDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a ServiceDescriptor");
+ return NULL;
+ }
+ return reinterpret_cast<const ServiceDescriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
+namespace method_descriptor {
+
+// Unchecked accessor to the C++ pointer.
+static const MethodDescriptor* _GetDescriptor(
+ PyBaseDescriptor *self) {
+ return reinterpret_cast<const MethodDescriptor*>(self->descriptor);
+}
+
+static PyObject* GetName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->name());
+}
+
+static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->full_name());
+}
+
+static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
+ return PyInt_FromLong(_GetDescriptor(self)->index());
+}
+
+static PyObject* GetContainingService(PyBaseDescriptor *self, void *closure) {
+ const ServiceDescriptor* containing_service =
+ _GetDescriptor(self)->service();
+ return PyServiceDescriptor_FromDescriptor(containing_service);
+}
+
+static PyObject* GetInputType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* input_type = _GetDescriptor(self)->input_type();
+ return PyMessageDescriptor_FromDescriptor(input_type);
+}
+
+static PyObject* GetOutputType(PyBaseDescriptor *self, void *closure) {
+ const Descriptor* output_type = _GetDescriptor(self)->output_type();
+ return PyMessageDescriptor_FromDescriptor(output_type);
+}
+
+static PyObject* GetOptions(PyBaseDescriptor *self) {
+ return GetOrBuildOptions(_GetDescriptor(self));
+}
+
+static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
+ return CopyToPythonProto<MethodDescriptorProto>(_GetDescriptor(self), target);
+}
+
+static PyGetSetDef Getters[] = {
+ { "name", (getter)GetName, NULL, "Name", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name", NULL},
+ { "index", (getter)GetIndex, NULL, "Index", NULL},
+ { "containing_service", (getter)GetContainingService, NULL,
+ "Containing service", NULL},
+ { "input_type", (getter)GetInputType, NULL, "Input type", NULL},
+ { "output_type", (getter)GetOutputType, NULL, "Output type", NULL},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, },
+ { "CopyToProto", (PyCFunction)CopyToProto, METH_O, },
+ {NULL}
+};
+
+} // namespace method_descriptor
+
+PyTypeObject PyMethodDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MethodDescriptor", // tp_name
+ sizeof(PyBaseDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ 0, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Method Descriptor", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ method_descriptor::Methods, // tp_methods
+ 0, // tp_members
+ method_descriptor::Getters, // tp_getset
+ &descriptor::PyBaseDescriptor_Type, // tp_base
+};
+
+PyObject* PyMethodDescriptor_FromDescriptor(
+ const MethodDescriptor* method_descriptor) {
+ return descriptor::NewInternedDescriptor(
+ &PyMethodDescriptor_Type, method_descriptor, NULL);
+}
+
+const MethodDescriptor* PyMethodDescriptor_AsDescriptor(PyObject* obj) {
+ if (!PyObject_TypeCheck(obj, &PyMethodDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a MethodDescriptor");
+ return NULL;
+ }
+ return reinterpret_cast<const MethodDescriptor*>(
+ reinterpret_cast<PyBaseDescriptor*>(obj)->descriptor);
+}
+
// Add a enum values to a type dictionary.
static bool AddEnumValues(PyTypeObject *type,
const EnumDescriptor* enum_descriptor) {
@@ -1572,6 +1910,12 @@ bool InitDescriptor() {
if (PyType_Ready(&PyOneofDescriptor_Type) < 0)
return false;
+ if (PyType_Ready(&PyServiceDescriptor_Type) < 0)
+ return false;
+
+ if (PyType_Ready(&PyMethodDescriptor_Type) < 0)
+ return false;
+
if (!InitDescriptorMappingTypes())
return false;
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
index eb99df18..f081df84 100644
--- a/python/google/protobuf/pyext/descriptor.h
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -47,6 +47,8 @@ extern PyTypeObject PyEnumDescriptor_Type;
extern PyTypeObject PyEnumValueDescriptor_Type;
extern PyTypeObject PyFileDescriptor_Type;
extern PyTypeObject PyOneofDescriptor_Type;
+extern PyTypeObject PyServiceDescriptor_Type;
+extern PyTypeObject PyMethodDescriptor_Type;
// Wraps a Descriptor in a Python object.
// The C++ pointer is usually borrowed from the global DescriptorPool.
@@ -60,6 +62,10 @@ PyObject* PyEnumValueDescriptor_FromDescriptor(
PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor);
PyObject* PyFileDescriptor_FromDescriptor(
const FileDescriptor* file_descriptor);
+PyObject* PyServiceDescriptor_FromDescriptor(
+ const ServiceDescriptor* descriptor);
+PyObject* PyMethodDescriptor_FromDescriptor(
+ const MethodDescriptor* descriptor);
// Alternate constructor of PyFileDescriptor, used when we already have a
// serialized FileDescriptorProto that can be cached.
@@ -74,6 +80,8 @@ const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj);
const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj);
const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj);
const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj);
+const ServiceDescriptor* PyServiceDescriptor_AsDescriptor(PyObject* obj);
+const MethodDescriptor* PyMethodDescriptor_AsDescriptor(PyObject* obj);
// Returns the raw C++ pointer.
const void* PyDescriptor_AsVoidPtr(PyObject* obj);
diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc
index e505d812..bc007f7e 100644
--- a/python/google/protobuf/pyext/descriptor_containers.cc
+++ b/python/google/protobuf/pyext/descriptor_containers.cc
@@ -608,6 +608,24 @@ static PyObject* GetItem(PyContainer* self, Py_ssize_t index) {
return _NewObj_ByIndex(self, index);
}
+static PyObject *
+SeqSubscript(PyContainer* self, PyObject* item) {
+ if (PyIndex_Check(item)) {
+ Py_ssize_t index;
+ index = PyNumber_AsSsize_t(item, PyExc_IndexError);
+ if (index == -1 && PyErr_Occurred())
+ return NULL;
+ return GetItem(self, index);
+ }
+ // Materialize the list and delegate the operation to it.
+ ScopedPyObjectPtr list(PyObject_CallFunctionObjArgs(
+ reinterpret_cast<PyObject*>(&PyList_Type), self, NULL));
+ if (list == NULL) {
+ return NULL;
+ }
+ return Py_TYPE(list.get())->tp_as_mapping->mp_subscript(list.get(), item);
+}
+
// Returns the position of the item in the sequence, of -1 if not found.
// This function never fails.
int Find(PyContainer* self, PyObject* item) {
@@ -703,14 +721,20 @@ static PyMethodDef SeqMethods[] = {
};
static PySequenceMethods SeqSequenceMethods = {
- (lenfunc)Length, // sq_length
- 0, // sq_concat
- 0, // sq_repeat
- (ssizeargfunc)GetItem, // sq_item
- 0, // sq_slice
- 0, // sq_ass_item
- 0, // sq_ass_slice
- (objobjproc)SeqContains, // sq_contains
+ (lenfunc)Length, // sq_length
+ 0, // sq_concat
+ 0, // sq_repeat
+ (ssizeargfunc)GetItem, // sq_item
+ 0, // sq_slice
+ 0, // sq_ass_item
+ 0, // sq_ass_slice
+ (objobjproc)SeqContains, // sq_contains
+};
+
+static PyMappingMethods SeqMappingMethods = {
+ (lenfunc)Length, // mp_length
+ (binaryfunc)SeqSubscript, // mp_subscript
+ 0, // mp_ass_subscript
};
PyTypeObject DescriptorSequence_Type = {
@@ -726,7 +750,7 @@ PyTypeObject DescriptorSequence_Type = {
(reprfunc)ContainerRepr, // tp_repr
0, // tp_as_number
&SeqSequenceMethods, // tp_as_sequence
- 0, // tp_as_mapping
+ &SeqMappingMethods, // tp_as_mapping
0, // tp_hash
0, // tp_call
0, // tp_str
@@ -933,55 +957,55 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->field_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindFieldByName(name);
}
-static ItemDescriptor GetByCamelcaseName(PyContainer* self,
+static const void* GetByCamelcaseName(PyContainer* self,
const string& name) {
return GetDescriptor(self)->FindFieldByCamelcaseName(name);
}
-static ItemDescriptor GetByNumber(PyContainer* self, int number) {
+static const void* GetByNumber(PyContainer* self, int number) {
return GetDescriptor(self)->FindFieldByNumber(number);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->field(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static const string& GetItemCamelcaseName(ItemDescriptor item) {
- return item->camelcase_name();
+static const string& GetItemCamelcaseName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->camelcase_name();
}
-static int GetItemNumber(ItemDescriptor item) {
- return item->number();
+static int GetItemNumber(const void* item) {
+ return static_cast<ItemDescriptor>(item)->number();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"MessageFields",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)GetByCamelcaseName,
- (GetByNumberMethod)GetByNumber,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)GetItemCamelcaseName,
- (GetItemNumberMethod)GetItemNumber,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ GetByCamelcaseName,
+ GetByNumber,
+ NewObjectFromItem,
+ GetItemName,
+ GetItemCamelcaseName,
+ GetItemNumber,
+ GetItemIndex,
};
} // namespace fields
@@ -1011,38 +1035,38 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->nested_type_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindNestedTypeByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->nested_type(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyMessageDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyMessageDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"MessageNestedTypes",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace nested_types
@@ -1063,38 +1087,38 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->enum_type_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindEnumTypeByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->enum_type(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyEnumDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"MessageNestedEnums",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace enums
@@ -1126,11 +1150,11 @@ static int Count(PyContainer* self) {
return count;
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindEnumValueByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
// This is not optimal, but the number of enums *types* in a given message
// is small. This function is only used when iterating over the mapping.
const EnumDescriptor* enum_type = NULL;
@@ -1149,26 +1173,27 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
return enum_type->value(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumValueDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyEnumValueDescriptor_FromDescriptor(
+ static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
static DescriptorContainerDef ContainerDef = {
"MessageEnumValues",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)NULL,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ NULL,
};
} // namespace enumvalues
@@ -1185,38 +1210,38 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->extension_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindExtensionByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->extension(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"MessageExtensions",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace extensions
@@ -1237,38 +1262,38 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->oneof_decl_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindOneofByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->oneof_decl(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyOneofDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyOneofDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"MessageOneofs",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace oneofs
@@ -1299,46 +1324,47 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->value_count();
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->value(index);
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindValueByName(name);
}
-static ItemDescriptor GetByNumber(PyContainer* self, int number) {
+static const void* GetByNumber(PyContainer* self, int number) {
return GetDescriptor(self)->FindValueByNumber(number);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumValueDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyEnumValueDescriptor_FromDescriptor(
+ static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemNumber(ItemDescriptor item) {
- return item->number();
+static int GetItemNumber(const void* item) {
+ return static_cast<ItemDescriptor>(item)->number();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"EnumValues",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)GetByNumber,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)GetItemNumber,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ GetByNumber,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ GetItemNumber,
+ GetItemIndex,
};
} // namespace enumvalues
@@ -1373,30 +1399,30 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->field_count();
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->field(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index_in_oneof();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index_in_oneof();
}
static DescriptorContainerDef ContainerDef = {
"OneofFields",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)NULL,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)NULL,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ NULL,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ NULL,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace fields
@@ -1407,6 +1433,68 @@ PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) {
} // namespace oneof_descriptor
+namespace service_descriptor {
+
+typedef const ServiceDescriptor* ParentDescriptor;
+
+static ParentDescriptor GetDescriptor(PyContainer* self) {
+ return reinterpret_cast<ParentDescriptor>(self->descriptor);
+}
+
+namespace methods {
+
+typedef const MethodDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->method_count();
+}
+
+static const void* GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindMethodByName(name);
+}
+
+static const void* GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->method(index);
+}
+
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyMethodDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
+}
+
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
+}
+
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "ServiceMethods",
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
+};
+
+} // namespace methods
+
+PyObject* NewServiceMethodsSeq(ParentDescriptor descriptor) {
+ return descriptor::NewSequence(&methods::ContainerDef, descriptor);
+}
+
+PyObject* NewServiceMethodsByName(ParentDescriptor descriptor) {
+ return descriptor::NewMappingByName(&methods::ContainerDef, descriptor);
+}
+
+} // namespace service_descriptor
+
namespace file_descriptor {
typedef const FileDescriptor* ParentDescriptor;
@@ -1423,43 +1511,43 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->message_type_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindMessageTypeByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->message_type(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyMessageDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyMessageDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"FileMessages",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace messages
-PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) {
+PyObject* NewFileMessageTypesByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&messages::ContainerDef, descriptor);
}
@@ -1471,43 +1559,43 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->enum_type_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindEnumTypeByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->enum_type(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyEnumDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"FileEnums",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace enums
-PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) {
+PyObject* NewFileEnumTypesByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&enums::ContainerDef, descriptor);
}
@@ -1519,46 +1607,94 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->extension_count();
}
-static ItemDescriptor GetByName(PyContainer* self, const string& name) {
+static const void* GetByName(PyContainer* self, const string& name) {
return GetDescriptor(self)->FindExtensionByName(name);
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->extension(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFieldDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
-static const string& GetItemName(ItemDescriptor item) {
- return item->name();
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
}
-static int GetItemIndex(ItemDescriptor item) {
- return item->index();
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
}
static DescriptorContainerDef ContainerDef = {
"FileExtensions",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)GetByName,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)GetItemName,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)GetItemIndex,
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
};
} // namespace extensions
-PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) {
+PyObject* NewFileExtensionsByName(ParentDescriptor descriptor) {
return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor);
}
+namespace services {
+
+typedef const ServiceDescriptor* ItemDescriptor;
+
+static int Count(PyContainer* self) {
+ return GetDescriptor(self)->service_count();
+}
+
+static const void* GetByName(PyContainer* self, const string& name) {
+ return GetDescriptor(self)->FindServiceByName(name);
+}
+
+static const void* GetByIndex(PyContainer* self, int index) {
+ return GetDescriptor(self)->service(index);
+}
+
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyServiceDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
+}
+
+static const string& GetItemName(const void* item) {
+ return static_cast<ItemDescriptor>(item)->name();
+}
+
+static int GetItemIndex(const void* item) {
+ return static_cast<ItemDescriptor>(item)->index();
+}
+
+static DescriptorContainerDef ContainerDef = {
+ "FileServices",
+ Count,
+ GetByIndex,
+ GetByName,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ GetItemName,
+ NULL,
+ NULL,
+ GetItemIndex,
+};
+
+} // namespace services
+
+PyObject* NewFileServicesByName(const FileDescriptor* descriptor) {
+ return descriptor::NewMappingByName(&services::ContainerDef, descriptor);
+}
+
namespace dependencies {
typedef const FileDescriptor* ItemDescriptor;
@@ -1567,26 +1703,26 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->dependency_count();
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->dependency(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFileDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFileDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
static DescriptorContainerDef ContainerDef = {
"FileDependencies",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)NULL,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)NULL,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)NULL,
+ Count,
+ GetByIndex,
+ NULL,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
};
} // namespace dependencies
@@ -1603,26 +1739,26 @@ static int Count(PyContainer* self) {
return GetDescriptor(self)->public_dependency_count();
}
-static ItemDescriptor GetByIndex(PyContainer* self, int index) {
+static const void* GetByIndex(PyContainer* self, int index) {
return GetDescriptor(self)->public_dependency(index);
}
-static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFileDescriptor_FromDescriptor(item);
+static PyObject* NewObjectFromItem(const void* item) {
+ return PyFileDescriptor_FromDescriptor(static_cast<ItemDescriptor>(item));
}
static DescriptorContainerDef ContainerDef = {
"FilePublicDependencies",
- (CountMethod)Count,
- (GetByIndexMethod)GetByIndex,
- (GetByNameMethod)NULL,
- (GetByCamelcaseNameMethod)NULL,
- (GetByNumberMethod)NULL,
- (NewObjectFromItemMethod)NewObjectFromItem,
- (GetItemNameMethod)NULL,
- (GetItemCamelcaseNameMethod)NULL,
- (GetItemNumberMethod)NULL,
- (GetItemIndexMethod)NULL,
+ Count,
+ GetByIndex,
+ NULL,
+ NULL,
+ NULL,
+ NewObjectFromItem,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
};
} // namespace public_dependencies
diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h
index ce40747d..83de07b6 100644
--- a/python/google/protobuf/pyext/descriptor_containers.h
+++ b/python/google/protobuf/pyext/descriptor_containers.h
@@ -43,6 +43,7 @@ class Descriptor;
class FileDescriptor;
class EnumDescriptor;
class OneofDescriptor;
+class ServiceDescriptor;
namespace python {
@@ -89,10 +90,17 @@ PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor);
PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor);
+PyObject* NewFileServicesByName(const FileDescriptor* descriptor);
+
PyObject* NewFileDependencies(const FileDescriptor* descriptor);
PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor);
} // namespace file_descriptor
+namespace service_descriptor {
+PyObject* NewServiceMethodsSeq(const ServiceDescriptor* descriptor);
+PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor);
+} // namespace service_descriptor
+
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc
index 514722b4..daa40cc7 100644
--- a/python/google/protobuf/pyext/descriptor_database.cc
+++ b/python/google/protobuf/pyext/descriptor_database.cc
@@ -64,6 +64,9 @@ static bool GetFileDescriptorProto(PyObject* py_descriptor,
}
return false;
}
+ if (py_descriptor == Py_None) {
+ return false;
+ }
const Descriptor* filedescriptor_descriptor =
FileDescriptorProto::default_instance().GetDescriptor();
CMessage* message = reinterpret_cast<CMessage*>(py_descriptor);
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index 0bc76bc9..95882aeb 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -33,12 +33,13 @@
#include <Python.h>
#include <google/protobuf/descriptor.pb.h>
-#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_database.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/stubs/hash.h>
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
@@ -73,18 +74,16 @@ static PyDescriptorPool* _CreateDescriptorPool() {
cpool->underlay = NULL;
cpool->database = NULL;
- DynamicMessageFactory* message_factory = new DynamicMessageFactory();
- // This option might be the default some day.
- message_factory->SetDelegateToGeneratedFactory(true);
- cpool->message_factory = message_factory;
-
- // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same
- // storage.
- cpool->classes_by_descriptor =
- new PyDescriptorPool::ClassesByMessageMap();
cpool->descriptor_options =
new hash_map<const void*, PyObject *>();
+ cpool->py_message_factory = message_factory::NewMessageFactory(
+ &PyMessageFactory_Type, cpool);
+ if (cpool->py_message_factory == NULL) {
+ Py_DECREF(cpool);
+ return NULL;
+ }
+
return cpool;
}
@@ -150,27 +149,22 @@ static PyObject* New(PyTypeObject* type,
PyDescriptorPool_NewWithDatabase(database));
}
-static void Dealloc(PyDescriptorPool* self) {
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
+static void Dealloc(PyObject* pself) {
+ PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
descriptor_pool_map.erase(self->pool);
- for (iterator it = self->classes_by_descriptor->begin();
- it != self->classes_by_descriptor->end(); ++it) {
- Py_DECREF(it->second);
- }
- delete self->classes_by_descriptor;
+ Py_CLEAR(self->py_message_factory);
for (hash_map<const void*, PyObject*>::iterator it =
self->descriptor_options->begin();
it != self->descriptor_options->end(); ++it) {
Py_DECREF(it->second);
}
delete self->descriptor_options;
- delete self->message_factory;
delete self->database;
delete self->pool;
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
-PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
+static PyObject* FindMessageByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
@@ -178,7 +172,8 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
}
const Descriptor* message_descriptor =
- self->pool->FindMessageTypeByName(string(name, name_size));
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName(
+ string(name, name_size));
if (message_descriptor == NULL) {
PyErr_Format(PyExc_KeyError, "Couldn't find message %.200s", name);
@@ -188,37 +183,10 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
return PyMessageDescriptor_FromDescriptor(message_descriptor);
}
-// Add a message class to our database.
-int RegisterMessageClass(PyDescriptorPool* self,
- const Descriptor *message_descriptor,
- PyObject *message_class) {
- Py_INCREF(message_class);
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
- std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
- std::make_pair(message_descriptor, message_class));
- if (!ret.second) {
- // Update case: DECREF the previous value.
- Py_DECREF(ret.first->second);
- ret.first->second = message_class;
- }
- return 0;
-}
-// Retrieve the message class added to our database.
-PyObject *GetMessageClass(PyDescriptorPool* self,
- const Descriptor *message_descriptor) {
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
- iterator ret = self->classes_by_descriptor->find(message_descriptor);
- if (ret == self->classes_by_descriptor->end()) {
- PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
- message_descriptor->full_name().c_str());
- return NULL;
- } else {
- return ret->second;
- }
-}
-PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
+
+static PyObject* FindFileByName(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
@@ -226,13 +194,12 @@ PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
}
const FileDescriptor* file_descriptor =
- self->pool->FindFileByName(string(name, name_size));
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileByName(
+ string(name, name_size));
if (file_descriptor == NULL) {
- PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s",
- name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", name);
return NULL;
}
-
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
@@ -254,6 +221,10 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
+static PyObject* FindFieldByNameMethod(PyObject* self, PyObject* arg) {
+ return FindFieldByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
+}
+
PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
@@ -271,6 +242,10 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
+static PyObject* FindExtensionByNameMethod(PyObject* self, PyObject* arg) {
+ return FindExtensionByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
+}
+
PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
@@ -288,6 +263,10 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
return PyEnumDescriptor_FromDescriptor(enum_descriptor);
}
+static PyObject* FindEnumTypeByNameMethod(PyObject* self, PyObject* arg) {
+ return FindEnumTypeByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
+}
+
PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
@@ -305,7 +284,47 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
}
-PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
+static PyObject* FindOneofByNameMethod(PyObject* self, PyObject* arg) {
+ return FindOneofByName(reinterpret_cast<PyDescriptorPool*>(self), arg);
+}
+
+static PyObject* FindServiceByName(PyObject* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const ServiceDescriptor* service_descriptor =
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName(
+ string(name, name_size));
+ if (service_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find service %.200s", name);
+ return NULL;
+ }
+
+ return PyServiceDescriptor_FromDescriptor(service_descriptor);
+}
+
+static PyObject* FindMethodByName(PyObject* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const MethodDescriptor* method_descriptor =
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMethodByName(
+ string(name, name_size));
+ if (method_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name);
+ return NULL;
+ }
+
+ return PyMethodDescriptor_FromDescriptor(method_descriptor);
+}
+
+static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
@@ -313,7 +332,8 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
}
const FileDescriptor* file_descriptor =
- self->pool->FindFileContainingSymbol(string(name, name_size));
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileContainingSymbol(
+ string(name, name_size));
if (file_descriptor == NULL) {
PyErr_Format(PyExc_KeyError, "Couldn't find symbol %.200s", name);
return NULL;
@@ -322,6 +342,53 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
+static PyObject* FindExtensionByNumber(PyObject* self, PyObject* args) {
+ PyObject* message_descriptor;
+ int number;
+ if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) {
+ return NULL;
+ }
+ const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(
+ message_descriptor);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+
+ const FieldDescriptor* extension_descriptor =
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByNumber(
+ descriptor, number);
+ if (extension_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find extension %d", number);
+ return NULL;
+ }
+
+ return PyFieldDescriptor_FromDescriptor(extension_descriptor);
+}
+
+static PyObject* FindAllExtensions(PyObject* self, PyObject* arg) {
+ const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+
+ std::vector<const FieldDescriptor*> extensions;
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindAllExtensions(
+ descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyList_New(extensions.size()));
+ if (result == NULL) {
+ return NULL;
+ }
+ for (int i = 0; i < extensions.size(); i++) {
+ PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]);
+ if (extension == NULL) {
+ return NULL;
+ }
+ PyList_SET_ITEM(result.get(), i, extension); // Steals the reference.
+ }
+ return result.release();
+}
+
// These functions should not exist -- the only valid way to create
// descriptors is to call Add() or AddSerializedFile().
// But these AddDescriptor() functions were created in Python and some people
@@ -331,14 +398,15 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
// call a function that will just be a no-op?
// TODO(amauryfa): Need to investigate further.
-PyObject* AddFileDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+static PyObject* AddFileDescriptor(PyObject* self, PyObject* descriptor) {
const FileDescriptor* file_descriptor =
PyFileDescriptor_AsDescriptor(descriptor);
if (!file_descriptor) {
return NULL;
}
if (file_descriptor !=
- self->pool->FindFileByName(file_descriptor->name())) {
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindFileByName(
+ file_descriptor->name())) {
PyErr_Format(PyExc_ValueError,
"The file descriptor %s does not belong to this pool",
file_descriptor->name().c_str());
@@ -347,14 +415,15 @@ PyObject* AddFileDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
Py_RETURN_NONE;
}
-PyObject* AddDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+static PyObject* AddDescriptor(PyObject* self, PyObject* descriptor) {
const Descriptor* message_descriptor =
PyMessageDescriptor_AsDescriptor(descriptor);
if (!message_descriptor) {
return NULL;
}
if (message_descriptor !=
- self->pool->FindMessageTypeByName(message_descriptor->full_name())) {
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindMessageTypeByName(
+ message_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The message descriptor %s does not belong to this pool",
message_descriptor->full_name().c_str());
@@ -363,14 +432,15 @@ PyObject* AddDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
Py_RETURN_NONE;
}
-PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+static PyObject* AddEnumDescriptor(PyObject* self, PyObject* descriptor) {
const EnumDescriptor* enum_descriptor =
PyEnumDescriptor_AsDescriptor(descriptor);
if (!enum_descriptor) {
return NULL;
}
if (enum_descriptor !=
- self->pool->FindEnumTypeByName(enum_descriptor->full_name())) {
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindEnumTypeByName(
+ enum_descriptor->full_name())) {
PyErr_Format(PyExc_ValueError,
"The enum descriptor %s does not belong to this pool",
enum_descriptor->full_name().c_str());
@@ -379,8 +449,41 @@ PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
Py_RETURN_NONE;
}
-// The code below loads new Descriptors from a serialized FileDescriptorProto.
+static PyObject* AddExtensionDescriptor(PyObject* self, PyObject* descriptor) {
+ const FieldDescriptor* extension_descriptor =
+ PyFieldDescriptor_AsDescriptor(descriptor);
+ if (!extension_descriptor) {
+ return NULL;
+ }
+ if (extension_descriptor !=
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindExtensionByName(
+ extension_descriptor->full_name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The extension descriptor %s does not belong to this pool",
+ extension_descriptor->full_name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+static PyObject* AddServiceDescriptor(PyObject* self, PyObject* descriptor) {
+ const ServiceDescriptor* service_descriptor =
+ PyServiceDescriptor_AsDescriptor(descriptor);
+ if (!service_descriptor) {
+ return NULL;
+ }
+ if (service_descriptor !=
+ reinterpret_cast<PyDescriptorPool*>(self)->pool->FindServiceByName(
+ service_descriptor->full_name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The service descriptor %s does not belong to this pool",
+ service_descriptor->full_name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+// The code below loads new Descriptors from a serialized FileDescriptorProto.
// Collects errors that occur during proto file building to allow them to be
// propagated in the python exception instead of only living in ERROR logs.
@@ -407,7 +510,8 @@ class BuildFileErrorCollector : public DescriptorPool::ErrorCollector {
bool had_errors;
};
-PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
+static PyObject* AddSerializedFile(PyObject* pself, PyObject* serialized_pb) {
+ PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
char* message_type;
Py_ssize_t message_len;
@@ -455,7 +559,7 @@ PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
descriptor, serialized_pb);
}
-PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) {
+static PyObject* Add(PyObject* self, PyObject* file_descriptor_proto) {
ScopedPyObjectPtr serialized_pb(
PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL));
if (serialized_pb == NULL) {
@@ -465,35 +569,47 @@ PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) {
}
static PyMethodDef Methods[] = {
- { "Add", (PyCFunction)Add, METH_O,
+ { "Add", Add, METH_O,
"Adds the FileDescriptorProto and its types to this pool." },
- { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O,
+ { "AddSerializedFile", AddSerializedFile, METH_O,
"Adds a serialized FileDescriptorProto to this pool." },
// TODO(amauryfa): Understand why the Python implementation differs from
// this one, ask users to use another API and deprecate these functions.
- { "AddFileDescriptor", (PyCFunction)AddFileDescriptor, METH_O,
+ { "AddFileDescriptor", AddFileDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
+ { "AddDescriptor", AddDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
+ { "AddEnumDescriptor", AddEnumDescriptor, METH_O,
"No-op. Add() must have been called before." },
- { "AddDescriptor", (PyCFunction)AddDescriptor, METH_O,
+ { "AddExtensionDescriptor", AddExtensionDescriptor, METH_O,
"No-op. Add() must have been called before." },
- { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O,
+ { "AddServiceDescriptor", AddServiceDescriptor, METH_O,
"No-op. Add() must have been called before." },
- { "FindFileByName", (PyCFunction)FindFileByName, METH_O,
+ { "FindFileByName", FindFileByName, METH_O,
"Searches for a file descriptor by its .proto name." },
- { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O,
+ { "FindMessageTypeByName", FindMessageByName, METH_O,
"Searches for a message descriptor by full name." },
- { "FindFieldByName", (PyCFunction)FindFieldByName, METH_O,
+ { "FindFieldByName", FindFieldByNameMethod, METH_O,
"Searches for a field descriptor by full name." },
- { "FindExtensionByName", (PyCFunction)FindExtensionByName, METH_O,
+ { "FindExtensionByName", FindExtensionByNameMethod, METH_O,
"Searches for extension descriptor by full name." },
- { "FindEnumTypeByName", (PyCFunction)FindEnumTypeByName, METH_O,
+ { "FindEnumTypeByName", FindEnumTypeByNameMethod, METH_O,
"Searches for enum type descriptor by full name." },
- { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O,
+ { "FindOneofByName", FindOneofByNameMethod, METH_O,
"Searches for oneof descriptor by full name." },
+ { "FindServiceByName", FindServiceByName, METH_O,
+ "Searches for service descriptor by full name." },
+ { "FindMethodByName", FindMethodByName, METH_O,
+ "Searches for method descriptor by full name." },
- { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O,
+ { "FindFileContainingSymbol", FindFileContainingSymbol, METH_O,
"Gets the FileDescriptor containing the specified symbol." },
+ { "FindExtensionByNumber", FindExtensionByNumber, METH_VARARGS,
+ "Gets the extension descriptor for the given number." },
+ { "FindAllExtensions", FindAllExtensions, METH_O,
+ "Gets all known extensions of the given message descriptor." },
{NULL}
};
@@ -504,7 +620,7 @@ PyTypeObject PyDescriptorPool_Type = {
FULL_MODULE_NAME ".DescriptorPool", // tp_name
sizeof(PyDescriptorPool), // tp_basicsize
0, // tp_itemsize
- (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
+ cdescriptor_pool::Dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
index 16bc910c..53ee53dc 100644
--- a/python/google/protobuf/pyext/descriptor_pool.h
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -38,10 +38,13 @@
namespace google {
namespace protobuf {
-class MessageFactory;
-
namespace python {
+struct PyMessageFactory;
+
+// The (meta) type of all Messages classes.
+struct CMessageClass;
+
// Wraps operations to the global DescriptorPool which contains information
// about all messages and fields.
//
@@ -66,20 +69,10 @@ typedef struct PyDescriptorPool {
// This pointer is owned.
const DescriptorDatabase* database;
- // DynamicMessageFactory used to create C++ instances of messages.
- // This object cache the descriptors that were used, so the DescriptorPool
- // needs to get rid of it before it can delete itself.
- //
- // Note: A C++ MessageFactory is different from the Python MessageFactory.
- // The C++ one creates messages, when the Python one creates classes.
- MessageFactory* message_factory;
-
- // Make our own mapping to retrieve Python classes from C++ descriptors.
- //
- // Descriptor pointers stored here are owned by the DescriptorPool above.
- // Python references to classes are owned by this PyDescriptorPool.
- typedef hash_map<const Descriptor*, PyObject*> ClassesByMessageMap;
- ClassesByMessageMap* classes_by_descriptor;
+ // The preferred MessageFactory to be used by descriptors.
+ // TODO(amauryfa): Don't create the Factory from the DescriptorPool, but
+ // use the one passed while creating message classes. And remove this member.
+ PyMessageFactory* py_message_factory;
// Cache the options for any kind of descriptor.
// Descriptor pointers are owned by the DescriptorPool above.
@@ -92,24 +85,12 @@ extern PyTypeObject PyDescriptorPool_Type;
namespace cdescriptor_pool {
+
// Looks up a message by name.
// Returns a message Descriptor, or NULL if not found.
const Descriptor* FindMessageTypeByName(PyDescriptorPool* self,
const string& name);
-// Registers a new Python class for the given message descriptor.
-// On error, returns -1 with a Python exception set.
-int RegisterMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor,
- PyObject* message_class);
-
-// Retrieves the Python class registered with the given message descriptor.
-//
-// Returns a *borrowed* reference if found, otherwise returns NULL with an
-// exception set.
-PyObject* GetMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor);
-
// The functions below are also exposed as methods of the DescriptorPool type.
// Looks up a message by name. Returns a PyMessageDescriptor corresponding to
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index 555bd293..018b5c2c 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -32,19 +32,30 @@
// Author: tibell@google.com (Johan Tibell)
#include <google/protobuf/pyext/extension_dict.h>
+#include <memory>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor.h>
-#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#include <google/protobuf/stubs/shared_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
namespace google {
namespace protobuf {
@@ -60,35 +71,6 @@ PyObject* len(ExtensionDict* self) {
#endif
}
-// TODO(tibell): Use VisitCompositeField.
-int ReleaseExtension(ExtensionDict* self,
- PyObject* extension,
- const FieldDescriptor* descriptor) {
- if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- if (repeated_composite_container::Release(
- reinterpret_cast<RepeatedCompositeContainer*>(
- extension)) < 0) {
- return -1;
- }
- } else {
- if (repeated_scalar_container::Release(
- reinterpret_cast<RepeatedScalarContainer*>(
- extension)) < 0) {
- return -1;
- }
- }
- } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- if (cmessage::ReleaseSubMessage(
- self->parent, descriptor,
- reinterpret_cast<CMessage*>(extension)) < 0) {
- return -1;
- }
- }
-
- return 0;
-}
-
PyObject* subscript(ExtensionDict* self, PyObject* key) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == NULL) {
@@ -119,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ // TODO(plabatut): consider building the class on the fly!
PyObject* sub_message = cmessage::InternalGetSubMessage(
self->parent, descriptor);
if (sub_message == NULL) {
@@ -130,9 +113,21 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject *message_class = cdescriptor_pool::GetMessageClass(
- cmessage::GetDescriptorPoolForMessage(self->parent),
+ // On the fly message class creation is needed to support the following
+ // situation:
+ // 1- add FileDescriptor to the pool that contains extensions of a message
+ // defined by another proto file. Do not create any message classes.
+ // 2- instantiate an extended message, and access the extension using
+ // the field descriptor.
+ // 3- the extension submessage fails to be returned, because no class has
+ // been created.
+ // It happens when deserializing text proto format, or when enumerating
+ // fields of a deserialized message.
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
+ cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
+ ScopedPyObjectPtr message_class_handler(
+ reinterpret_cast<PyObject*>(message_class));
if (message_class == NULL) {
return NULL;
}
@@ -183,60 +178,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}
-PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
- const FieldDescriptor* descriptor =
- cmessage::GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
+ char* name;
+ Py_ssize_t name_size;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return NULL;
}
- PyObject* value = PyDict_GetItem(self->values, extension);
- if (self->parent) {
- if (value != NULL) {
- if (ReleaseExtension(self, value, descriptor) < 0) {
- return NULL;
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension =
+ pool->pool->FindExtensionByName(string(name, name_size));
+ if (message_extension == NULL) {
+ // Is is the name of a message set extension?
+ const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
+ string(name, name_size));
+ if (message_descriptor && message_descriptor->extension_count() > 0) {
+ const FieldDescriptor* extension = message_descriptor->extension(0);
+ if (extension->is_extension() &&
+ extension->containing_type()->options().message_set_wire_format() &&
+ extension->type() == FieldDescriptor::TYPE_MESSAGE &&
+ extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
+ message_extension = extension;
}
}
- if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor(
- self->parent, descriptor)) == NULL) {
- return NULL;
- }
}
- if (PyDict_DelItem(self->values, extension) < 0) {
- PyErr_Clear();
+ if (message_extension == NULL) {
+ Py_RETURN_NONE;
}
- Py_RETURN_NONE;
-}
-PyObject* HasExtension(ExtensionDict* self, PyObject* extension) {
- const FieldDescriptor* descriptor =
- cmessage::GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- if (self->parent) {
- return cmessage::HasFieldByDescriptor(self->parent, descriptor);
- } else {
- int exists = PyDict_Contains(self->values, extension);
- if (exists < 0) {
- return NULL;
- }
- return PyBool_FromLong(exists);
- }
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
- ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString(
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name"));
- if (extensions_by_name == NULL) {
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
+ int64 number = PyLong_AsLong(arg);
+ if (number == -1 && PyErr_Occurred()) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_name.get(), name);
- if (result == NULL) {
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
+ self->parent->message->GetDescriptor(), number);
+ if (message_extension == NULL) {
Py_RETURN_NONE;
- } else {
- Py_INCREF(result);
- return result;
}
+
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
ExtensionDict* NewExtensionDict(CMessage *parent) {
@@ -267,10 +253,10 @@ static PyMappingMethods MpMethods = {
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
static PyMethodDef Methods[] = {
- EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."),
- EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."),
EDMETHOD(_FindExtensionByName, METH_O,
"Finds an extension by name."),
+ EDMETHOD(_FindExtensionByNumber, METH_O,
+ "Finds an extension by field number."),
{ NULL, NULL }
};
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
index d92cf956..0de2c4ee 100644
--- a/python/google/protobuf/pyext/extension_dict.h
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -37,9 +37,8 @@
#include <Python.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
+
+#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {
@@ -47,16 +46,8 @@ namespace protobuf {
class Message;
class FieldDescriptor;
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-#else
-using internal::shared_ptr;
-#endif
-
namespace python {
-struct CMessage;
-
typedef struct ExtensionDict {
PyObject_HEAD;
@@ -64,7 +55,7 @@ typedef struct ExtensionDict {
// proto tree. Every Python container class holds a
// reference to it in order to keep it alive as long as there's a
// Python object that references any part of the tree.
- shared_ptr<Message> owner;
+ CMessage::OwnerRef owner;
// Weak reference to parent message. Used to make sure
// the parent is writable when an extension field is modified.
@@ -86,43 +77,6 @@ namespace extension_dict {
// Builds an Extensions dict for a specific message.
ExtensionDict* NewExtensionDict(CMessage *parent);
-// Gets the number of extension values in this ExtensionDict as a python object.
-//
-// Returns a new reference.
-PyObject* len(ExtensionDict* self);
-
-// Releases extensions referenced outside this dictionary to keep outside
-// references alive.
-//
-// Returns 0 on success, -1 on failure.
-int ReleaseExtension(ExtensionDict* self,
- PyObject* extension,
- const FieldDescriptor* descriptor);
-
-// Gets an extension from the dict for the given extension descriptor.
-//
-// Returns a new reference.
-PyObject* subscript(ExtensionDict* self, PyObject* key);
-
-// Assigns a value to an extension in the dict. Can only be used for singular
-// simple types.
-//
-// Returns 0 on success, -1 on failure.
-int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value);
-
-// Clears an extension from the dict. Will release the extension if there
-// is still an external reference left to it.
-//
-// Returns None on success.
-PyObject* ClearExtension(ExtensionDict* self,
- PyObject* extension);
-
-// Gets an extension from the dict given the extension name as opposed to
-// descriptor.
-//
-// Returns a new reference.
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name);
-
} // namespace extension_dict
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index df9138a4..6d7ee285 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -32,13 +32,16 @@
#include <google/protobuf/pyext/map_container.h>
+#include <memory>
+
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
-#include <google/protobuf/stubs/scoped_ptr.h>
#include <google/protobuf/map_field.h>
#include <google/protobuf/map.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
@@ -70,7 +73,7 @@ class MapReflectionFriend {
struct MapIterator {
PyObject_HEAD;
- scoped_ptr< ::google::protobuf::MapIterator> iter;
+ std::unique_ptr<::google::protobuf::MapIterator> iter;
// A pointer back to the container, so we can notice changes to the version.
// We own a ref on this.
@@ -88,7 +91,7 @@ struct MapIterator {
// as this iterator does. This is solely for the benefit of the MapIterator
// destructor -- we should never actually access the iterator in this state
// except to delete it.
- shared_ptr<Message> owner;
+ CMessage::OwnerRef owner;
// The version of the map when we took the iterator to it.
//
@@ -324,6 +327,33 @@ PyObject* Clear(PyObject* _self) {
Py_RETURN_NONE;
}
+PyObject* GetEntryClass(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ cmessage::GetFactoryForMessage(self->parent),
+ self->parent_field_descriptor->message_type());
+ Py_XINCREF(message_class);
+ return reinterpret_cast<PyObject*>(message_class);
+}
+
+PyObject* MergeFrom(PyObject* _self, PyObject* arg) {
+ MapContainer* self = GetMap(_self);
+ MapContainer* other_map = GetMap(arg);
+ Message* message = self->GetMutableMessage();
+ const Message* other_message = other_map->message;
+ const Reflection* reflection = message->GetReflection();
+ const Reflection* other_reflection = other_message->GetReflection();
+ int count = other_reflection->FieldSize(
+ *other_message, other_map->parent_field_descriptor);
+ for (int i = 0 ; i < count; i ++) {
+ reflection->AddMessage(message, self->parent_field_descriptor)->MergeFrom(
+ other_reflection->GetRepeatedMessage(
+ *other_message, other_map->parent_field_descriptor, i));
+ }
+ self->version++;
+ Py_RETURN_NONE;
+}
+
PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
MapContainer* self = GetMap(_self);
@@ -344,9 +374,10 @@ PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
}
// Initializes the underlying Message object of "to" so it becomes a new parent
-// repeated scalar, and copies all the values from "from" to it. A child scalar
+// map container, and copies all the values from "from" to it. A child map
// container can be released by passing it as both from and to (e.g. making it
// the recipient of the new parent message and copying the values from itself).
+// In fact, this is the only supported use at the moment.
static int InitializeAndCopyToParentContainer(MapContainer* from,
MapContainer* to) {
// For now we require from == to, re-evaluate if we want to support deep copy
@@ -358,7 +389,7 @@ static int InitializeAndCopyToParentContainer(MapContainer* from,
// A somewhat roundabout way of copying just one field from old_message to
// new_message. This is the best we can do with what Reflection gives us.
Message* mutable_old = from->GetMutableMessage();
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
fields.push_back(from->parent_field_descriptor);
// Move the map field into the new message.
@@ -395,12 +426,7 @@ PyObject *NewScalarMapContainer(
return NULL;
}
-#if PY_MAJOR_VERSION >= 3
- ScopedPyObjectPtr obj(PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
-#else
- ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
-#endif
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
if (obj.get() == NULL) {
return PyErr_Format(PyExc_RuntimeError,
"Could not allocate new container.");
@@ -522,6 +548,10 @@ static PyMethodDef ScalarMapMethods[] = {
"Removes all elements from the map." },
{ "get", ScalarMapGet, METH_VARARGS,
"Gets the value for the given key if present, or otherwise a default" },
+ { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
+ "Return the class used to build Entries of (key, value) pairs." },
+ { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
+ "Merges a map into the current map." },
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -531,6 +561,7 @@ static PyMethodDef ScalarMapMethods[] = {
{NULL, NULL},
};
+PyTypeObject *ScalarMapContainer_Type;
#if PY_MAJOR_VERSION >= 3
static PyType_Slot ScalarMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void *)ScalarMapDealloc},
@@ -549,7 +580,6 @@ static PyMethodDef ScalarMapMethods[] = {
Py_TPFLAGS_DEFAULT,
ScalarMapContainer_Type_slots
};
- PyObject *ScalarMapContainer_Type;
#else
static PyMappingMethods ScalarMapMappingMethods = {
MapReflectionFriend::Length, // mp_length
@@ -557,7 +587,7 @@ static PyMethodDef ScalarMapMethods[] = {
MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript
};
- PyTypeObject ScalarMapContainer_Type = {
+ PyTypeObject _ScalarMapContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
sizeof(MapContainer), // tp_basicsize
@@ -610,8 +640,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
if (ret == NULL) {
- CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
- message->GetDescriptor());
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class);
ret = reinterpret_cast<PyObject*>(cmsg);
if (cmsg == NULL) {
@@ -634,17 +663,12 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
PyObject* NewMessageMapContainer(
CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
- PyObject* concrete_class) {
+ CMessageClass* message_class) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return NULL;
}
-#if PY_MAJOR_VERSION >= 3
- PyObject* obj = PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
-#else
- PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
-#endif
+ PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
if (obj == NULL) {
return PyErr_Format(PyExc_RuntimeError,
"Could not allocate new container.");
@@ -669,8 +693,8 @@ PyObject* NewMessageMapContainer(
"Could not allocate message dict.");
}
- Py_INCREF(concrete_class);
- self->subclass_init = concrete_class;
+ Py_INCREF(message_class);
+ self->message_class = message_class;
if (self->key_field_descriptor == NULL ||
self->value_field_descriptor == NULL) {
@@ -705,8 +729,33 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
}
// Delete key from map.
- if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
map_key)) {
+ // Delete key from CMessage dict.
+ MapValueRef value;
+ reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value);
+ ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue()));
+
+ PyObject* cmsg_value = PyDict_GetItem(self->message_dict, key.get());
+ if (cmsg_value) {
+ // Need to keep CMessage stay alive if it is still referenced after
+ // deletion. Makes a new message and swaps values into CMessage
+ // instead of just removing.
+ CMessage* cmsg = reinterpret_cast<CMessage*>(cmsg_value);
+ Message* msg = cmsg->message;
+ cmsg->owner.reset(msg->New());
+ cmsg->message = cmsg->owner.get();
+ cmsg->parent = NULL;
+ msg->GetReflection()->Swap(msg, cmsg->message);
+ if (PyDict_DelItem(self->message_dict, key.get()) < 0) {
+ return -1;
+ }
+ }
+
+ // Delete key from map.
+ reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key);
return 0;
} else {
PyErr_Format(PyExc_KeyError, "Key not present in map");
@@ -763,6 +812,7 @@ static void MessageMapDealloc(PyObject* _self) {
MessageMapContainer* self = GetMessageMap(_self);
self->owner.reset();
Py_DECREF(self->message_dict);
+ Py_DECREF(self->message_class);
Py_TYPE(_self)->tp_free(_self);
}
@@ -775,6 +825,10 @@ static PyMethodDef MessageMapMethods[] = {
"Gets the value for the given key if present, or otherwise a default" },
{ "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
"Alias for getitem, useful to make explicit that the map is mutated." },
+ { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
+ "Return the class used to build Entries of (key, value) pairs." },
+ { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
+ "Merges a map into the current map." },
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -784,6 +838,7 @@ static PyMethodDef MessageMapMethods[] = {
{NULL, NULL},
};
+PyTypeObject *MessageMapContainer_Type;
#if PY_MAJOR_VERSION >= 3
static PyType_Slot MessageMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void *)MessageMapDealloc},
@@ -802,8 +857,6 @@ static PyMethodDef MessageMapMethods[] = {
Py_TPFLAGS_DEFAULT,
MessageMapContainer_Type_slots
};
-
- PyObject *MessageMapContainer_Type;
#else
static PyMappingMethods MessageMapMappingMethods = {
MapReflectionFriend::Length, // mp_length
@@ -811,7 +864,7 @@ static PyMethodDef MessageMapMethods[] = {
MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript
};
- PyTypeObject MessageMapContainer_Type = {
+ PyTypeObject _MessageMapContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".MessageMapContainer", // tp_name
sizeof(MessageMapContainer), // tp_basicsize
@@ -960,6 +1013,63 @@ PyTypeObject MapIterator_Type = {
0, // tp_init
};
+bool InitMapContainers() {
+ // ScalarMapContainer_Type derives from our MutableMapping type.
+ ScopedPyObjectPtr containers(PyImport_ImportModule(
+ "google.protobuf.internal.containers"));
+ if (containers == NULL) {
+ return false;
+ }
+
+ ScopedPyObjectPtr mutable_mapping(
+ PyObject_GetAttrString(containers.get(), "MutableMapping"));
+ if (mutable_mapping == NULL) {
+ return false;
+ }
+
+ if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
+ return false;
+ }
+
+ Py_INCREF(mutable_mapping.get());
+#if PY_MAJOR_VERSION >= 3
+ PyObject* bases = PyTuple_New(1);
+ PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
+
+ ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
+ PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases));
+#else
+ _ScalarMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
+ return false;
+ }
+
+ ScalarMapContainer_Type = &_ScalarMapContainer_Type;
+#endif
+
+ if (PyType_Ready(&MapIterator_Type) < 0) {
+ return false;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
+ PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases));
+#else
+ Py_INCREF(mutable_mapping.get());
+ _MessageMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
+ return false;
+ }
+
+ MessageMapContainer_Type = &_MessageMapContainer_Type;
+#endif
+ return true;
+}
+
} // namespace python
} // namespace protobuf
} // namespace google
diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h
index ddf94be7..111fafbf 100644
--- a/python/google/protobuf/pyext/map_container.h
+++ b/python/google/protobuf/pyext/map_container.h
@@ -34,27 +34,19 @@
#include <Python.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {
class Message;
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-#else
-using internal::shared_ptr;
-#endif
-
namespace python {
-struct CMessage;
+struct CMessageClass;
// This struct is used directly for ScalarMap, and is the base class of
// MessageMapContainer, which is used for MessageMap.
@@ -65,7 +57,7 @@ struct MapContainer {
// proto tree. Every Python MapContainer holds a
// reference to it in order to keep it alive as long as there's a
// Python object that references any part of the tree.
- shared_ptr<Message> owner;
+ CMessage::OwnerRef owner;
// Pointer to the C++ Message that contains this container. The
// MapContainer does not own this pointer.
@@ -98,29 +90,21 @@ struct MapContainer {
int Release();
// Set the owner field of self and any children of self.
- void SetOwner(const shared_ptr<Message>& new_owner) {
- owner = new_owner;
- }
+ void SetOwner(const CMessage::OwnerRef& new_owner) { owner = new_owner; }
};
struct MessageMapContainer : public MapContainer {
- // A callable that is used to create new child messages.
- PyObject* subclass_init;
+ // The type used to create new child messages.
+ CMessageClass* message_class;
// A dict mapping Message* -> CMessage.
PyObject* message_dict;
};
-#if PY_MAJOR_VERSION >= 3
- extern PyObject *MessageMapContainer_Type;
- extern PyType_Spec MessageMapContainer_Type_spec;
- extern PyObject *ScalarMapContainer_Type;
- extern PyType_Spec ScalarMapContainer_Type_spec;
-#else
- extern PyTypeObject MessageMapContainer_Type;
- extern PyTypeObject ScalarMapContainer_Type;
-#endif
+bool InitMapContainers();
+extern PyTypeObject* MessageMapContainer_Type;
+extern PyTypeObject* ScalarMapContainer_Type;
extern PyTypeObject MapIterator_Type; // Both map types use the same iterator.
// Builds a MapContainer object, from a parent message and a
@@ -132,7 +116,7 @@ extern PyObject* NewScalarMapContainer(
// field descriptor.
extern PyObject* NewMessageMapContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor,
- PyObject* concrete_class);
+ CMessageClass* message_class);
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 60ec9c1b..53736b9c 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -35,9 +35,6 @@
#include <map>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <string>
#include <vector>
#include <structmember.h> // A Python header file.
@@ -52,6 +49,7 @@
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
@@ -63,11 +61,11 @@
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/map_container.h>
+#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
- #define PyInt_Check PyLong_Check
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyInt_FromSize_t PyLong_FromSize_t
@@ -91,42 +89,26 @@ namespace protobuf {
namespace python {
static PyObject* kDESCRIPTOR;
-static PyObject* k_extensions_by_name;
-static PyObject* k_extensions_by_number;
PyObject* EnumTypeWrapper_class;
static PyObject* PythonMessage_class;
static PyObject* kEmptyWeakref;
static PyObject* WKT_classes = NULL;
-// Defines the Metaclass of all Message classes.
-// It allows us to cache some C++ pointers in the class object itself, they are
-// faster to extract than from the type's dictionary.
-
-struct PyMessageMeta {
- // This is how CPython subclasses C structures: the base structure must be
- // the first member of the object.
- PyHeapTypeObject super;
-
- // C++ descriptor of this message.
- const Descriptor* message_descriptor;
-
- // Owned reference, used to keep the pointer above alive.
- PyObject* py_message_descriptor;
-
- // The Python DescriptorPool used to create the class. It is needed to resolve
- // fields descriptors, including extensions fields; its C++ MessageFactory is
- // used to instantiate submessages.
- // This can be different from DESCRIPTOR.file.pool, in the case of a custom
- // DescriptorPool which defines new extensions.
- // We own the reference, because it's important to keep the descriptors and
- // factory alive.
- PyDescriptorPool* py_descriptor_pool;
-};
-
namespace message_meta {
static int InsertEmptyWeakref(PyTypeObject* base);
+namespace {
+// Copied oveer from internal 'google/protobuf/stubs/strutil.h'.
+inline void UpperString(string * s) {
+ string::iterator end = s->end();
+ for (string::iterator i = s->begin(); i != end; ++i) {
+ // toupper() changes based on locale. We don't want this!
+ if ('a' <= *i && *i <= 'z') *i += 'A' - 'a';
+ }
+}
+}
+
// Add the number of a field descriptor to the containing message class.
// Equivalent to:
// _cls.<field>_FIELD_NUMBER = <number>
@@ -152,19 +134,6 @@ static bool AddFieldNumberToClass(
// Finalize the creation of the Message class.
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
- // If there are extension_ranges, the message is "extendable", and extension
- // classes will register themselves in this class.
- if (descriptor->extension_range_count() > 0) {
- ScopedPyObjectPtr by_name(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) {
- return -1;
- }
- ScopedPyObjectPtr by_number(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) {
- return -1;
- }
- }
-
// For each field set: cls.<field>_FIELD_NUMBER = <number>
for (int i = 0; i < descriptor->field_count(); ++i) {
if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
@@ -173,10 +142,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
}
// For each enum set cls.<enum name> = EnumTypeWrapper(<enum descriptor>).
- //
- // The enum descriptor we get from
- // <messagedescriptor>.enum_types_by_name[name]
- // which was built previously.
for (int i = 0; i < descriptor->enum_type_count(); ++i) {
const EnumDescriptor* enum_descriptor = descriptor->enum_type(i);
ScopedPyObjectPtr enum_type(
@@ -273,6 +238,12 @@ static PyObject* New(PyTypeObject* type,
return NULL;
}
+ // Messages have no __dict__
+ ScopedPyObjectPtr slots(PyTuple_New(0));
+ if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
+ return NULL;
+ }
+
// Build the arguments to the base metaclass.
// We change the __bases__ classes.
ScopedPyObjectPtr new_args;
@@ -309,7 +280,7 @@ static PyObject* New(PyTypeObject* type,
if (result == NULL) {
return NULL;
}
- PyMessageMeta* newtype = reinterpret_cast<PyMessageMeta*>(result.get());
+ CMessageClass* newtype = reinterpret_cast<CMessageClass*>(result.get());
// Insert the empty weakref into the base classes.
if (InsertEmptyWeakref(
@@ -329,16 +300,19 @@ static PyObject* New(PyTypeObject* type,
newtype->message_descriptor = descriptor;
// TODO(amauryfa): Don't always use the canonical pool of the descriptor,
// use the MessageFactory optionally passed in the class dict.
- newtype->py_descriptor_pool = GetDescriptorPool_FromPool(
- descriptor->file()->pool());
- if (newtype->py_descriptor_pool == NULL) {
+ PyDescriptorPool* py_descriptor_pool =
+ GetDescriptorPool_FromPool(descriptor->file()->pool());
+ if (py_descriptor_pool == NULL) {
return NULL;
}
- Py_INCREF(newtype->py_descriptor_pool);
+ newtype->py_message_factory = py_descriptor_pool->py_message_factory;
+ Py_INCREF(newtype->py_message_factory);
- // Add the message to the DescriptorPool.
- if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool,
- descriptor, result.get()) < 0) {
+ // Register the message in the MessageFactory.
+ // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
+ // MessageFactory is fully implemented in C++.
+ if (message_factory::RegisterMessageClass(newtype->py_message_factory,
+ descriptor, newtype) < 0) {
return NULL;
}
@@ -349,9 +323,9 @@ static PyObject* New(PyTypeObject* type,
return result.release();
}
-static void Dealloc(PyMessageMeta *self) {
- Py_DECREF(self->py_message_descriptor);
- Py_DECREF(self->py_descriptor_pool);
+static void Dealloc(CMessageClass *self) {
+ Py_XDECREF(self->py_message_descriptor);
+ Py_XDECREF(self->py_message_factory);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -376,12 +350,67 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) {
#endif // PY_MAJOR_VERSION >= 3
}
+// The _extensions_by_name dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
+static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
+ extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+// The _extensions_by_number dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
+static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+static PyGetSetDef Getters[] = {
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
+ {NULL}
+};
+
} // namespace message_meta
-PyTypeObject PyMessageMeta_Type = {
+PyTypeObject CMessageClass_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".MessageMeta", // tp_name
- sizeof(PyMessageMeta), // tp_basicsize
+ sizeof(CMessageClass), // tp_basicsize
0, // tp_itemsize
(destructor)message_meta::Dealloc, // tp_dealloc
0, // tp_print
@@ -408,7 +437,7 @@ PyTypeObject PyMessageMeta_Type = {
0, // tp_iternext
0, // tp_methods
0, // tp_members
- 0, // tp_getset
+ message_meta::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -419,16 +448,16 @@ PyTypeObject PyMessageMeta_Type = {
message_meta::New, // tp_new
};
-static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) {
- if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) {
+static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
return NULL;
}
- return reinterpret_cast<PyMessageMeta*>(cls);
+ return reinterpret_cast<CMessageClass*>(cls);
}
static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) {
- PyMessageMeta* type = CheckMessageClass(cls);
+ CMessageClass* type = CheckMessageClass(cls);
if (type == NULL) {
return NULL;
}
@@ -544,23 +573,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
// ---------------------------------------------------------------------
-// Constants used for integer type range checking.
-PyObject* kPythonZero;
-PyObject* kint32min_py;
-PyObject* kint32max_py;
-PyObject* kuint32max_py;
-PyObject* kint64min_py;
-PyObject* kint64max_py;
-PyObject* kuint64max_py;
-
PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
-// Constant PyString values used for GetAttr/GetItem.
-static PyObject* k_cdescriptor;
-static PyObject* kfull_name;
-
/* Is 64bit */
void FormatTypeError(PyObject* arg, char* expected_types) {
PyObject* repr = PyObject_Repr(arg);
@@ -574,68 +590,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) {
}
}
-template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max) {
- bool is_long = PyLong_Check(arg);
-#if PY_MAJOR_VERSION < 3
- if (!PyInt_Check(arg) && !is_long) {
- FormatTypeError(arg, "int, long");
- return false;
+void OutOfRangeError(PyObject* arg) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
}
- if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) {
-#else
- if (!is_long) {
- FormatTypeError(arg, "int");
+}
+
+template<class RangeType, class ValueType>
+bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
+ if (GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred())) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ // Replace it with the same ValueError as pure python protos instead of
+ // the default one.
+ PyErr_Clear();
+ OutOfRangeError(arg);
+ } // Otherwise propagate existing error.
return false;
- }
- if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
- PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
-#endif
- if (!PyErr_Occurred()) {
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
- }
}
- return false;
- }
+ if (GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value))) {
+ OutOfRangeError(arg);
+ return false;
+ }
+ return true;
+}
+
+template<class T>
+bool CheckAndGetInteger(PyObject* arg, T* value) {
+ // The fast path.
#if PY_MAJOR_VERSION < 3
- if (!is_long) {
- *value = static_cast<T>(PyInt_AsLong(arg));
- } else // NOLINT
+ // For the typical case, offer a fast path.
+ if (GOOGLE_PREDICT_TRUE(PyInt_Check(arg))) {
+ long int_result = PyInt_AsLong(arg);
+ if (GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result))) {
+ *value = static_cast<T>(int_result);
+ return true;
+ } else {
+ OutOfRangeError(arg);
+ return false;
+ }
+ }
#endif
- {
- if (min == kPythonZero) {
- *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg));
+ // This effectively defines an integer as "an object that can be cast as
+ // an integer and can be used as an ordinal number".
+ // This definition includes everything that implements numbers.Integral
+ // and shouldn't cast the net too wide.
+ if (GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg))) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
+
+ // Now we have an integral number so we can safely use PyLong_ functions.
+ // We need to treat the signed and unsigned cases differently in case arg is
+ // holding a value above the maximum for signed longs.
+ if (std::numeric_limits<T>::min() == 0) {
+ // Unsigned case.
+ unsigned PY_LONG_LONG ulong_result;
+ if (PyLong_Check(arg)) {
+ ulong_result = PyLong_AsUnsignedLongLong(arg);
} else {
- *value = static_cast<T>(PyLong_AsLongLong(arg));
+ // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
+ // picky about the exact type.
+ PyObject* casted = PyNumber_Long(arg);
+ if (GOOGLE_PREDICT_FALSE(casted == nullptr)) {
+ // Propagate existing error.
+ return false;
+ }
+ ulong_result = PyLong_AsUnsignedLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
+ ulong_result)) {
+ *value = static_cast<T>(ulong_result);
+ } else {
+ return false;
+ }
+ } else {
+ // Signed case.
+ PY_LONG_LONG long_result;
+ PyNumberMethods *nb;
+ if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
+ // PyLong_AsLongLong requires it to be a long or to have an __int__()
+ // method.
+ long_result = PyLong_AsLongLong(arg);
+ } else {
+ // Valid subclasses of numbers.Integral should have a __long__() method
+ // so fall back to that.
+ PyObject* casted = PyNumber_Long(arg);
+ if (GOOGLE_PREDICT_FALSE(casted == nullptr)) {
+ // Propagate existing error.
+ return false;
+ }
+ long_result = PyLong_AsLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
+ *value = static_cast<T>(long_result);
+ } else {
+ return false;
}
}
+
return true;
}
// These are referenced by repeated_scalar_container, and must
// be explicitly instantiated.
-template bool CheckAndGetInteger<int32>(
- PyObject*, int32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<int64>(
- PyObject*, int64*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint32>(
- PyObject*, uint32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint64>(
- PyObject*, uint64*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<int32>(PyObject*, int32*);
+template bool CheckAndGetInteger<int64>(PyObject*, int64*);
+template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
+template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
bool CheckAndGetDouble(PyObject* arg, double* value) {
- if (!PyInt_Check(arg) && !PyLong_Check(arg) &&
- !PyFloat_Check(arg)) {
+ *value = PyFloat_AsDouble(arg);
+ if (GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred())) {
FormatTypeError(arg, "int, long, float");
return false;
- }
- *value = PyFloat_AsDouble(arg);
+ }
return true;
}
@@ -649,11 +723,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) {
}
bool CheckAndGetBool(PyObject* arg, bool* value) {
- if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) {
+ long long_value = PyInt_AsLong(arg);
+ if (long_value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, bool");
return false;
}
- *value = static_cast<bool>(PyInt_AsLong(arg));
+ *value = static_cast<bool>(long_value);
+
return true;
}
@@ -711,7 +787,7 @@ PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) {
encoded_string = arg; // Already encoded.
Py_INCREF(encoded_string);
} else {
- encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
+ encoded_string = PyUnicode_AsEncodedString(arg, "utf-8", NULL);
}
} else {
// In this case field type is "bytes".
@@ -751,7 +827,8 @@ bool CheckAndSetString(
return true;
}
-PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) {
+PyObject* ToStringObject(const FieldDescriptor* descriptor,
+ const string& value) {
if (descriptor->type() != FieldDescriptor::TYPE_STRING) {
return PyBytes_FromStringAndSize(value.c_str(), value.length());
}
@@ -781,15 +858,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
namespace cmessage {
-PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) {
- // No need to check the type: the type of instances of CMessage is always
- // an instance of PyMessageMeta. Let's prove it with a debug-only check.
+PyMessageFactory* GetFactoryForMessage(CMessage* message) {
GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
- return reinterpret_cast<PyMessageMeta*>(Py_TYPE(message))->py_descriptor_pool;
-}
-
-MessageFactory* GetFactoryForMessage(CMessage* message) {
- return GetDescriptorPoolForMessage(message)->message_factory;
+ return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
}
static int MaybeReleaseOverlappingOneofField(
@@ -842,7 +913,8 @@ static Message* GetMutableMessage(
return NULL;
}
return reflection->MutableMessage(
- parent_message, parent_field, GetFactoryForMessage(parent));
+ parent_message, parent_field,
+ GetFactoryForMessage(parent)->message_factory);
}
struct FixupMessageReference : public ChildVisitor {
@@ -990,28 +1062,17 @@ int InternalDeleteRepeatedField(
int min, max;
length = reflection->FieldSize(*message, field_descriptor);
- if (PyInt_Check(slice) || PyLong_Check(slice)) {
- from = to = PyLong_AsLong(slice);
- if (from < 0) {
- from = to = length + from;
- }
- step = 1;
- min = max = from;
-
- // Range check.
- if (from < 0 || from >= length) {
- PyErr_Format(PyExc_IndexError, "list assignment index out of range");
- return -1;
- }
- } else if (PySlice_Check(slice)) {
+ if (PySlice_Check(slice)) {
from = to = step = slice_length = 0;
- PySlice_GetIndicesEx(
#if PY_MAJOR_VERSION < 3
+ PySlice_GetIndicesEx(
reinterpret_cast<PySliceObject*>(slice),
+ length, &from, &to, &step, &slice_length);
#else
+ PySlice_GetIndicesEx(
slice,
-#endif
length, &from, &to, &step, &slice_length);
+#endif
if (from < to) {
min = from;
max = to - 1;
@@ -1020,8 +1081,23 @@ int InternalDeleteRepeatedField(
max = from;
}
} else {
- PyErr_SetString(PyExc_TypeError, "list indices must be integers");
- return -1;
+ from = to = PyLong_AsLong(slice);
+ if (from == -1 && PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ if (from < 0) {
+ from = to = length + from;
+ }
+ step = 1;
+ min = max = from;
+
+ // Range check.
+ if (from < 0 || from >= length) {
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range");
+ return -1;
+ }
}
Py_ssize_t i = from;
@@ -1070,7 +1146,12 @@ int InternalDeleteRepeatedField(
}
// Initializes fields of a message. Used in constructors.
-int InitAttributes(CMessage* self, PyObject* kwargs) {
+int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
+ if (args != NULL && PyTuple_Size(args) != 0) {
+ PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
+ return -1;
+ }
+
if (kwargs == NULL) {
return 0;
}
@@ -1090,8 +1171,12 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
PyString_AsString(name));
return -1;
}
+ if (value == Py_None) {
+ // field=None is the same as no field at all.
+ continue;
+ }
if (descriptor->is_map()) {
- ScopedPyObjectPtr map(GetAttr(self, name));
+ ScopedPyObjectPtr map(GetAttr(reinterpret_cast<PyObject*>(self), name));
const FieldDescriptor* value_descriptor =
descriptor->message_type()->FindFieldByName("value");
if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -1119,7 +1204,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
}
} else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- ScopedPyObjectPtr container(GetAttr(self, name));
+ ScopedPyObjectPtr container(
+ GetAttr(reinterpret_cast<PyObject*>(self), name));
if (container == NULL) {
return -1;
}
@@ -1186,13 +1272,16 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
}
}
} else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- ScopedPyObjectPtr message(GetAttr(self, name));
+ ScopedPyObjectPtr message(
+ GetAttr(reinterpret_cast<PyObject*>(self), name));
if (message == NULL) {
return -1;
}
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
if (PyDict_Check(value)) {
- if (InitAttributes(cmessage, value) < 0) {
+ // Make the message exist even if the dict is empty.
+ AssureWritable(cmessage);
+ if (InitAttributes(cmessage, NULL, value) < 0) {
return -1;
}
} else {
@@ -1209,8 +1298,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
return -1;
}
}
- if (SetAttr(self, name, (new_val.get() == NULL) ? value : new_val.get()) <
- 0) {
+ if (SetAttr(reinterpret_cast<PyObject*>(self), name,
+ (new_val.get() == NULL) ? value : new_val.get()) < 0) {
return -1;
}
}
@@ -1220,13 +1309,15 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
// Allocates an incomplete Python Message: the caller must fill self->message,
// self->owner and eventually self->parent.
-CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
+CMessage* NewEmptyMessage(CMessageClass* type) {
CMessage* self = reinterpret_cast<CMessage*>(
- PyType_GenericAlloc(reinterpret_cast<PyTypeObject*>(type), 0));
+ PyType_GenericAlloc(&type->super.ht_type, 0));
if (self == NULL) {
return NULL;
}
+ // Use "placement new" syntax to initialize the C++ object.
+ new (&self->owner) CMessage::OwnerRef(NULL);
self->message = NULL;
self->parent = NULL;
self->parent_field_descriptor = NULL;
@@ -1242,7 +1333,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) {
// Creates a new C++ message and takes ownership.
static PyObject* New(PyTypeObject* cls,
PyObject* unused_args, PyObject* unused_kwargs) {
- PyMessageMeta* type = CheckMessageClass(cls);
+ CMessageClass* type = CheckMessageClass(cls);
if (type == NULL) {
return NULL;
}
@@ -1251,15 +1342,14 @@ static PyObject* New(PyTypeObject* cls,
if (message_descriptor == NULL) {
return NULL;
}
- const Message* default_message = type->py_descriptor_pool->message_factory
+ const Message* default_message = type->py_message_factory->message_factory
->GetPrototype(message_descriptor);
if (default_message == NULL) {
PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
return NULL;
}
- CMessage* self = NewEmptyMessage(reinterpret_cast<PyObject*>(type),
- message_descriptor);
+ CMessage* self = NewEmptyMessage(type);
if (self == NULL) {
return NULL;
}
@@ -1271,12 +1361,7 @@ static PyObject* New(PyTypeObject* cls,
// The __init__ method of Message classes.
// It initializes fields from keywords passed to the constructor.
static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
- if (PyTuple_Size(args) != 0) {
- PyErr_SetString(PyExc_TypeError, "No positional arguments allowed");
- return -1;
- }
-
- return InitAttributes(self, kwargs);
+ return InitAttributes(self, args, kwargs);
}
// ---------------------------------------------------------------------
@@ -1318,6 +1403,9 @@ struct ClearWeakReferences : public ChildVisitor {
};
static void Dealloc(CMessage* self) {
+ if (self->weakreflist) {
+ PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
+ }
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
if (self->extensions) {
@@ -1326,7 +1414,7 @@ static void Dealloc(CMessage* self) {
Py_CLEAR(self->extensions);
Py_CLEAR(self->composite_fields);
- self->owner.reset();
+ self->owner.~ThreadUnsafeSharedPtr<Message>();
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -1467,36 +1555,25 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
if (message->GetReflection()->HasField(*message, field_descriptor)) {
Py_RETURN_TRUE;
}
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
- // Special case: Python HasField() differs in semantics from C++
- // slightly: we return HasField('enum_field') == true if there is
- // an unknown enum value present. To implement this we have to
- // look in the UnknownFieldSet.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() == field_descriptor->number()) {
- Py_RETURN_TRUE;
- }
- }
- }
+
Py_RETURN_FALSE;
}
PyObject* ClearExtension(CMessage* self, PyObject* extension) {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
+ }
if (self->extensions != NULL) {
- return extension_dict::ClearExtension(self->extensions, extension);
- } else {
- const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) {
- return NULL;
+ PyObject* value = PyDict_GetItem(self->extensions->values, extension);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->extensions->values, extension);
}
}
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, descriptor);
}
PyObject* HasExtension(CMessage* self, PyObject* extension) {
@@ -1539,9 +1616,10 @@ PyObject* HasExtension(CMessage* self, PyObject* extension) {
// * Clear the weak references from the released container to the
// parent.
-struct SetOwnerVisitor : public ChildVisitor {
+class SetOwnerVisitor : public ChildVisitor {
+ public:
// new_owner must outlive this object.
- explicit SetOwnerVisitor(const shared_ptr<Message>& new_owner)
+ explicit SetOwnerVisitor(const CMessage::OwnerRef& new_owner)
: new_owner_(new_owner) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
@@ -1565,11 +1643,11 @@ struct SetOwnerVisitor : public ChildVisitor {
}
private:
- const shared_ptr<Message>& new_owner_;
+ const CMessage::OwnerRef& new_owner_;
};
// Change the owner of this CMessage and all its children, recursively.
-int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
+int SetOwner(CMessage* self, const CMessage::OwnerRef& new_owner) {
self->owner = new_owner;
if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1)
return -1;
@@ -1582,7 +1660,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
Message* ReleaseMessage(CMessage* self,
const Descriptor* descriptor,
const FieldDescriptor* field_descriptor) {
- MessageFactory* message_factory = GetFactoryForMessage(self);
+ MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory;
Message* released_message = self->message->GetReflection()->ReleaseMessage(
self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
@@ -1602,7 +1680,7 @@ int ReleaseSubMessage(CMessage* self,
const FieldDescriptor* field_descriptor,
CMessage* child_cmessage) {
// Release the Message
- shared_ptr<Message> released_message(ReleaseMessage(
+ CMessage::OwnerRef released_message(ReleaseMessage(
self, child_cmessage->message->GetDescriptor(), field_descriptor));
child_cmessage->message = released_message.get();
child_cmessage->owner.swap(released_message);
@@ -1619,23 +1697,20 @@ struct ReleaseChild : public ChildVisitor {
parent_(parent) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
- return repeated_composite_container::Release(
- reinterpret_cast<RepeatedCompositeContainer*>(container));
+ return repeated_composite_container::Release(container);
}
int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
- return repeated_scalar_container::Release(
- reinterpret_cast<RepeatedScalarContainer*>(container));
+ return repeated_scalar_container::Release(container);
}
int VisitMapContainer(MapContainer* container) {
- return reinterpret_cast<MapContainer*>(container)->Release();
+ return container->Release();
}
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
- return ReleaseSubMessage(parent_, field_descriptor,
- reinterpret_cast<CMessage*>(cmessage));
+ return ReleaseSubMessage(parent_, field_descriptor, cmessage);
}
CMessage* parent_;
@@ -1653,12 +1728,13 @@ int InternalReleaseFieldByDescriptor(
PyObject* ClearFieldByDescriptor(
CMessage* self,
- const FieldDescriptor* descriptor) {
- if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
+ const FieldDescriptor* field_descriptor) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
return NULL;
}
AssureWritable(self);
- self->message->GetReflection()->ClearField(self->message, descriptor);
+ Message* message = self->message;
+ message->GetReflection()->ClearField(message, field_descriptor);
Py_RETURN_NONE;
}
@@ -1694,27 +1770,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
arg = arg_in_oneof.get();
}
- PyObject* composite_field = self->composite_fields ?
- PyDict_GetItem(self->composite_fields, arg) : NULL;
-
- // Only release the field if there's a possibility that there are
- // references to it.
- if (composite_field != NULL) {
- if (InternalReleaseFieldByDescriptor(self, field_descriptor,
- composite_field) < 0) {
- return NULL;
+ // Release the field if it exists in the dict of composite fields.
+ if (self->composite_fields) {
+ PyObject* value = PyDict_GetItem(self->composite_fields, arg);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->composite_fields, arg);
}
- PyDict_DelItem(self->composite_fields, arg);
- }
- message->GetReflection()->ClearField(message, field_descriptor);
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
- !message->GetReflection()->SupportsUnknownEnumValues()) {
- UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- unknown_field_set->DeleteByNumber(field_descriptor->number());
}
-
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, field_descriptor);
}
PyObject* Clear(CMessage* self) {
@@ -1739,8 +1805,25 @@ static string GetMessageName(CMessage* self) {
}
}
-static PyObject* SerializeToString(CMessage* self, PyObject* args) {
- if (!self->message->IsInitialized()) {
+static PyObject* InternalSerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs,
+ bool require_initialized) {
+ // Parse the "deterministic" kwarg; defaults to False.
+ static char* kwlist[] = { "deterministic", 0 };
+ PyObject* deterministic_obj = Py_None;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
+ &deterministic_obj)) {
+ return NULL;
+ }
+ // Preemptively convert to a bool first, so we don't need to back out of
+ // allocating memory if this raises an exception.
+ // NOTE: This is unused later if deterministic == Py_None, but that's fine.
+ int deterministic = PyObject_IsTrue(deterministic_obj);
+ if (deterministic < 0) {
+ return NULL;
+ }
+
+ if (require_initialized && !self->message->IsInitialized()) {
ScopedPyObjectPtr errors(FindInitializationErrors(self));
if (errors == NULL) {
return NULL;
@@ -1778,24 +1861,36 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) {
GetMessageName(self).c_str(), PyString_AsString(joined.get()));
return NULL;
}
- int size = self->message->ByteSize();
- if (size <= 0) {
+
+ // Ok, arguments parsed and errors checked, now encode to a string
+ const size_t size = self->message->ByteSizeLong();
+ if (size == 0) {
return PyBytes_FromString("");
}
PyObject* result = PyBytes_FromStringAndSize(NULL, size);
if (result == NULL) {
return NULL;
}
- char* buffer = PyBytes_AS_STRING(result);
- self->message->SerializeWithCachedSizesToArray(
- reinterpret_cast<uint8*>(buffer));
+ io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
+ io::CodedOutputStream coded_out(&out);
+ if (deterministic_obj != Py_None) {
+ coded_out.SetSerializationDeterministic(deterministic);
+ }
+ self->message->SerializeWithCachedSizes(&coded_out);
+ GOOGLE_CHECK(!coded_out.HadError());
return result;
}
-static PyObject* SerializePartialToString(CMessage* self) {
- string contents;
- self->message->SerializePartialToString(&contents);
- return PyBytes_FromStringAndSize(contents.c_str(), contents.size());
+static PyObject* SerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/true);
+}
+
+static PyObject* SerializePartialToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/false);
}
// Formats proto fields for ascii dumps using python formatting functions where
@@ -1851,8 +1946,12 @@ static PyObject* ToStr(CMessage* self) {
PyObject* MergeFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1860,8 +1959,8 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to merge from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1879,8 +1978,12 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
- PyErr_SetString(PyExc_TypeError, "Must be a message");
+ if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
+ self->message->GetDescriptor()->full_name().c_str(),
+ Py_TYPE(arg)->tp_name);
return NULL;
}
@@ -1893,8 +1996,8 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
if (other_message->message->GetDescriptor() !=
self->message->GetDescriptor()) {
PyErr_Format(PyExc_TypeError,
- "Tried to copy from a message with a different type. "
- "to: %s, from: %s",
+ "Parameter to CopyFrom() must be instance of same class: "
+ "expected %s got %s.",
self->message->GetDescriptor()->full_name().c_str(),
other_message->message->GetDescriptor()->full_name().c_str());
return NULL;
@@ -1911,6 +2014,34 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
Py_RETURN_NONE;
}
+// Protobuf has a 64MB limit built in, this variable will override this. Please
+// do not enable this unless you fully understand the implications: protobufs
+// must all be kept in memory at the same time, so if they grow too big you may
+// get OOM errors. The protobuf APIs do not provide any tools for processing
+// protobufs in chunks. If you have protos this big you should break them up if
+// it is at all convenient to do so.
+#ifdef PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
+static bool allow_oversize_protos = true;
+#else
+static bool allow_oversize_protos = false;
+#endif
+
+// Provide a method in the module to set allow_oversize_protos to a boolean
+// value. This method returns the newly value of allow_oversize_protos.
+PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) {
+ if (!arg || !PyBool_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Argument to SetAllowOversizeProtos must be boolean");
+ return NULL;
+ }
+ allow_oversize_protos = PyObject_IsTrue(arg);
+ if (allow_oversize_protos) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
const void* data;
Py_ssize_t data_length;
@@ -1921,19 +2052,18 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
AssureWritable(self);
io::CodedInputStream input(
reinterpret_cast<const uint8*>(data), data_length);
-#if PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
- // Protobuf has a 64MB limit built in, this code will override this. Please do
- // not enable this unless you fully understand the implications: protobufs
- // must all be kept in memory at the same time, so if they grow too big you
- // may get OOM errors. The protobuf APIs do not provide any tools for
- // processing protobufs in chunks. If you have protos this big you should
- // break them up if it is at all convenient to do so.
- input.SetTotalBytesLimit(INT_MAX, INT_MAX);
-#endif // PROTOBUF_PYTHON_ALLOW_OVERSIZE_PROTOS
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
- input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ if (allow_oversize_protos) {
+ input.SetTotalBytesLimit(INT_MAX, INT_MAX);
+ }
+ PyMessageFactory* factory = GetFactoryForMessage(self);
+ input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
if (success) {
+ if (!input.ConsumedEntireMessage()) {
+ // TODO(jieluo): Raise error and return NULL instead.
+ // b/27494216
+ PyErr_Warn(NULL, "Unexpected end-group tag: Not all data was converted");
+ }
return PyInt_FromLong(input.CurrentPosition());
} else {
PyErr_Format(DecodeError_class, "Error parsing message");
@@ -1952,75 +2082,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
return PyLong_FromLong(self->message->ByteSize());
}
-static PyObject* RegisterExtension(PyObject* cls,
- PyObject* extension_handle) {
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
const FieldDescriptor* descriptor =
GetExtensionDescriptor(extension_handle);
if (descriptor == NULL) {
return NULL;
}
-
- ScopedPyObjectPtr extensions_by_name(
- PyObject_GetAttr(cls, k_extensions_by_name));
- if (extensions_by_name == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class");
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
+ cls->ob_type->tp_name);
return NULL;
}
- ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name));
- if (full_name == NULL) {
+ CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
+ if (message_class == NULL) {
return NULL;
}
-
// If the extension was already registered, check that it is the same.
- PyObject* existing_extension =
- PyDict_GetItem(extensions_by_name.get(), full_name.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
-
- if (PyDict_SetItem(extensions_by_name.get(), full_name.get(),
- extension_handle) < 0) {
- return NULL;
- }
-
- // Also store a mapping from extension number to implementing class.
- ScopedPyObjectPtr extensions_by_number(
- PyObject_GetAttr(cls, k_extensions_by_number));
- if (extensions_by_number == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
- return NULL;
- }
- ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
- if (number == NULL) {
- return NULL;
- }
- if (PyDict_SetItem(extensions_by_number.get(), number.get(),
- extension_handle) < 0) {
+ const FieldDescriptor* existing_extension =
+ message_class->py_message_factory->pool->pool->FindExtensionByNumber(
+ descriptor->containing_type(), descriptor->number());
+ if (existing_extension != NULL && existing_extension != descriptor) {
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
return NULL;
}
-
- // Check if it's a message set
- if (descriptor->is_extension() &&
- descriptor->containing_type()->options().message_set_wire_format() &&
- descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
- descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
- ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
- descriptor->message_type()->full_name().c_str(),
- descriptor->message_type()->full_name().size()));
- if (message_name == NULL) {
- return NULL;
- }
- PyDict_SetItem(extensions_by_name.get(), message_name.get(),
- extension_handle);
- }
-
Py_RETURN_NONE;
}
@@ -2057,7 +2141,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
static PyObject* GetExtensionDict(CMessage* self, void *closure);
static PyObject* ListFields(CMessage* self) {
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
// Normally, the list will be exactly the size of the fields.
@@ -2087,8 +2171,8 @@ static PyObject* ListFields(CMessage* self) {
// is no message class and we cannot retrieve the value.
// TODO(amauryfa): consider building the class on the fly!
if (fields[i]->message_type() != NULL &&
- cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self),
+ message_factory::GetMessageClass(
+ GetFactoryForMessage(self),
fields[i]->message_type()) == NULL) {
PyErr_Clear();
continue;
@@ -2121,7 +2205,8 @@ static PyObject* ListFields(CMessage* self) {
return NULL;
}
- PyObject* field_value = GetAttr(self, py_field_name.get());
+ PyObject* field_value =
+ GetAttr(reinterpret_cast<PyObject*>(self), py_field_name.get());
if (field_value == NULL) {
PyErr_SetObject(PyExc_ValueError, py_field_name.get());
return NULL;
@@ -2132,13 +2217,23 @@ static PyObject* ListFields(CMessage* self) {
PyList_SET_ITEM(all_fields.get(), actual_size, t.release());
++actual_size;
}
- Py_SIZE(all_fields.get()) = actual_size;
+ if (static_cast<size_t>(actual_size) != fields.size() &&
+ (PyList_SetSlice(all_fields.get(), actual_size, fields.size(), NULL) <
+ 0)) {
+ return NULL;
+ }
return all_fields.release();
}
+static PyObject* DiscardUnknownFields(CMessage* self) {
+ AssureWritable(self);
+ self->message->DiscardUnknownFields();
+ Py_RETURN_NONE;
+}
+
PyObject* FindInitializationErrors(CMessage* self) {
Message* message = self->message;
- vector<string> errors;
+ std::vector<string> errors;
message->FindInitializationErrors(&errors);
PyObject* error_list = PyList_New(errors.size());
@@ -2235,32 +2330,16 @@ PyObject* InternalGetScalar(const Message* message,
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
- string value = reflection->GetString(*message, field_descriptor);
+ string scratch;
+ const string& value =
+ reflection->GetStringReference(*message, field_descriptor, &scratch);
result = ToStringObject(field_descriptor, value);
break;
}
case FieldDescriptor::CPPTYPE_ENUM: {
- if (!message->GetReflection()->SupportsUnknownEnumValues() &&
- !message->GetReflection()->HasField(*message, field_descriptor)) {
- // Look for the value in the unknown fields.
- const UnknownFieldSet& unknown_field_set =
- message->GetReflection()->GetUnknownFields(*message);
- for (int i = 0; i < unknown_field_set.field_count(); ++i) {
- if (unknown_field_set.field(i).number() ==
- field_descriptor->number() &&
- unknown_field_set.field(i).type() ==
- google::protobuf::UnknownField::TYPE_VARINT) {
- result = PyInt_FromLong(unknown_field_set.field(i).varint());
- break;
- }
- }
- }
-
- if (result == NULL) {
- const EnumValueDescriptor* enum_value =
- message->GetReflection()->GetEnum(*message, field_descriptor);
- result = PyInt_FromLong(enum_value->number());
- }
+ const EnumValueDescriptor* enum_value =
+ message->GetReflection()->GetEnum(*message, field_descriptor);
+ result = PyInt_FromLong(enum_value->number());
break;
}
default:
@@ -2275,18 +2354,19 @@ PyObject* InternalGetScalar(const Message* message,
PyObject* InternalGetSubMessage(
CMessage* self, const FieldDescriptor* field_descriptor) {
const Reflection* reflection = self->message->GetReflection();
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ PyMessageFactory* factory = GetFactoryForMessage(self);
const Message& sub_message = reflection->GetMessage(
- *self->message, field_descriptor, pool->message_factory);
+ *self->message, field_descriptor, factory->message_factory);
- PyObject *message_class = cdescriptor_pool::GetMessageClass(
- pool, field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
+ factory, field_descriptor->message_type());
+ ScopedPyObjectPtr message_class_handler(
+ reinterpret_cast<PyObject*>(message_class));
if (message_class == NULL) {
return NULL;
}
- CMessage* cmsg = cmessage::NewEmptyMessage(message_class,
- sub_message.GetDescriptor());
+ CMessage* cmsg = cmessage::NewEmptyMessage(message_class);
if (cmsg == NULL) {
return NULL;
}
@@ -2471,7 +2551,10 @@ PyObject* Reduce(CMessage* self) {
if (state == NULL) {
return NULL;
}
- ScopedPyObjectPtr serialized(SerializePartialToString(self));
+ string contents;
+ self->message->SerializePartialToString(&contents);
+ ScopedPyObjectPtr serialized(
+ PyBytes_FromStringAndSize(contents.c_str(), contents.size()));
if (serialized == NULL) {
return NULL;
}
@@ -2531,11 +2614,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) {
return NULL;
}
+static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByName(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
+static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByNumber(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
static PyGetSetDef Getters[] = {
{"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
{NULL}
};
+
static PyMethodDef Methods[] = {
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -2555,6 +2651,8 @@ static PyMethodDef Methods[] = {
"Clears a message field." },
{ "CopyFrom", (PyCFunction)CopyFrom, METH_O,
"Copies a protocol message into the current message." },
+ { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS,
+ "Discards the unknown fields." },
{ "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
METH_NOARGS,
"Finds unset required fields." },
@@ -2577,9 +2675,10 @@ static PyMethodDef Methods[] = {
{ "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
"Registers an extension with the current message." },
{ "SerializePartialToString", (PyCFunction)SerializePartialToString,
- METH_NOARGS,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, even if it isn't initialized." },
- { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS,
+ { "SerializeToString", (PyCFunction)SerializeToString,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, only for initialized messages." },
{ "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
"Sets the has bit of the given field in its parent message." },
@@ -2605,7 +2704,8 @@ static bool SetCompositeField(
return PyDict_SetItem(self->composite_fields, name, value) == 0;
}
-PyObject* GetAttr(CMessage* self, PyObject* name) {
+PyObject* GetAttr(PyObject* pself, PyObject* name) {
+ CMessage* self = reinterpret_cast<CMessage*>(pself);
PyObject* value = self->composite_fields ?
PyDict_GetItem(self->composite_fields, name) : NULL;
if (value != NULL) {
@@ -2624,8 +2724,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
const Descriptor* entry_type = field_descriptor->message_type();
const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject* value_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), value_type->message_type());
+ CMessageClass* value_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), value_type->message_type());
if (value_class == NULL) {
return NULL;
}
@@ -2647,8 +2747,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyObject* py_container = NULL;
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyObject *message_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2683,7 +2783,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
return InternalGetScalar(self->message, field_descriptor);
}
-int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
+int SetAttr(PyObject* pself, PyObject* name, PyObject* value) {
+ CMessage* self = reinterpret_cast<CMessage*>(pself);
if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) {
PyErr_SetString(PyExc_TypeError, "Can't set composite field");
return -1;
@@ -2711,7 +2812,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
PyErr_Format(PyExc_AttributeError,
"Assignment not allowed "
- "(no field \"%s\"in protocol message object).",
+ "(no field \"%s\" in protocol message object).",
PyString_AsString(name));
return -1;
}
@@ -2719,7 +2820,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
} // namespace cmessage
PyTypeObject CMessage_Type = {
- PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0)
+ PyVarObject_HEAD_INIT(&CMessageClass_Type, 0)
FULL_MODULE_NAME ".CMessage", // tp_name
sizeof(CMessage), // tp_basicsize
0, // tp_itemsize
@@ -2728,22 +2829,22 @@ PyTypeObject CMessage_Type = {
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ (reprfunc)cmessage::ToStr, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
0, // tp_as_mapping
PyObject_HashNotImplemented, // tp_hash
0, // tp_call
(reprfunc)cmessage::ToStr, // tp_str
- (getattrofunc)cmessage::GetAttr, // tp_getattro
- (setattrofunc)cmessage::SetAttr, // tp_setattro
+ cmessage::GetAttr, // tp_getattro
+ cmessage::SetAttr, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
"A ProtocolMessage", // tp_doc
0, // tp_traverse
0, // tp_clear
(richcmpfunc)cmessage::RichCompare, // tp_richcompare
- 0, // tp_weaklistoffset
+ offsetof(CMessage, weakreflist), // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
cmessage::Methods, // tp_methods
@@ -2765,17 +2866,38 @@ const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
+ const Message* message = PyMessage_GetMessagePointer(msg);
+ if (message == NULL) {
+ PyErr_Clear();
+ return NULL;
+ }
+ return message;
+}
+
+static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+ Message* message = PyMessage_GetMutableMessagePointer(msg);
+ if (message == NULL) {
+ PyErr_Clear();
+ return NULL;
+ }
+ return message;
+}
+
+const Message* PyMessage_GetMessagePointer(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
return cmsg->message;
}
-static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
+
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) ||
(cmsg->extensions != NULL &&
@@ -2784,36 +2906,20 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
// the underlying C++ message back to the CMessage (e.g. removed repeated
// composite containers). We only allow direct mutation of the underlying
// C++ message if there is no child data in the CMessage.
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot reliably get a mutable pointer "
+ "to a message with extra references");
return NULL;
}
cmessage::AssureWritable(cmsg);
return cmsg->message;
}
-static const char module_docstring[] =
-"python-proto2 is a module that can be used to enhance proto2 Python API\n"
-"performance.\n"
-"\n"
-"It provides access to the protocol buffers C++ reflection API that\n"
-"implements the basic protocol buffer functions.";
-
void InitGlobals() {
// TODO(gps): Check all return values in this function for NULL and propagate
// the error (MemoryError) on up to result in an import failure. These should
// also be freed and reset to NULL during finalization.
- kPythonZero = PyInt_FromLong(0);
- kint32min_py = PyInt_FromLong(kint32min);
- kint32max_py = PyInt_FromLong(kint32max);
- kuint32max_py = PyLong_FromLongLong(kuint32max);
- kint64min_py = PyLong_FromLongLong(kint64min);
- kint64max_py = PyLong_FromLongLong(kint64max);
- kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
-
kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
- k_cdescriptor = PyString_FromString("_cdescriptor");
- kfull_name = PyString_FromString("full_name");
- k_extensions_by_name = PyString_FromString("_extensions_by_name");
- k_extensions_by_number = PyString_FromString("_extensions_by_number");
PyObject *dummy_obj = PySet_New(NULL);
kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
@@ -2831,15 +2937,20 @@ bool InitProto2MessageModule(PyObject *m) {
return false;
}
+ // Initialize types and globals in message_factory.cc
+ if (!InitMessageFactory()) {
+ return false;
+ }
+
// Initialize constants defined in this file.
InitGlobals();
- PyMessageMeta_Type.tp_base = &PyType_Type;
- if (PyType_Ready(&PyMessageMeta_Type) < 0) {
+ CMessageClass_Type.tp_base = &PyType_Type;
+ if (PyType_Ready(&CMessageClass_Type) < 0) {
return false;
}
PyModule_AddObject(m, "MessageMeta",
- reinterpret_cast<PyObject*>(&PyMessageMeta_Type));
+ reinterpret_cast<PyObject*>(&CMessageClass_Type));
if (PyType_Ready(&CMessage_Type) < 0) {
return false;
@@ -2848,25 +2959,6 @@ bool InitProto2MessageModule(PyObject *m) {
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
// it here as well to document that subclasses need to set it.
PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
- // Subclasses with message extensions will override _extensions_by_name and
- // _extensions_by_number with fresh mutable dictionaries in AddDescriptors.
- // All other classes can share this same immutable mapping.
- ScopedPyObjectPtr empty_dict(PyDict_New());
- if (empty_dict == NULL) {
- return false;
- }
- ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get()));
- if (immutable_dict == NULL) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_name, immutable_dict.get()) < 0) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_number, immutable_dict.get()) < 0) {
- return false;
- }
PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));
@@ -2912,69 +3004,15 @@ bool InitProto2MessageModule(PyObject *m) {
}
// Initialize Map container types.
- {
- // ScalarMapContainer_Type derives from our MutableMapping type.
- ScopedPyObjectPtr containers(PyImport_ImportModule(
- "google.protobuf.internal.containers"));
- if (containers == NULL) {
- return false;
- }
-
- ScopedPyObjectPtr mutable_mapping(
- PyObject_GetAttrString(containers.get(), "MutableMapping"));
- if (mutable_mapping == NULL) {
- return false;
- }
-
- if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
- return false;
- }
-
- Py_INCREF(mutable_mapping.get());
-#if PY_MAJOR_VERSION >= 3
- PyObject* bases = PyTuple_New(1);
- PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
-
- ScalarMapContainer_Type =
- PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type);
-#else
- ScalarMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&ScalarMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "ScalarMapContainer",
- reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
-#endif
-
- if (PyType_Ready(&MapIterator_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MapIterator",
- reinterpret_cast<PyObject*>(&MapIterator_Type));
-
-
-#if PY_MAJOR_VERSION >= 3
- MessageMapContainer_Type =
- PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type);
-#else
- Py_INCREF(mutable_mapping.get());
- MessageMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&MessageMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MessageMapContainer",
- reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
-#endif
+ if (!InitMapContainers()) {
+ return false;
}
+ PyModule_AddObject(m, "ScalarMapContainer",
+ reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
+ PyModule_AddObject(m, "MessageMapContainer",
+ reinterpret_cast<PyObject*>(MessageMapContainer_Type));
+ PyModule_AddObject(m, "MapIterator",
+ reinterpret_cast<PyObject*>(&MapIterator_Type));
if (PyType_Ready(&ExtensionDict_Type) < 0) {
return false;
@@ -3009,6 +3047,10 @@ bool InitProto2MessageModule(PyObject *m) {
&PyFileDescriptor_Type));
PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast<PyObject*>(
&PyOneofDescriptor_Type));
+ PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast<PyObject*>(
+ &PyServiceDescriptor_Type));
+ PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast<PyObject*>(
+ &PyMethodDescriptor_Type));
PyObject* enum_type_wrapper = PyImport_ImportModule(
"google.protobuf.internal.enum_type_wrapper");
@@ -3045,47 +3087,4 @@ bool InitProto2MessageModule(PyObject *m) {
} // namespace python
} // namespace protobuf
-
-
-#if PY_MAJOR_VERSION >= 3
-static struct PyModuleDef _module = {
- PyModuleDef_HEAD_INIT,
- "_message",
- google::protobuf::python::module_docstring,
- -1,
- NULL,
- NULL,
- NULL,
- NULL,
- NULL
-};
-#define INITFUNC PyInit__message
-#define INITFUNC_ERRORVAL NULL
-#else // Python 2
-#define INITFUNC init_message
-#define INITFUNC_ERRORVAL
-#endif
-
-extern "C" {
- PyMODINIT_FUNC INITFUNC(void) {
- PyObject* m;
-#if PY_MAJOR_VERSION >= 3
- m = PyModule_Create(&_module);
-#else
- m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring);
-#endif
- if (m == NULL) {
- return INITFUNC_ERRORVAL;
- }
-
- if (!google::protobuf::python::InitProto2MessageModule(m)) {
- Py_DECREF(m);
- return INITFUNC_ERRORVAL;
- }
-
-#if PY_MAJOR_VERSION >= 3
- return m;
-#endif
- }
-}
} // namespace google
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index cc0012e9..d754e62a 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -37,11 +37,11 @@
#include <Python.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <string>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/pyext/thread_unsafe_shared_ptr.h>
+
namespace google {
namespace protobuf {
@@ -52,17 +52,10 @@ class Descriptor;
class DescriptorPool;
class MessageFactory;
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-using std::string;
-#else
-using internal::shared_ptr;
-#endif
-
namespace python {
struct ExtensionDict;
-struct PyDescriptorPool;
+struct PyMessageFactory;
typedef struct CMessage {
PyObject_HEAD;
@@ -71,7 +64,9 @@ typedef struct CMessage {
// proto tree. Every Python CMessage holds a reference to it in
// order to keep it alive as long as there's a Python object that
// references any part of the tree.
- shared_ptr<Message> owner;
+
+ typedef ThreadUnsafeSharedPtr<Message> OwnerRef;
+ OwnerRef owner;
// Weak reference to a parent CMessage object. This is NULL for any top-level
// message and is set for any child message (i.e. a child submessage or a
@@ -112,24 +107,48 @@ typedef struct CMessage {
// Similar to composite_fields, acting as a cache, but also contains the
// required extension dict logic.
ExtensionDict* extensions;
+
+ // Implements the "weakref" protocol for this object.
+ PyObject* weakreflist;
} CMessage;
+extern PyTypeObject CMessageClass_Type;
extern PyTypeObject CMessage_Type;
+
+// The (meta) type of all Messages classes.
+// It allows us to cache some C++ pointers in the class object itself, they are
+// faster to extract than from the type's dictionary.
+
+struct CMessageClass {
+ // This is how CPython subclasses C structures: the base structure must be
+ // the first member of the object.
+ PyHeapTypeObject super;
+
+ // C++ descriptor of this message.
+ const Descriptor* message_descriptor;
+
+ // Owned reference, used to keep the pointer above alive.
+ PyObject* py_message_descriptor;
+
+ // The Python MessageFactory used to create the class. It is needed to resolve
+ // fields descriptors, including extensions fields; its C++ MessageFactory is
+ // used to instantiate submessages.
+ // We own the reference, because it's important to keep the factory alive.
+ PyMessageFactory* py_message_factory;
+
+ PyObject* AsPyObject() {
+ return reinterpret_cast<PyObject*>(this);
+ }
+};
+
+
namespace cmessage {
// Internal function to create a new empty Message Python object, but with empty
// pointers to the C++ objects.
// The caller must fill self->message, self->owner and eventually self->parent.
-CMessage* NewEmptyMessage(PyObject* type, const Descriptor* descriptor);
-
-// Release a submessage from its proto tree, making it a new top-level messgae.
-// A new message will be created if this is a read-only default instance.
-//
-// Corresponds to reflection api method ReleaseMessage.
-int ReleaseSubMessage(CMessage* self,
- const FieldDescriptor* field_descriptor,
- CMessage* child_cmessage);
+CMessage* NewEmptyMessage(CMessageClass* type);
// Retrieves the C++ descriptor of a Python Extension descriptor.
// On error, return NULL with an exception set.
@@ -206,37 +225,44 @@ PyObject* HasFieldByDescriptor(
PyObject* HasField(CMessage* self, PyObject* arg);
// Initializes values of fields on a newly constructed message.
-int InitAttributes(CMessage* self, PyObject* kwargs);
+// Note that positional arguments are disallowed: 'args' must be NULL or the
+// empty tuple.
+int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs);
PyObject* MergeFrom(CMessage* self, PyObject* arg);
-// Retrieves an attribute named 'name' from CMessage 'self'. Returns
-// the attribute value on success, or NULL on failure.
+// This method does not do anything beyond checking that no other extension
+// has been registered with the same field number on this class.
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle);
+
+// Retrieves an attribute named 'name' from 'self', which is interpreted as a
+// CMessage. Returns the attribute value on success, or null on failure.
//
// Returns a new reference.
-PyObject* GetAttr(CMessage* self, PyObject* name);
+PyObject* GetAttr(PyObject* self, PyObject* name);
-// Set the value of the attribute named 'name', for CMessage 'self',
-// to the value 'value'. Returns -1 on failure.
-int SetAttr(CMessage* self, PyObject* name, PyObject* value);
+// Set the value of the attribute named 'name', for 'self', which is interpreted
+// as a CMessage, to the value 'value'. Returns -1 on failure.
+int SetAttr(PyObject* self, PyObject* name, PyObject* value);
PyObject* FindInitializationErrors(CMessage* self);
// Set the owner field of self and any children of self, recursively.
// Used when self is being released and thus has a new owner (the
// released Message.)
-int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner);
+int SetOwner(CMessage* self, const CMessage::OwnerRef& new_owner);
int AssureWritable(CMessage* self);
-// Returns the "best" DescriptorPool for the given message.
-// This is often equivalent to message.DESCRIPTOR.pool, but not always, when
-// the message class was created from a MessageFactory using a custom pool which
-// uses the generated pool as an underlay.
+// Returns the message factory for the given message.
+// This is equivalent to message.MESSAGE_FACTORY
//
-// The returned pool is suitable for finding fields and building submessages,
+// The returned factory is suitable for finding fields and building submessages,
// even in the case of extensions.
-PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
+// Returns a *borrowed* reference, and never fails because we pass a CMessage.
+PyMessageFactory* GetFactoryForMessage(CMessage* message);
+
+PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
} // namespace cmessage
@@ -249,25 +275,25 @@ PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
#define GOOGLE_CHECK_GET_INT32(arg, value, err) \
int32 value; \
- if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_INT64(arg, value, err) \
int64 value; \
- if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \
uint32 value; \
- if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \
uint64 value; \
- if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
@@ -290,20 +316,11 @@ PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
}
-extern PyObject* kPythonZero;
-extern PyObject* kint32min_py;
-extern PyObject* kint32max_py;
-extern PyObject* kuint32max_py;
-extern PyObject* kint64min_py;
-extern PyObject* kint64max_py;
-extern PyObject* kuint64max_py;
-
#define FULL_MODULE_NAME "google.protobuf.pyext._message"
void FormatTypeError(PyObject* arg, char* expected_types);
template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max);
+bool CheckAndGetInteger(PyObject* arg, T* value);
bool CheckAndGetDouble(PyObject* arg, double* value);
bool CheckAndGetFloat(PyObject* arg, float* value);
bool CheckAndGetBool(PyObject* arg, bool* value);
@@ -314,7 +331,8 @@ bool CheckAndSetString(
const Reflection* reflection,
bool append,
int index);
-PyObject* ToStringObject(const FieldDescriptor* descriptor, string value);
+PyObject* ToStringObject(const FieldDescriptor* descriptor,
+ const string& value);
// Check if the passed field descriptor belongs to the given message.
// If not, return false and set a Python exception (a KeyError)
@@ -323,6 +341,20 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
extern PyObject* PickleError_class;
+const Message* PyMessage_GetMessagePointer(PyObject* msg);
+Message* PyMessage_GetMutableMessagePointer(PyObject* msg);
+
+bool InitProto2MessageModule(PyObject *m);
+
+#if LANG_CXX11
+// These are referenced by repeated_scalar_container, and must
+// be explicitly instantiated.
+extern template bool CheckAndGetInteger<int32>(PyObject*, int32*);
+extern template bool CheckAndGetInteger<int64>(PyObject*, int64*);
+extern template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
+extern template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
+#endif
+
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
new file mode 100644
index 00000000..bacc76a6
--- /dev/null
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -0,0 +1,283 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <Python.h>
+
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+namespace message_factory {
+
+PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
+ PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
+ PyType_GenericAlloc(type, 0));
+ if (factory == NULL) {
+ return NULL;
+ }
+
+ DynamicMessageFactory* message_factory = new DynamicMessageFactory();
+ // This option might be the default some day.
+ message_factory->SetDelegateToGeneratedFactory(true);
+ factory->message_factory = message_factory;
+
+ factory->pool = pool;
+ // TODO(amauryfa): When the MessageFactory is not created from the
+ // DescriptorPool this reference should be owned, not borrowed.
+ // Py_INCREF(pool);
+
+ factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
+
+ return factory;
+}
+
+PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
+ static char* kwlist[] = {"pool", 0};
+ PyObject* pool = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
+ return NULL;
+ }
+ ScopedPyObjectPtr owned_pool;
+ if (pool == NULL || pool == Py_None) {
+ owned_pool.reset(PyObject_CallFunction(
+ reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
+ if (owned_pool == NULL) {
+ return NULL;
+ }
+ pool = owned_pool.get();
+ } else {
+ if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
+ pool->ob_type->tp_name);
+ return NULL;
+ }
+ }
+
+ return reinterpret_cast<PyObject*>(
+ NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
+}
+
+static void Dealloc(PyObject* pself) {
+ PyMessageFactory* self = reinterpret_cast<PyMessageFactory*>(pself);
+
+ // TODO(amauryfa): When the MessageFactory is not created from the
+ // DescriptorPool this reference should be owned, not borrowed.
+ // Py_CLEAR(self->pool);
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ for (iterator it = self->classes_by_descriptor->begin();
+ it != self->classes_by_descriptor->end(); ++it) {
+ Py_DECREF(it->second);
+ }
+ delete self->classes_by_descriptor;
+ delete self->message_factory;
+ Py_TYPE(self)->tp_free(pself);
+}
+
+// Add a message class to our database.
+int RegisterMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class) {
+ Py_INCREF(message_class);
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
+ std::make_pair(message_descriptor, message_class));
+ if (!ret.second) {
+ // Update case: DECREF the previous value.
+ Py_DECREF(ret.first->second);
+ ret.first->second = message_class;
+ }
+ return 0;
+}
+
+CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
+ const Descriptor* descriptor) {
+ // This is the same implementation as MessageFactory.GetPrototype().
+
+ // Do not create a MessageClass that already exists.
+ hash_map<const Descriptor*, CMessageClass*>::iterator it =
+ self->classes_by_descriptor->find(descriptor);
+ if (it != self->classes_by_descriptor->end()) {
+ Py_INCREF(it->second);
+ return it->second;
+ }
+ ScopedPyObjectPtr py_descriptor(
+ PyMessageDescriptor_FromDescriptor(descriptor));
+ if (py_descriptor == NULL) {
+ return NULL;
+ }
+ // Create a new message class.
+ ScopedPyObjectPtr args(Py_BuildValue(
+ "s(){sOsOsO}", descriptor->name().c_str(),
+ "DESCRIPTOR", py_descriptor.get(),
+ "__module__", Py_None,
+ "message_factory", self));
+ if (args == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr message_class(PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get()));
+ if (message_class == NULL) {
+ return NULL;
+ }
+ // Create messages class for the messages used by the fields, and registers
+ // all extensions for these messages during the recursion.
+ for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
+ const Descriptor* sub_descriptor =
+ descriptor->field(field_idx)->message_type();
+ // It is NULL if the field type is not a message.
+ if (sub_descriptor != NULL) {
+ CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
+ if (result == NULL) {
+ return NULL;
+ }
+ Py_DECREF(result);
+ }
+ }
+
+ // Register extensions defined in this message.
+ for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
+ const FieldDescriptor* extension = descriptor->extension(ext_idx);
+ ScopedPyObjectPtr py_extended_class(
+ GetOrCreateMessageClass(self, extension->containing_type())
+ ->AsPyObject());
+ if (py_extended_class == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
+ if (py_extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr result(cmessage::RegisterExtension(
+ py_extended_class.get(), py_extension.get()));
+ if (result == NULL) {
+ return NULL;
+ }
+ }
+ return reinterpret_cast<CMessageClass*>(message_class.release());
+}
+
+// Retrieve the message class added to our database.
+CMessageClass* GetMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor) {
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ iterator ret = self->classes_by_descriptor->find(message_descriptor);
+ if (ret == self->classes_by_descriptor->end()) {
+ PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
+ message_descriptor->full_name().c_str());
+ return NULL;
+ } else {
+ return ret->second;
+ }
+}
+
+static PyMethodDef Methods[] = {
+ {NULL}};
+
+static PyObject* GetPool(PyMessageFactory* self, void* closure) {
+ Py_INCREF(self->pool);
+ return reinterpret_cast<PyObject*>(self->pool);
+}
+
+static PyGetSetDef Getters[] = {
+ {"pool", (getter)GetPool, NULL, "DescriptorPool"},
+ {NULL}
+};
+
+} // namespace message_factory
+
+PyTypeObject PyMessageFactory_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
+ ".MessageFactory", // tp_name
+ sizeof(PyMessageFactory), // tp_basicsize
+ 0, // tp_itemsize
+ message_factory::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
+ "A static Message Factory", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ message_factory::Methods, // tp_methods
+ 0, // tp_members
+ message_factory::Getters, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ message_factory::New, // tp_new
+ PyObject_Del, // tp_free
+};
+
+bool InitMessageFactory() {
+ if (PyType_Ready(&PyMessageFactory_Type) < 0) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h
new file mode 100644
index 00000000..36092f7e
--- /dev/null
+++ b/python/google/protobuf/pyext/message_factory.h
@@ -0,0 +1,103 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
+
+#include <Python.h>
+
+#include <google/protobuf/stubs/hash.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
+
+namespace google {
+namespace protobuf {
+class MessageFactory;
+
+namespace python {
+
+// The (meta) type of all Messages classes.
+struct CMessageClass;
+
+struct PyMessageFactory {
+ PyObject_HEAD
+
+ // DynamicMessageFactory used to create C++ instances of messages.
+ // This object cache the descriptors that were used, so the DescriptorPool
+ // needs to get rid of it before it can delete itself.
+ //
+ // Note: A C++ MessageFactory is different from the PyMessageFactory.
+ // The C++ one creates messages, when the Python one creates classes.
+ MessageFactory* message_factory;
+
+ // borrowed reference to a Python DescriptorPool.
+ // TODO(amauryfa): invert the dependency: the MessageFactory owns the
+ // DescriptorPool, not the opposite.
+ PyDescriptorPool* pool;
+
+ // Make our own mapping to retrieve Python classes from C++ descriptors.
+ //
+ // Descriptor pointers stored here are owned by the DescriptorPool above.
+ // Python references to classes are owned by this PyDescriptorPool.
+ typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap;
+ ClassesByMessageMap* classes_by_descriptor;
+};
+
+extern PyTypeObject PyMessageFactory_Type;
+
+namespace message_factory {
+
+// Creates a new MessageFactory instance.
+PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool);
+
+// Registers a new Python class for the given message descriptor.
+// On error, returns -1 with a Python exception set.
+int RegisterMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class);
+// Retrieves the Python class registered with the given message descriptor, or
+// fail with a TypeError. Returns a *borrowed* reference.
+CMessageClass* GetMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor);
+// Retrieves the Python class registered with the given message descriptor.
+// The class is created if not done yet. Returns a *new* reference.
+CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor);
+} // namespace message_factory
+
+// Initialize objects used by this module.
+// On error, returns false with a Python exception set.
+bool InitMessageFactory();
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc
new file mode 100644
index 00000000..f5c8f295
--- /dev/null
+++ b/python/google/protobuf/pyext/message_module.cc
@@ -0,0 +1,138 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <Python.h>
+
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/proto_api.h>
+
+#include <google/protobuf/message_lite.h>
+
+namespace {
+
+// C++ API. Clients get at this via proto_api.h
+struct ApiImplementation : google::protobuf::python::PyProto_API {
+ const google::protobuf::Message*
+ GetMessagePointer(PyObject* msg) const override {
+ return google::protobuf::python::PyMessage_GetMessagePointer(msg);
+ }
+ google::protobuf::Message*
+ GetMutableMessagePointer(PyObject* msg) const override {
+ return google::protobuf::python::PyMessage_GetMutableMessagePointer(msg);
+ }
+};
+
+} // namespace
+
+static PyObject* GetPythonProto3PreserveUnknownsDefault(
+ PyObject* /*m*/, PyObject* /*args*/) {
+ if (google::protobuf::internal::GetProto3PreserveUnknownsDefault()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
+static PyObject* SetPythonProto3PreserveUnknownsDefault(
+ PyObject* /*m*/, PyObject* arg) {
+ if (!arg || !PyBool_Check(arg)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Argument to SetPythonProto3PreserveUnknownsDefault must be boolean");
+ return NULL;
+ }
+ google::protobuf::internal::SetProto3PreserveUnknownsDefault(PyObject_IsTrue(arg));
+ Py_RETURN_NONE;
+}
+
+static const char module_docstring[] =
+"python-proto2 is a module that can be used to enhance proto2 Python API\n"
+"performance.\n"
+"\n"
+"It provides access to the protocol buffers C++ reflection API that\n"
+"implements the basic protocol buffer functions.";
+
+static PyMethodDef ModuleMethods[] = {
+ {"SetAllowOversizeProtos",
+ (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
+ METH_O, "Enable/disable oversize proto parsing."},
+ // DO NOT USE: For migration and testing only.
+ {"GetPythonProto3PreserveUnknownsDefault",
+ (PyCFunction)GetPythonProto3PreserveUnknownsDefault,
+ METH_NOARGS, "Get Proto3 preserve unknowns default."},
+ // DO NOT USE: For migration and testing only.
+ {"SetPythonProto3PreserveUnknownsDefault",
+ (PyCFunction)SetPythonProto3PreserveUnknownsDefault,
+ METH_O, "Enable/disable proto3 unknowns preservation."},
+ { NULL, NULL}
+};
+
+#if PY_MAJOR_VERSION >= 3
+static struct PyModuleDef _module = {
+ PyModuleDef_HEAD_INIT,
+ "_message",
+ module_docstring,
+ -1,
+ ModuleMethods, /* m_methods */
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+#define INITFUNC PyInit__message
+#define INITFUNC_ERRORVAL NULL
+#else // Python 2
+#define INITFUNC init_message
+#define INITFUNC_ERRORVAL
+#endif
+
+extern "C" {
+ PyMODINIT_FUNC INITFUNC(void) {
+ PyObject* m;
+#if PY_MAJOR_VERSION >= 3
+ m = PyModule_Create(&_module);
+#else
+ m = Py_InitModule3("_message", ModuleMethods,
+ module_docstring);
+#endif
+ if (m == NULL) {
+ return INITFUNC_ERRORVAL;
+ }
+
+ if (!google::protobuf::python::InitProto2MessageModule(m)) {
+ Py_DECREF(m);
+ return INITFUNC_ERRORVAL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ return m;
+#endif
+ }
+}
diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto
index cce645d7..2e50df74 100644
--- a/python/google/protobuf/pyext/python.proto
+++ b/python/google/protobuf/pyext/python.proto
@@ -58,11 +58,11 @@ message ForeignMessage {
repeated int32 d = 2;
}
-message TestAllExtensions {
+message TestAllExtensions { // extension begin
extensions 1 to max;
-}
+} // extension end
-extend TestAllExtensions {
+extend TestAllExtensions { // extension begin
optional TestAllTypes.NestedMessage optional_nested_message_extension = 1;
repeated TestAllTypes.NestedMessage repeated_nested_message_extension = 2;
-}
+} // extension end
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index b01123b4..5874d5de 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -34,9 +34,6 @@
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/stubs/common.h>
@@ -46,7 +43,9 @@
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/reflection.h>
#if PY_MAJOR_VERSION >= 3
#define PyInt_Check PyLong_Check
@@ -79,7 +78,10 @@ namespace repeated_composite_container {
// ---------------------------------------------------------------------
// len()
-static Py_ssize_t Length(RepeatedCompositeContainer* self) {
+static Py_ssize_t Length(PyObject* pself) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
Message* message = self->message;
if (message != NULL) {
return message->GetReflection()->FieldSize(*message,
@@ -100,15 +102,14 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) {
// A MergeFrom on a parent message could have caused extra messages to be
// added in the underlying protobuf so add them to our list. They can never
// be removed in such a way so there's no need to worry about that.
- Py_ssize_t message_length = Length(self);
+ Py_ssize_t message_length = Length(reinterpret_cast<PyObject*>(self));
Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages);
Message* message = self->message;
const Reflection* reflection = message->GetReflection();
for (Py_ssize_t i = child_length; i < message_length; ++i) {
const Message& sub_message = reflection->GetRepeatedMessage(
*(self->message), self->parent_field_descriptor, i);
- CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
- sub_message.GetDescriptor());
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class);
ScopedPyObjectPtr py_cmsg(reinterpret_cast<PyObject*>(cmsg));
if (cmsg == NULL) {
return -1;
@@ -137,18 +138,20 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self,
if (cmessage::AssureWritable(self->parent) == -1)
return NULL;
Message* message = self->message;
+
Message* sub_message =
- message->GetReflection()->AddMessage(message,
- self->parent_field_descriptor);
- CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
- sub_message->GetDescriptor());
+ message->GetReflection()->AddMessage(
+ message,
+ self->parent_field_descriptor,
+ self->child_message_class->py_message_factory->message_factory);
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class);
if (cmsg == NULL)
return NULL;
cmsg->owner = self->owner;
cmsg->message = sub_message;
cmsg->parent = self->parent;
- if (cmessage::InitAttributes(cmsg, kwargs) < 0) {
+ if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) {
Py_DECREF(cmsg);
return NULL;
}
@@ -168,7 +171,7 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self,
// Create a new Message detached from the rest.
PyObject* py_cmsg = PyEval_CallObjectWithKeywords(
- self->subclass_init, NULL, kwargs);
+ self->child_message_class->AsPyObject(), args, kwargs);
if (py_cmsg == NULL)
return NULL;
@@ -188,6 +191,10 @@ PyObject* Add(RepeatedCompositeContainer* self,
return AddToAttached(self, args, kwargs);
}
+static PyObject* AddMethod(PyObject* self, PyObject* args, PyObject* kwargs) {
+ return Add(reinterpret_cast<RepeatedCompositeContainer*>(self), args, kwargs);
+}
+
// ---------------------------------------------------------------------
// extend()
@@ -223,6 +230,10 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
Py_RETURN_NONE;
}
+static PyObject* ExtendMethod(PyObject* self, PyObject* value) {
+ return Extend(reinterpret_cast<RepeatedCompositeContainer*>(self), value);
+}
+
PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) {
if (UpdateChildMessages(self) < 0) {
return NULL;
@@ -230,6 +241,10 @@ PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) {
return Extend(self, other);
}
+static PyObject* MergeFromMethod(PyObject* self, PyObject* other) {
+ return MergeFrom(reinterpret_cast<RepeatedCompositeContainer*>(self), other);
+}
+
PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) {
if (UpdateChildMessages(self) < 0) {
return NULL;
@@ -239,6 +254,10 @@ PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) {
return PyObject_GetItem(self->child_messages, slice);
}
+static PyObject* SubscriptMethod(PyObject* self, PyObject* slice) {
+ return Subscript(reinterpret_cast<RepeatedCompositeContainer*>(self), slice);
+}
+
int AssignSubscript(RepeatedCompositeContainer* self,
PyObject* slice,
PyObject* value) {
@@ -262,15 +281,16 @@ int AssignSubscript(RepeatedCompositeContainer* self,
Py_ssize_t from;
Py_ssize_t to;
Py_ssize_t step;
- Py_ssize_t length = Length(self);
+ Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self));
Py_ssize_t slicelength;
if (PySlice_Check(slice)) {
#if PY_MAJOR_VERSION >= 3
if (PySlice_GetIndicesEx(slice,
+ length, &from, &to, &step, &slicelength) == -1) {
#else
if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
-#endif
length, &from, &to, &step, &slicelength) == -1) {
+#endif
return -1;
}
return PySequence_DelSlice(self->child_messages, from, to);
@@ -286,7 +306,16 @@ int AssignSubscript(RepeatedCompositeContainer* self,
return 0;
}
-static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) {
+static int AssignSubscriptMethod(PyObject* self, PyObject* slice,
+ PyObject* value) {
+ return AssignSubscript(reinterpret_cast<RepeatedCompositeContainer*>(self),
+ slice, value);
+}
+
+static PyObject* Remove(PyObject* pself, PyObject* value) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
if (UpdateChildMessages(self) < 0) {
return NULL;
}
@@ -301,9 +330,10 @@ static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) {
Py_RETURN_NONE;
}
-static PyObject* RichCompare(RepeatedCompositeContainer* self,
- PyObject* other,
- int opid) {
+static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
if (UpdateChildMessages(self) < 0) {
return NULL;
}
@@ -336,6 +366,19 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self,
}
}
+static PyObject* ToStr(PyObject* pself) {
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(
+ reinterpret_cast<RepeatedCompositeContainer*>(pself), full_slice.get()));
+ if (list == NULL) {
+ return NULL;
+ }
+ return PyObject_Repr(list.get());
+}
+
// ---------------------------------------------------------------------
// sort()
@@ -343,7 +386,7 @@ static void ReorderAttached(RepeatedCompositeContainer* self) {
Message* message = self->message;
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* descriptor = self->parent_field_descriptor;
- const Py_ssize_t length = Length(self);
+ const Py_ssize_t length = Length(reinterpret_cast<PyObject*>(self));
// Since Python protobuf objects are never arena-allocated, adding and
// removing message pointers to the underlying array is just updating
@@ -366,7 +409,7 @@ static int SortPythonMessages(RepeatedCompositeContainer* self,
ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort"));
if (m == NULL)
return -1;
- if (PyObject_Call(m.get(), args, kwds) == NULL)
+ if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == NULL)
return -1;
if (self->message != NULL) {
ReorderAttached(self);
@@ -374,9 +417,10 @@ static int SortPythonMessages(RepeatedCompositeContainer* self,
return 0;
}
-static PyObject* Sort(RepeatedCompositeContainer* self,
- PyObject* args,
- PyObject* kwds) {
+static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
// Support the old sort_function argument for backwards
// compatibility.
if (kwds != NULL) {
@@ -400,11 +444,14 @@ static PyObject* Sort(RepeatedCompositeContainer* self,
// ---------------------------------------------------------------------
-static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) {
+static PyObject* Item(PyObject* pself, Py_ssize_t index) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
if (UpdateChildMessages(self) < 0) {
return NULL;
}
- Py_ssize_t length = Length(self);
+ Py_ssize_t length = Length(pself);
if (index < 0) {
index = length + index;
}
@@ -416,17 +463,17 @@ static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) {
return item;
}
-static PyObject* Pop(RepeatedCompositeContainer* self,
- PyObject* args) {
+static PyObject* Pop(PyObject* pself, PyObject* args) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
Py_ssize_t index = -1;
if (!PyArg_ParseTuple(args, "|n", &index)) {
return NULL;
}
- PyObject* item = Item(self, index);
+ PyObject* item = Item(pself, index);
if (item == NULL) {
- PyErr_Format(PyExc_IndexError,
- "list index (%zd) out of range",
- index);
+ PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index);
return NULL;
}
ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index));
@@ -444,7 +491,7 @@ void ReleaseLastTo(CMessage* parent,
GOOGLE_CHECK_NOTNULL(field);
GOOGLE_CHECK_NOTNULL(target);
- shared_ptr<Message> released_message(
+ CMessage::OwnerRef released_message(
parent->message->GetReflection()->ReleaseLast(parent->message, field));
// TODO(tibell): Deal with proto1.
@@ -487,8 +534,37 @@ int Release(RepeatedCompositeContainer* self) {
return 0;
}
+PyObject* DeepCopy(PyObject* pself, PyObject* arg) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
+ ScopedPyObjectPtr cloneObj(
+ PyType_GenericAlloc(&RepeatedCompositeContainer_Type, 0));
+ if (cloneObj == NULL) {
+ return NULL;
+ }
+ RepeatedCompositeContainer* clone =
+ reinterpret_cast<RepeatedCompositeContainer*>(cloneObj.get());
+
+ Message* new_message = self->message->New();
+ clone->parent = NULL;
+ clone->parent_field_descriptor = self->parent_field_descriptor;
+ clone->message = new_message;
+ clone->owner.reset(new_message);
+ Py_INCREF(self->child_message_class);
+ clone->child_message_class = self->child_message_class;
+ clone->child_messages = PyList_New(0);
+
+ new_message->GetReflection()
+ ->GetMutableRepeatedFieldRef<Message>(new_message,
+ self->parent_field_descriptor)
+ .MergeFrom(self->message->GetReflection()->GetRepeatedFieldRef<Message>(
+ *self->message, self->parent_field_descriptor));
+ return cloneObj.release();
+}
+
int SetOwner(RepeatedCompositeContainer* self,
- const shared_ptr<Message>& new_owner) {
+ const CMessage::OwnerRef& new_owner) {
GOOGLE_CHECK_ATTACHED(self);
self->owner = new_owner;
@@ -506,7 +582,7 @@ int SetOwner(RepeatedCompositeContainer* self,
PyObject *NewContainer(
CMessage* parent,
const FieldDescriptor* parent_field_descriptor,
- PyObject *concrete_class) {
+ CMessageClass* concrete_class) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return NULL;
}
@@ -523,47 +599,52 @@ PyObject *NewContainer(
self->parent_field_descriptor = parent_field_descriptor;
self->owner = parent->owner;
Py_INCREF(concrete_class);
- self->subclass_init = concrete_class;
+ self->child_message_class = concrete_class;
self->child_messages = PyList_New(0);
return reinterpret_cast<PyObject*>(self);
}
-static void Dealloc(RepeatedCompositeContainer* self) {
+static void Dealloc(PyObject* pself) {
+ RepeatedCompositeContainer* self =
+ reinterpret_cast<RepeatedCompositeContainer*>(pself);
+
Py_CLEAR(self->child_messages);
- Py_CLEAR(self->subclass_init);
+ Py_CLEAR(self->child_message_class);
// TODO(tibell): Do we need to call delete on these objects to make
// sure their destructors are called?
self->owner.reset();
- Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+ Py_TYPE(self)->tp_free(pself);
}
static PySequenceMethods SqMethods = {
- (lenfunc)Length, /* sq_length */
- 0, /* sq_concat */
- 0, /* sq_repeat */
- (ssizeargfunc)Item /* sq_item */
+ Length, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ Item /* sq_item */
};
static PyMappingMethods MpMethods = {
- (lenfunc)Length, /* mp_length */
- (binaryfunc)Subscript, /* mp_subscript */
- (objobjargproc)AssignSubscript,/* mp_ass_subscript */
+ Length, /* mp_length */
+ SubscriptMethod, /* mp_subscript */
+ AssignSubscriptMethod, /* mp_ass_subscript */
};
static PyMethodDef Methods[] = {
- { "add", (PyCFunction) Add, METH_VARARGS | METH_KEYWORDS,
+ { "__deepcopy__", DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "add", (PyCFunction)AddMethod, METH_VARARGS | METH_KEYWORDS,
"Adds an object to the repeated container." },
- { "extend", (PyCFunction) Extend, METH_O,
+ { "extend", ExtendMethod, METH_O,
"Adds objects to the repeated container." },
- { "pop", (PyCFunction)Pop, METH_VARARGS,
+ { "pop", Pop, METH_VARARGS,
"Removes an object from the repeated container and returns it." },
- { "remove", (PyCFunction) Remove, METH_O,
+ { "remove", Remove, METH_O,
"Removes an object from the repeated container." },
- { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS,
+ { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS,
"Sorts the repeated container." },
- { "MergeFrom", (PyCFunction) MergeFrom, METH_O,
+ { "MergeFrom", MergeFromMethod, METH_O,
"Adds objects to the repeated container." },
{ NULL, NULL }
};
@@ -575,12 +656,12 @@ PyTypeObject RepeatedCompositeContainer_Type = {
FULL_MODULE_NAME ".RepeatedCompositeContainer", // tp_name
sizeof(RepeatedCompositeContainer), // tp_basicsize
0, // tp_itemsize
- (destructor)repeated_composite_container::Dealloc, // tp_dealloc
+ repeated_composite_container::Dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ repeated_composite_container::ToStr, // tp_repr
0, // tp_as_number
&repeated_composite_container::SqMethods, // tp_as_sequence
&repeated_composite_container::MpMethods, // tp_as_mapping
@@ -594,7 +675,7 @@ PyTypeObject RepeatedCompositeContainer_Type = {
"A Repeated scalar container", // tp_doc
0, // tp_traverse
0, // tp_clear
- (richcmpfunc)repeated_composite_container::RichCompare, // tp_richcompare
+ repeated_composite_container::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h
index 58d37b02..e5e946aa 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.h
+++ b/python/google/protobuf/pyext/repeated_composite_container.h
@@ -37,27 +37,20 @@
#include <Python.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <string>
#include <vector>
+#include <google/protobuf/pyext/message.h>
+
namespace google {
namespace protobuf {
class FieldDescriptor;
class Message;
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-#else
-using internal::shared_ptr;
-#endif
-
namespace python {
-struct CMessage;
+struct CMessageClass;
// A RepeatedCompositeContainer can be in one of two states: attached
// or released.
@@ -76,7 +69,7 @@ typedef struct RepeatedCompositeContainer {
// proto tree. Every Python RepeatedCompositeContainer holds a
// reference to it in order to keep it alive as long as there's a
// Python object that references any part of the tree.
- shared_ptr<Message> owner;
+ CMessage::OwnerRef owner;
// Weak reference to parent object. May be NULL. Used to make sure
// the parent is writable before modifying the
@@ -94,8 +87,8 @@ typedef struct RepeatedCompositeContainer {
// calling Clear() or ClearField() on the parent.
Message* message;
- // A callable that is used to create new child messages.
- PyObject* subclass_init;
+ // The type used to create new child messages.
+ CMessageClass* child_message_class;
// A list of child messages.
PyObject* child_messages;
@@ -110,7 +103,7 @@ namespace repeated_composite_container {
PyObject *NewContainer(
CMessage* parent,
const FieldDescriptor* parent_field_descriptor,
- PyObject *concrete_class);
+ CMessageClass *child_message_class);
// Appends a new CMessage to the container and returns it. The
// CMessage is initialized using the content of kwargs.
@@ -147,11 +140,6 @@ int AssignSubscript(RepeatedCompositeContainer* self,
PyObject* slice,
PyObject* value);
-// Releases the messages in the container to the given message.
-//
-// Returns 0 on success, -1 on failure.
-int ReleaseToMessage(RepeatedCompositeContainer* self, Message* new_message);
-
// Releases the messages in the container to a new message.
//
// Returns 0 on success, -1 on failure.
@@ -159,7 +147,7 @@ int Release(RepeatedCompositeContainer* self);
// Returns 0 on success, -1 on failure.
int SetOwner(RepeatedCompositeContainer* self,
- const shared_ptr<Message>& new_owner);
+ const CMessage::OwnerRef& new_owner);
// Removes the last element of the repeated message field 'field' on
// the Message 'parent', and transfers the ownership of the released
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index 95da85f8..de3b6e14 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -34,9 +34,6 @@
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
@@ -77,15 +74,18 @@ static int InternalAssignRepeatedField(
return 0;
}
-static Py_ssize_t Len(RepeatedScalarContainer* self) {
+static Py_ssize_t Len(PyObject* pself) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
Message* message = self->message;
return message->GetReflection()->FieldSize(*message,
self->parent_field_descriptor);
}
-static int AssignItem(RepeatedScalarContainer* self,
- Py_ssize_t index,
- PyObject* arg) {
+static int AssignItem(PyObject* pself, Py_ssize_t index, PyObject* arg) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
+
cmessage::AssureWritable(self->parent);
Message* message = self->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
@@ -188,7 +188,10 @@ static int AssignItem(RepeatedScalarContainer* self,
return 0;
}
-static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
+static PyObject* Item(PyObject* pself, Py_ssize_t index) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
+
Message* message = self->message;
const FieldDescriptor* field_descriptor = self->parent_field_descriptor;
const Reflection* reflection = message->GetReflection();
@@ -256,27 +259,12 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
break;
}
case FieldDescriptor::CPPTYPE_STRING: {
- string value = reflection->GetRepeatedString(
- *message, field_descriptor, index);
+ string scratch;
+ const string& value = reflection->GetRepeatedStringReference(
+ *message, field_descriptor, index, &scratch);
result = ToStringObject(field_descriptor, value);
break;
}
- case FieldDescriptor::CPPTYPE_MESSAGE: {
- PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>(
- &CMessage_Type), NULL);
- if (py_cmsg == NULL) {
- return NULL;
- }
- CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
- const Message& msg = reflection->GetRepeatedMessage(
- *message, field_descriptor, index);
- cmsg->owner = self->owner;
- cmsg->parent = self->parent;
- cmsg->message = const_cast<Message*>(&msg);
- cmsg->read_only = false;
- result = reinterpret_cast<PyObject*>(py_cmsg);
- break;
- }
default:
PyErr_Format(
PyExc_SystemError,
@@ -287,7 +275,7 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
return result;
}
-static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
+static PyObject* Subscript(PyObject* pself, PyObject* slice) {
Py_ssize_t from;
Py_ssize_t to;
Py_ssize_t step;
@@ -302,13 +290,14 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
if (PyLong_Check(slice)) {
from = to = PyLong_AsLong(slice);
} else if (PySlice_Check(slice)) {
- length = Len(self);
+ length = Len(pself);
#if PY_MAJOR_VERSION >= 3
if (PySlice_GetIndicesEx(slice,
+ length, &from, &to, &step, &slicelength) == -1) {
#else
if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
-#endif
length, &from, &to, &step, &slicelength) == -1) {
+#endif
return NULL;
}
return_list = true;
@@ -318,7 +307,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
}
if (!return_list) {
- return Item(self, from);
+ return Item(pself, from);
}
PyObject* list = PyList_New(0);
@@ -333,7 +322,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
if (index < 0 || index >= length) {
break;
}
- ScopedPyObjectPtr s(Item(self, index));
+ ScopedPyObjectPtr s(Item(pself, index));
PyList_Append(list, s.get());
}
} else {
@@ -344,7 +333,7 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
if (index < 0 || index >= length) {
break;
}
- ScopedPyObjectPtr s(Item(self, index));
+ ScopedPyObjectPtr s(Item(pself, index));
PyList_Append(list, s.get());
}
}
@@ -431,9 +420,14 @@ PyObject* Append(RepeatedScalarContainer* self, PyObject* item) {
Py_RETURN_NONE;
}
-static int AssSubscript(RepeatedScalarContainer* self,
- PyObject* slice,
- PyObject* value) {
+static PyObject* AppendMethod(PyObject* self, PyObject* item) {
+ return Append(reinterpret_cast<RepeatedScalarContainer*>(self), item);
+}
+
+static int AssSubscript(PyObject* pself, PyObject* slice, PyObject* value) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
+
Py_ssize_t from;
Py_ssize_t to;
Py_ssize_t step;
@@ -449,7 +443,7 @@ static int AssSubscript(RepeatedScalarContainer* self,
#if PY_MAJOR_VERSION < 3
if (PyInt_Check(slice)) {
from = to = PyInt_AsLong(slice);
- } else
+ } else // NOLINT
#endif
if (PyLong_Check(slice)) {
from = to = PyLong_AsLong(slice);
@@ -458,10 +452,11 @@ static int AssSubscript(RepeatedScalarContainer* self,
length = reflection->FieldSize(*message, field_descriptor);
#if PY_MAJOR_VERSION >= 3
if (PySlice_GetIndicesEx(slice,
+ length, &from, &to, &step, &slicelength) == -1) {
#else
if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
-#endif
length, &from, &to, &step, &slicelength) == -1) {
+#endif
return -1;
}
create_list = true;
@@ -476,14 +471,14 @@ static int AssSubscript(RepeatedScalarContainer* self,
}
if (!create_list) {
- return AssignItem(self, from, value);
+ return AssignItem(pself, from, value);
}
ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
if (full_slice == NULL) {
return -1;
}
- ScopedPyObjectPtr new_list(Subscript(self, full_slice.get()));
+ ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get()));
if (new_list == NULL) {
return -1;
}
@@ -522,14 +517,17 @@ PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) {
Py_RETURN_NONE;
}
-static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) {
+static PyObject* Insert(PyObject* pself, PyObject* args) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
+
Py_ssize_t index;
PyObject* value;
if (!PyArg_ParseTuple(args, "lO", &index, &value)) {
return NULL;
}
ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
- ScopedPyObjectPtr new_list(Subscript(self, full_slice.get()));
+ ScopedPyObjectPtr new_list(Subscript(pself, full_slice.get()));
if (PyList_Insert(new_list.get(), index, value) < 0) {
return NULL;
}
@@ -540,10 +538,10 @@ static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) {
Py_RETURN_NONE;
}
-static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) {
+static PyObject* Remove(PyObject* pself, PyObject* value) {
Py_ssize_t match_index = -1;
- for (Py_ssize_t i = 0; i < Len(self); ++i) {
- ScopedPyObjectPtr elem(Item(self, i));
+ for (Py_ssize_t i = 0; i < Len(pself); ++i) {
+ ScopedPyObjectPtr elem(Item(pself, i));
if (PyObject_RichCompareBool(elem.get(), value, Py_EQ)) {
match_index = i;
break;
@@ -553,15 +551,17 @@ static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) {
PyErr_SetString(PyExc_ValueError, "remove(x): x not in container");
return NULL;
}
- if (AssignItem(self, match_index, NULL) < 0) {
+ if (AssignItem(pself, match_index, NULL) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
-static PyObject* RichCompare(RepeatedScalarContainer* self,
- PyObject* other,
- int opid) {
+static PyObject* ExtendMethod(PyObject* self, PyObject* value) {
+ return Extend(reinterpret_cast<RepeatedScalarContainer*>(self), value);
+}
+
+static PyObject* RichCompare(PyObject* pself, PyObject* other, int opid) {
if (opid != Py_EQ && opid != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
@@ -578,28 +578,25 @@ static PyObject* RichCompare(RepeatedScalarContainer* self,
ScopedPyObjectPtr other_list_deleter;
if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) {
- other_list_deleter.reset(Subscript(
- reinterpret_cast<RepeatedScalarContainer*>(other), full_slice.get()));
+ other_list_deleter.reset(Subscript(other, full_slice.get()));
other = other_list_deleter.get();
}
- ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
+ ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == NULL) {
return NULL;
}
return PyObject_RichCompare(list.get(), other, opid);
}
-PyObject* Reduce(RepeatedScalarContainer* unused_self) {
+PyObject* Reduce(PyObject* unused_self, PyObject* unused_other) {
PyErr_Format(
PickleError_class,
"can't pickle repeated message fields, convert to list first");
return NULL;
}
-static PyObject* Sort(RepeatedScalarContainer* self,
- PyObject* args,
- PyObject* kwds) {
+static PyObject* Sort(PyObject* pself, PyObject* args, PyObject* kwds) {
// Support the old sort_function argument for backwards
// compatibility.
if (kwds != NULL) {
@@ -618,7 +615,7 @@ static PyObject* Sort(RepeatedScalarContainer* self,
if (full_slice == NULL) {
return NULL;
}
- ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
+ ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
if (list == NULL) {
return NULL;
}
@@ -630,32 +627,42 @@ static PyObject* Sort(RepeatedScalarContainer* self,
if (res == NULL) {
return NULL;
}
- int ret = InternalAssignRepeatedField(self, list.get());
+ int ret = InternalAssignRepeatedField(
+ reinterpret_cast<RepeatedScalarContainer*>(pself), list.get());
if (ret < 0) {
return NULL;
}
Py_RETURN_NONE;
}
-static PyObject* Pop(RepeatedScalarContainer* self,
- PyObject* args) {
+static PyObject* Pop(PyObject* pself, PyObject* args) {
Py_ssize_t index = -1;
if (!PyArg_ParseTuple(args, "|n", &index)) {
return NULL;
}
- PyObject* item = Item(self, index);
+ PyObject* item = Item(pself, index);
if (item == NULL) {
- PyErr_Format(PyExc_IndexError,
- "list index (%zd) out of range",
- index);
+ PyErr_Format(PyExc_IndexError, "list index (%zd) out of range", index);
return NULL;
}
- if (AssignItem(self, index, NULL) < 0) {
+ if (AssignItem(pself, index, NULL) < 0) {
return NULL;
}
return item;
}
+static PyObject* ToStr(PyObject* pself) {
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(pself, full_slice.get()));
+ if (list == NULL) {
+ return NULL;
+ }
+ return PyObject_Repr(list.get());
+}
+
// The private constructor of RepeatedScalarContainer objects.
PyObject *NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor) {
@@ -688,7 +695,8 @@ static int InitializeAndCopyToParentContainer(
if (full_slice == NULL) {
return -1;
}
- ScopedPyObjectPtr values(Subscript(from, full_slice.get()));
+ ScopedPyObjectPtr values(
+ Subscript(reinterpret_cast<PyObject*>(from), full_slice.get()));
if (values == NULL) {
return -1;
}
@@ -707,7 +715,10 @@ int Release(RepeatedScalarContainer* self) {
return InitializeAndCopyToParentContainer(self, self);
}
-PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) {
+PyObject* DeepCopy(PyObject* pself, PyObject* arg) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
+
RepeatedScalarContainer* clone = reinterpret_cast<RepeatedScalarContainer*>(
PyType_GenericAlloc(&RepeatedScalarContainer_Type, 0));
if (clone == NULL) {
@@ -721,45 +732,47 @@ PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) {
return reinterpret_cast<PyObject*>(clone);
}
-static void Dealloc(RepeatedScalarContainer* self) {
+static void Dealloc(PyObject* pself) {
+ RepeatedScalarContainer* self =
+ reinterpret_cast<RepeatedScalarContainer*>(pself);
self->owner.reset();
- Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+ Py_TYPE(self)->tp_free(pself);
}
void SetOwner(RepeatedScalarContainer* self,
- const shared_ptr<Message>& new_owner) {
+ const CMessage::OwnerRef& new_owner) {
self->owner = new_owner;
}
static PySequenceMethods SqMethods = {
- (lenfunc)Len, /* sq_length */
- 0, /* sq_concat */
- 0, /* sq_repeat */
- (ssizeargfunc)Item, /* sq_item */
- 0, /* sq_slice */
- (ssizeobjargproc)AssignItem /* sq_ass_item */
+ Len, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ Item, /* sq_item */
+ 0, /* sq_slice */
+ AssignItem /* sq_ass_item */
};
static PyMappingMethods MpMethods = {
- (lenfunc)Len, /* mp_length */
- (binaryfunc)Subscript, /* mp_subscript */
- (objobjargproc)AssSubscript, /* mp_ass_subscript */
+ Len, /* mp_length */
+ Subscript, /* mp_subscript */
+ AssSubscript, /* mp_ass_subscript */
};
static PyMethodDef Methods[] = {
- { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ { "__deepcopy__", DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
- { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ { "__reduce__", Reduce, METH_NOARGS,
"Outputs picklable representation of the repeated field." },
- { "append", (PyCFunction)Append, METH_O,
+ { "append", AppendMethod, METH_O,
"Appends an object to the repeated container." },
- { "extend", (PyCFunction)Extend, METH_O,
- "Appends objects to the repeated container." },
- { "insert", (PyCFunction)Insert, METH_VARARGS,
+ { "extend", ExtendMethod, METH_O,
"Appends objects to the repeated container." },
- { "pop", (PyCFunction)Pop, METH_VARARGS,
+ { "insert", Insert, METH_VARARGS,
+ "Inserts an object at the specified position in the container." },
+ { "pop", Pop, METH_VARARGS,
"Removes an object from the repeated container and returns it." },
- { "remove", (PyCFunction)Remove, METH_O,
+ { "remove", Remove, METH_O,
"Removes an object from the repeated container." },
{ "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS,
"Sorts the repeated container."},
@@ -773,12 +786,12 @@ PyTypeObject RepeatedScalarContainer_Type = {
FULL_MODULE_NAME ".RepeatedScalarContainer", // tp_name
sizeof(RepeatedScalarContainer), // tp_basicsize
0, // tp_itemsize
- (destructor)repeated_scalar_container::Dealloc, // tp_dealloc
+ repeated_scalar_container::Dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ repeated_scalar_container::ToStr, // tp_repr
0, // tp_as_number
&repeated_scalar_container::SqMethods, // tp_as_sequence
&repeated_scalar_container::MpMethods, // tp_as_mapping
@@ -792,7 +805,7 @@ PyTypeObject RepeatedScalarContainer_Type = {
"A Repeated scalar container", // tp_doc
0, // tp_traverse
0, // tp_clear
- (richcmpfunc)repeated_scalar_container::RichCompare, // tp_richcompare
+ repeated_scalar_container::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h
index 555e621c..559dec98 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.h
+++ b/python/google/protobuf/pyext/repeated_scalar_container.h
@@ -37,27 +37,14 @@
#include <Python.h>
#include <memory>
-#ifndef _SHARED_PTR_H
-#include <google/protobuf/stubs/shared_ptr.h>
-#endif
#include <google/protobuf/descriptor.h>
+#include <google/protobuf/pyext/message.h>
namespace google {
namespace protobuf {
-
-class Message;
-
-#ifdef _SHARED_PTR_H
-using std::shared_ptr;
-#else
-using internal::shared_ptr;
-#endif
-
namespace python {
-struct CMessage;
-
typedef struct RepeatedScalarContainer {
PyObject_HEAD;
@@ -65,7 +52,7 @@ typedef struct RepeatedScalarContainer {
// proto tree. Every Python RepeatedScalarContainer holds a
// reference to it in order to keep it alive as long as there's a
// Python object that references any part of the tree.
- shared_ptr<Message> owner;
+ CMessage::OwnerRef owner;
// Pointer to the C++ Message that contains this container. The
// RepeatedScalarContainer does not own this pointer.
@@ -112,7 +99,7 @@ PyObject* Extend(RepeatedScalarContainer* self, PyObject* value);
// Set the owner field of self and any children of self.
void SetOwner(RepeatedScalarContainer* self,
- const shared_ptr<Message>& new_owner);
+ const CMessage::OwnerRef& new_owner);
} // namespace repeated_scalar_container
} // namespace python
diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h
new file mode 100644
index 00000000..639ba2c8
--- /dev/null
+++ b/python/google/protobuf/pyext/safe_numerics.h
@@ -0,0 +1,164 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
+// Copied from chromium with only changes to the namespace.
+
+#include <limits>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+template <bool SameSize, bool DestLarger,
+ bool DestIsSigned, bool SourceIsSigned>
+struct IsValidNumericCastImpl;
+
+#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \
+template <> struct IsValidNumericCastImpl<A, B, C, D> { \
+ template <class Source, class DestBounds> static inline bool Test( \
+ Source source, DestBounds min, DestBounds max) { \
+ return Code; \
+ } \
+}
+
+#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ true, true, DestSigned, SourceSigned, Code); \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ true, false, DestSigned, SourceSigned, Code)
+
+#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ false, false, DestSigned, SourceSigned, Code); \
+
+#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ false, true, DestSigned, SourceSigned, Code); \
+
+// The three top level cases are:
+// - Same size
+// - Source larger
+// - Dest larger
+// And for each of those three cases, we handle the 4 different possibilities
+// of signed and unsigned. This gives 12 cases to handle, which we enumerate
+// below.
+//
+// The last argument in each of the macros is the actual comparison code. It
+// has three arguments available, source (the value), and min/max which are
+// the ranges of the destination.
+
+
+// These are the cases where both types have the same size.
+
+// Both signed.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true);
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true);
+// Dest unsigned, Source signed.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0);
+// Dest signed, Source unsigned.
+// This cast is OK because Dest's max must be less than Source's.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false,
+ source <= static_cast<Source>(max));
+
+
+// These are the cases where Source is larger.
+
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max);
+// Both signed.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true,
+ source >= min && source <= max);
+// Dest is unsigned, Source is signed.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true,
+ source >= 0 && source <= max);
+// Dest is signed, Source is unsigned.
+// This cast is OK because Dest's max must be less than Source's.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false,
+ source <= static_cast<Source>(max));
+
+
+// These are the cases where Dest is larger.
+
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true);
+// Both signed.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true);
+// Dest is unsigned, Source is signed.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0);
+// Dest is signed, Source is unsigned.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true);
+
+#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION
+#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE
+#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER
+#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER
+
+
+// The main test for whether the conversion will under or overflow.
+template <class Dest, class Source>
+inline bool IsValidNumericCast(Source source) {
+ typedef std::numeric_limits<Source> SourceLimits;
+ typedef std::numeric_limits<Dest> DestLimits;
+ GOOGLE_COMPILE_ASSERT(SourceLimits::is_specialized, argument_must_be_numeric);
+ GOOGLE_COMPILE_ASSERT(SourceLimits::is_integer, argument_must_be_integral);
+ GOOGLE_COMPILE_ASSERT(DestLimits::is_specialized, result_must_be_numeric);
+ GOOGLE_COMPILE_ASSERT(DestLimits::is_integer, result_must_be_integral);
+
+ return IsValidNumericCastImpl<
+ sizeof(Dest) == sizeof(Source),
+ (sizeof(Dest) > sizeof(Source)),
+ DestLimits::is_signed,
+ SourceLimits::is_signed>::Test(
+ source,
+ DestLimits::min(),
+ DestLimits::max());
+}
+
+// checked_numeric_cast<> is analogous to static_cast<> for numeric types,
+// except that it CHECKs that the specified numeric conversion will not
+// overflow or underflow. Floating point arguments are not currently allowed
+// (this is COMPILE_ASSERTd), though this could be supported if necessary.
+template <class Dest, class Source>
+inline Dest checked_numeric_cast(Source source) {
+ GOOGLE_CHECK(IsValidNumericCast<Dest>(source));
+ return static_cast<Dest>(source);
+}
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
index a128cd4c..a2afa7f1 100644
--- a/python/google/protobuf/pyext/scoped_pyobject_ptr.h
+++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
@@ -36,61 +36,70 @@
#include <google/protobuf/stubs/common.h>
#include <Python.h>
-
namespace google {
-class ScopedPyObjectPtr {
+namespace protobuf {
+namespace python {
+
+// Owns a python object and decrements the reference count on destruction.
+// This class is not threadsafe.
+template <typename PyObjectStruct>
+class ScopedPythonPtr {
public:
- // Constructor. Defaults to initializing with NULL.
- // There is no way to create an uninitialized ScopedPyObjectPtr.
- explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { }
+ // Takes the ownership of the specified object to ScopedPythonPtr.
+ // The reference count of the specified py_object is not incremented.
+ explicit ScopedPythonPtr(PyObjectStruct* py_object = NULL)
+ : ptr_(py_object) {}
- // Destructor. If there is a PyObject object, delete it.
- ~ScopedPyObjectPtr() {
- Py_XDECREF(ptr_);
- }
+ // If a PyObject is owned, decrement its reference count.
+ ~ScopedPythonPtr() { Py_XDECREF(ptr_); }
- // Reset. Deletes the current owned object, if any.
- // Then takes ownership of a new object, if given.
+ // Deletes the current owned object, if any.
+ // Then takes ownership of a new object without incrementing the reference
+ // count.
// This function must be called with a reference that you own.
// this->reset(this->get()) is wrong!
// this->reset(this->release()) is OK.
- PyObject* reset(PyObject* p = NULL) {
+ PyObjectStruct* reset(PyObjectStruct* p = NULL) {
Py_XDECREF(ptr_);
ptr_ = p;
return ptr_;
}
- // Releases ownership of the object.
+ // Releases ownership of the object without decrementing the reference count.
// The caller now owns the returned reference.
- PyObject* release() {
+ PyObjectStruct* release() {
PyObject* p = ptr_;
ptr_ = NULL;
return p;
}
- PyObject* operator->() const {
+ PyObjectStruct* operator->() const {
assert(ptr_ != NULL);
return ptr_;
}
- PyObject* get() const { return ptr_; }
+ PyObjectStruct* get() const { return ptr_; }
- Py_ssize_t refcnt() const { return Py_REFCNT(ptr_); }
+ PyObject* as_pyobject() const { return reinterpret_cast<PyObject*>(ptr_); }
+ // Increments the reference count fo the current object.
+ // Should not be called when no object is held.
void inc() const { Py_INCREF(ptr_); }
- // Comparison operators.
- // These return whether a ScopedPyObjectPtr and a raw pointer
- // refer to the same object, not just to two different but equal
- // objects.
- bool operator==(const PyObject* p) const { return ptr_ == p; }
- bool operator!=(const PyObject* p) const { return ptr_ != p; }
+ // True when a ScopedPyObjectPtr and a raw pointer refer to the same object.
+ // Comparison operators are non reflexive.
+ bool operator==(const PyObjectStruct* p) const { return ptr_ == p; }
+ bool operator!=(const PyObjectStruct* p) const { return ptr_ != p; }
private:
- PyObject* ptr_;
+ PyObjectStruct* ptr_;
- GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPyObjectPtr);
+ GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPythonPtr);
};
+typedef ScopedPythonPtr<PyObject> ScopedPyObjectPtr;
+
+} // namespace python
+} // namespace protobuf
} // namespace google
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
diff --git a/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h
new file mode 100644
index 00000000..ad804b5f
--- /dev/null
+++ b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h
@@ -0,0 +1,104 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// ThreadUnsafeSharedPtr<T> is the same as shared_ptr<T> without the locking
+// overhread (and thread-safety).
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_THREAD_UNSAFE_SHARED_PTR_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_THREAD_UNSAFE_SHARED_PTR_H__
+
+#include <algorithm>
+#include <utility>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+template <typename T>
+class ThreadUnsafeSharedPtr {
+ public:
+ // Takes ownership.
+ explicit ThreadUnsafeSharedPtr(T* ptr)
+ : ptr_(ptr), refcount_(ptr ? new RefcountT(1) : nullptr) {
+ }
+
+ ThreadUnsafeSharedPtr(const ThreadUnsafeSharedPtr& other)
+ : ThreadUnsafeSharedPtr(nullptr) {
+ *this = other;
+ }
+
+ ThreadUnsafeSharedPtr& operator=(const ThreadUnsafeSharedPtr& other) {
+ if (other.refcount_ == refcount_) {
+ return *this;
+ }
+ this->~ThreadUnsafeSharedPtr();
+ ptr_ = other.ptr_;
+ refcount_ = other.refcount_;
+ if (refcount_) {
+ ++*refcount_;
+ }
+ return *this;
+ }
+
+ ~ThreadUnsafeSharedPtr() {
+ if (refcount_ == nullptr) {
+ GOOGLE_DCHECK(ptr_ == nullptr);
+ return;
+ }
+ if (--*refcount_ == 0) {
+ delete refcount_;
+ delete ptr_;
+ }
+ }
+
+ void reset(T* ptr = nullptr) { *this = ThreadUnsafeSharedPtr(ptr); }
+
+ T* get() { return ptr_; }
+ const T* get() const { return ptr_; }
+
+ void swap(ThreadUnsafeSharedPtr& other) {
+ using std::swap;
+ swap(ptr_, other.ptr_);
+ swap(refcount_, other.refcount_);
+ }
+
+ private:
+ typedef int RefcountT;
+ T* ptr_;
+ RefcountT* refcount_;
+};
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_THREAD_UNSAFE_SHARED_PTR_H__
diff --git a/python/google/protobuf/pyext/python_protobuf.h b/python/google/protobuf/python_protobuf.h
index beb6e460..beb6e460 100644
--- a/python/google/protobuf/pyext/python_protobuf.h
+++ b/python/google/protobuf/python_protobuf.h
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 0c757264..f4ce8caf 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -58,15 +58,11 @@ else:
from google.protobuf.internal import python_message as message_impl
# The type of all Message classes.
-# Part of the public interface.
-#
-# Used by generated files, but clients can also use it at runtime:
-# mydescriptor = pool.FindDescriptor(.....)
-# class MyProtoClass(Message):
-# __metaclass__ = GeneratedProtocolMessageType
-# DESCRIPTOR = mydescriptor
+# Part of the public interface, but normally only used by message factories.
GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType
+MESSAGE_CLASS_CACHE = {}
+
def ParseMessage(descriptor, byte_str):
"""Generate a new Message instance from this Descriptor and a byte string.
@@ -110,11 +106,16 @@ def MakeClass(descriptor):
Returns:
The Message class object described by the descriptor.
"""
+ if descriptor in MESSAGE_CLASS_CACHE:
+ return MESSAGE_CLASS_CACHE[descriptor]
+
attributes = {}
for name, nested_type in descriptor.nested_types_by_name.items():
attributes[name] = MakeClass(nested_type)
attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor
- return GeneratedProtocolMessageType(str(descriptor.name), (message.Message,),
- attributes)
+ result = GeneratedProtocolMessageType(
+ str(descriptor.name), (message.Message,), attributes)
+ MESSAGE_CLASS_CACHE[descriptor] = result
+ return result
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index 87760f26..5ad869f4 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -30,11 +30,9 @@
"""A database of Python protocol buffer generated symbols.
-SymbolDatabase makes it easy to create new instances of a registered type, given
-only the type's protocol buffer symbol name. Once all symbols are registered,
-they can be accessed using either the MessageFactory interface which
-SymbolDatabase exposes, or the DescriptorPool interface of the underlying
-pool.
+SymbolDatabase is the MessageFactory for messages generated at compile time,
+and makes it easy to create new instances of a registered type, given only the
+type's protocol buffer symbol name.
Example usage:
@@ -61,27 +59,17 @@ Example usage:
from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
-class SymbolDatabase(object):
- """A database of Python generated symbols.
-
- SymbolDatabase also models message_factory.MessageFactory.
-
- The symbol database can be used to keep a global registry of all protocol
- buffer types used within a program.
- """
-
- def __init__(self, pool=None):
- """Constructor."""
-
- self._symbols = {}
- self._symbols_by_file = {}
- self.pool = pool or descriptor_pool.Default()
+class SymbolDatabase(message_factory.MessageFactory):
+ """A database of Python generated symbols."""
def RegisterMessage(self, message):
"""Registers the given message type in the local database.
+ Calls to GetSymbol() and GetMessages() will return messages registered here.
+
Args:
message: a message.Message, to be registered.
@@ -90,13 +78,18 @@ class SymbolDatabase(object):
"""
desc = message.DESCRIPTOR
- self._symbols[desc.full_name] = message
- if desc.file.name not in self._symbols_by_file:
- self._symbols_by_file[desc.file.name] = {}
- self._symbols_by_file[desc.file.name][desc.full_name] = message
- self.pool.AddDescriptor(desc)
+ self._classes[desc] = message
+ self.RegisterMessageDescriptor(desc)
return message
+ def RegisterMessageDescriptor(self, message_descriptor):
+ """Registers the given message descriptor in the local database.
+
+ Args:
+ message_descriptor: a descriptor.MessageDescriptor.
+ """
+ self.pool.AddDescriptor(message_descriptor)
+
def RegisterEnumDescriptor(self, enum_descriptor):
"""Registers the given enum descriptor in the local database.
@@ -109,6 +102,17 @@ class SymbolDatabase(object):
self.pool.AddEnumDescriptor(enum_descriptor)
return enum_descriptor
+ def RegisterServiceDescriptor(self, service_descriptor):
+ """Registers the given service descriptor in the local database.
+
+ Args:
+ service_descriptor: a descriptor.ServiceDescriptor.
+
+ Returns:
+ The provided descriptor.
+ """
+ self.pool.AddServiceDescriptor(service_descriptor)
+
def RegisterFileDescriptor(self, file_descriptor):
"""Registers the given file descriptor in the local database.
@@ -136,47 +140,47 @@ class SymbolDatabase(object):
KeyError: if the symbol could not be found.
"""
- return self._symbols[symbol]
-
- def GetPrototype(self, descriptor):
- """Builds a proto2 message class based on the passed in descriptor.
-
- Passing a descriptor with a fully qualified name matching a previous
- invocation will cause the same class to be returned.
-
- Args:
- descriptor: The descriptor to build from.
-
- Returns:
- A class describing the passed in descriptor.
- """
-
- return self.GetSymbol(descriptor.full_name)
+ return self._classes[self.pool.FindMessageTypeByName(symbol)]
def GetMessages(self, files):
- """Gets all the messages from a specified file.
-
- This will find and resolve dependencies, failing if they are not registered
- in the symbol database.
+ # TODO(amauryfa): Fix the differences with MessageFactory.
+ """Gets all registered messages from a specified file.
+ Only messages already created and registered will be returned; (this is the
+ case for imported _pb2 modules)
+ But unlike MessageFactory, this version also returns already defined nested
+ messages, but does not register any message extensions.
Args:
files: The file names to extract messages from.
Returns:
- A dictionary mapping proto names to the message classes. This will include
- any dependent messages as well as any messages defined in the same file as
- a specified message.
+ A dictionary mapping proto names to the message classes.
Raises:
KeyError: if a file could not be found.
"""
+ def _GetAllMessages(desc):
+ """Walk a message Descriptor and recursively yields all message names."""
+ yield desc
+ for msg_desc in desc.nested_types:
+ for nested_desc in _GetAllMessages(msg_desc):
+ yield nested_desc
+
result = {}
- for f in files:
- result.update(self._symbols_by_file[f])
+ for file_name in files:
+ file_desc = self.pool.FindFileByName(file_name)
+ for msg_desc in file_desc.message_types_by_name.values():
+ for desc in _GetAllMessages(msg_desc):
+ try:
+ result[desc.full_name] = self._classes[desc]
+ except KeyError:
+ # This descriptor has no registered class, skip it.
+ pass
return result
+
_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 8d256076..2cbd21bc 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -48,15 +48,15 @@ import re
import six
if six.PY3:
- long = int
+ long = int # pylint: disable=redefined-builtin,invalid-name
+# pylint: disable=g-import-not-at-top
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import text_encoding
-__all__ = ['MessageToString', 'PrintMessage', 'PrintField',
- 'PrintFieldValue', 'Merge']
-
+__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue',
+ 'Merge']
_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
type_checkers.Int32ValueChecker(),
@@ -67,6 +67,7 @@ _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
_QUOTES = frozenset(("'", '"'))
+_ANY_FULL_TYPE_NAME = 'google.protobuf.Any'
class Error(Exception):
@@ -74,10 +75,30 @@ class Error(Exception):
class ParseError(Error):
- """Thrown in case of text parsing error."""
+ """Thrown in case of text parsing or tokenizing error."""
+
+ def __init__(self, message=None, line=None, column=None):
+ if message is not None and line is not None:
+ loc = str(line)
+ if column is not None:
+ loc += ':{0}'.format(column)
+ message = '{0} : {1}'.format(loc, message)
+ if message is not None:
+ super(ParseError, self).__init__(message)
+ else:
+ super(ParseError, self).__init__()
+ self._line = line
+ self._column = column
+
+ def GetLine(self):
+ return self._line
+
+ def GetColumn(self):
+ return self._column
class TextWriter(object):
+
def __init__(self, as_utf8):
if six.PY2:
self._writer = io.BytesIO()
@@ -97,9 +118,16 @@ class TextWriter(object):
return self._writer.getvalue()
-def MessageToString(message, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False,
- float_format=None):
+def MessageToString(message,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None,
+ indent=0,
+ message_formatter=None):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -113,20 +141,28 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
as_one_line: Don't introduce newlines between fields.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
- use_index_order: If True, print fields of a proto message using the order
- defined in source code instead of the field number. By default, use the
- field number order.
+ use_index_order: If True, fields of a proto message will be printed using
+ the order defined in source code instead of the field number, extensions
+ will be printed at the end of the message and their relative order is
+ determined by the extension number. By default, use the field number
+ order.
float_format: If set, use this to specify floating point number formatting
(per the "Format Specification Mini-Language"); otherwise, str() is used.
+ use_field_number: If True, print field numbers instead of names.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
+ indent: The indent level, in terms of spaces, for pretty print.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
Returns:
A string of the text formatted protocol buffer message.
"""
out = TextWriter(as_utf8)
- PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, use_field_number,
+ descriptor_pool, message_formatter)
+ printer.PrintMessage(message)
result = out.getvalue()
out.close()
if as_one_line:
@@ -140,142 +176,310 @@ def _IsMapEntry(field):
field.message_type.GetOptions().map_entry)
-def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False,
- float_format=None):
- fields = message.ListFields()
- if use_index_order:
- fields.sort(key=lambda x: x[0].index)
- for field, value in fields:
- if _IsMapEntry(field):
- for key in sorted(value):
- # This is slow for maps with submessage entires because it copies the
- # entire tree. Unfortunately this would take significant refactoring
- # of this file to work around.
- #
- # TODO(haberman): refactor and optimize if this becomes an issue.
- entry_submsg = field.message_type._concrete_class(
- key=key, value=value[key])
- PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order, float_format=float_format)
- elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- for element in value:
- PrintField(field, element, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
- else:
- PrintField(field, value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
+def PrintMessage(message,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None,
+ message_formatter=None):
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, use_field_number,
+ descriptor_pool, message_formatter)
+ printer.PrintMessage(message)
+
+
+def PrintField(field,
+ value,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ message_formatter=None):
+ """Print a single field name/value pair."""
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, message_formatter)
+ printer.PrintField(field, value)
+
+
+def PrintFieldValue(field,
+ value,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ message_formatter=None):
+ """Print a single field value (not including name)."""
+ printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ use_index_order, float_format, message_formatter)
+ printer.PrintFieldValue(field, value)
+
+def _BuildMessageFromTypeName(type_name, descriptor_pool):
+ """Returns a protobuf message instance.
-def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
- pointy_brackets=False, use_index_order=False, float_format=None):
- """Print a single field name/value pair. For repeated fields, the value
- should be a single element.
+ Args:
+ type_name: Fully-qualified protobuf message type name string.
+ descriptor_pool: DescriptorPool instance.
+
+ Returns:
+ A Message instance of type matching type_name, or None if the a Descriptor
+ wasn't found matching type_name.
"""
+ # pylint: disable=g-import-not-at-top
+ if descriptor_pool is None:
+ from google.protobuf import descriptor_pool as pool_mod
+ descriptor_pool = pool_mod.Default()
+ from google.protobuf import symbol_database
+ database = symbol_database.Default()
+ try:
+ message_descriptor = descriptor_pool.FindMessageTypeByName(type_name)
+ except KeyError:
+ return None
+ message_type = database.GetPrototype(message_descriptor)
+ return message_type()
+
+
+class _Printer(object):
+ """Text format printer for protocol message."""
+
+ def __init__(self,
+ out,
+ indent=0,
+ as_utf8=False,
+ as_one_line=False,
+ pointy_brackets=False,
+ use_index_order=False,
+ float_format=None,
+ use_field_number=False,
+ descriptor_pool=None,
+ message_formatter=None):
+ """Initialize the Printer.
+
+ Floating point values can be formatted compactly with 15 digits of
+ precision (which is the most that IEEE 754 "double" can guarantee)
+ using float_format='.15g'. To ensure that converting to text and back to a
+ proto will result in an identical value, float_format='.17g' should be used.
- out.write(' ' * indent)
- if field.is_extension:
- out.write('[')
- if (field.containing_type.GetOptions().message_set_wire_format and
- field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
- field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
- out.write(field.message_type.full_name)
+ Args:
+ out: To record the text format result.
+ indent: The indent level for pretty print.
+ as_utf8: Produce text output in UTF8 format.
+ as_one_line: Don't introduce newlines between fields.
+ pointy_brackets: If True, use angle brackets instead of curly braces for
+ nesting.
+ use_index_order: If True, print fields of a proto message using the order
+ defined in source code instead of the field number. By default, use the
+ field number order.
+ float_format: If set, use this to specify floating point number formatting
+ (per the "Format Specification Mini-Language"); otherwise, str() is
+ used.
+ use_field_number: If True, print field numbers instead of names.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
+ """
+ self.out = out
+ self.indent = indent
+ self.as_utf8 = as_utf8
+ self.as_one_line = as_one_line
+ self.pointy_brackets = pointy_brackets
+ self.use_index_order = use_index_order
+ self.float_format = float_format
+ self.use_field_number = use_field_number
+ self.descriptor_pool = descriptor_pool
+ self.message_formatter = message_formatter
+
+ def _TryPrintAsAnyMessage(self, message):
+ """Serializes if message is a google.protobuf.Any field."""
+ packed_message = _BuildMessageFromTypeName(message.TypeName(),
+ self.descriptor_pool)
+ if packed_message:
+ packed_message.MergeFromString(message.value)
+ self.out.write('%s[%s]' % (self.indent * ' ', message.type_url))
+ self._PrintMessageFieldValue(packed_message)
+ self.out.write(' ' if self.as_one_line else '\n')
+ return True
else:
- out.write(field.full_name)
- out.write(']')
- elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
- # For groups, use the capitalized name.
- out.write(field.message_type.name)
- else:
- out.write(field.name)
-
- if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- # The colon is optional in this case, but our cross-language golden files
- # don't include it.
- out.write(': ')
+ return False
- PrintFieldValue(field, value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
- if as_one_line:
- out.write(' ')
- else:
- out.write('\n')
+ def _TryCustomFormatMessage(self, message):
+ formatted = self.message_formatter(message, self.indent, self.as_one_line)
+ if formatted is None:
+ return False
+ out = self.out
+ out.write(' ' * self.indent)
+ out.write(formatted)
+ out.write(' ' if self.as_one_line else '\n')
+ return True
-def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
- as_one_line=False, pointy_brackets=False,
- use_index_order=False,
- float_format=None):
- """Print a single field value (not including name). For repeated fields,
- the value should be a single element."""
+ def PrintMessage(self, message):
+ """Convert protobuf message to text format.
- if pointy_brackets:
- openb = '<'
- closeb = '>'
- else:
- openb = '{'
- closeb = '}'
-
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- if as_one_line:
- out.write(' %s ' % openb)
- PrintMessage(value, out, indent, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
- out.write(closeb)
+ Args:
+ message: The protocol buffers message.
+ """
+ if self.message_formatter and self._TryCustomFormatMessage(message):
+ return
+ if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and
+ self._TryPrintAsAnyMessage(message)):
+ return
+ fields = message.ListFields()
+ if self.use_index_order:
+ fields.sort(
+ key=lambda x: x[0].number if x[0].is_extension else x[0].index)
+ for field, value in fields:
+ if _IsMapEntry(field):
+ for key in sorted(value):
+ # This is slow for maps with submessage entries because it copies the
+ # entire tree. Unfortunately this would take significant refactoring
+ # of this file to work around.
+ #
+ # TODO(haberman): refactor and optimize if this becomes an issue.
+ entry_submsg = value.GetEntryClass()(key=key, value=value[key])
+ self.PrintField(field, entry_submsg)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ for element in value:
+ self.PrintField(field, element)
+ else:
+ self.PrintField(field, value)
+
+ def PrintField(self, field, value):
+ """Print a single field name/value pair."""
+ out = self.out
+ out.write(' ' * self.indent)
+ if self.use_field_number:
+ out.write(str(field.number))
else:
- out.write(' %s\n' % openb)
- PrintMessage(value, out, indent + 2, as_utf8, as_one_line,
- pointy_brackets=pointy_brackets,
- use_index_order=use_index_order,
- float_format=float_format)
- out.write(' ' * indent + closeb)
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- enum_value = field.enum_type.values_by_number.get(value, None)
- if enum_value is not None:
- out.write(enum_value.name)
+ if field.is_extension:
+ out.write('[')
+ if (field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
+ out.write(field.message_type.full_name)
+ else:
+ out.write(field.full_name)
+ out.write(']')
+ elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
+ # For groups, use the capitalized name.
+ out.write(field.message_type.name)
+ else:
+ out.write(field.name)
+
+ if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # The colon is optional in this case, but our cross-language golden files
+ # don't include it.
+ out.write(': ')
+
+ self.PrintFieldValue(field, value)
+ if self.as_one_line:
+ out.write(' ')
else:
- out.write(str(value))
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
- out.write('\"')
- if isinstance(value, six.text_type):
- out_value = value.encode('utf-8')
+ out.write('\n')
+
+ def _PrintMessageFieldValue(self, value):
+ if self.pointy_brackets:
+ openb = '<'
+ closeb = '>'
else:
- out_value = value
- if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # We need to escape non-UTF8 chars in TYPE_BYTES field.
- out_as_utf8 = False
+ openb = '{'
+ closeb = '}'
+
+ if self.as_one_line:
+ self.out.write(' %s ' % openb)
+ self.PrintMessage(value)
+ self.out.write(closeb)
else:
- out_as_utf8 = as_utf8
- out.write(text_encoding.CEscape(out_value, out_as_utf8))
- out.write('\"')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- if value:
- out.write('true')
+ self.out.write(' %s\n' % openb)
+ self.indent += 2
+ self.PrintMessage(value)
+ self.indent -= 2
+ self.out.write(' ' * self.indent + closeb)
+
+ def PrintFieldValue(self, field, value):
+ """Print a single field value (not including name).
+
+ For repeated fields, the value should be a single element.
+
+ Args:
+ field: The descriptor of the field to be printed.
+ value: The value of the field.
+ """
+ out = self.out
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ self._PrintMessageFieldValue(value)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ out.write(enum_value.name)
+ else:
+ out.write(str(value))
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ out.write('\"')
+ if isinstance(value, six.text_type):
+ out_value = value.encode('utf-8')
+ else:
+ out_value = value
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # We need to escape non-UTF8 chars in TYPE_BYTES field.
+ out_as_utf8 = False
+ else:
+ out_as_utf8 = self.as_utf8
+ out.write(text_encoding.CEscape(out_value, out_as_utf8))
+ out.write('\"')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ if value:
+ out.write('true')
+ else:
+ out.write('false')
+ elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None:
+ out.write('{1:{0}}'.format(self.float_format, value))
else:
- out.write('false')
- elif field.cpp_type in _FLOAT_TYPES and float_format is not None:
- out.write('{1:{0}}'.format(float_format, value))
- else:
- out.write(str(value))
+ out.write(str(value))
+
+
+def Parse(text,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
+ NOTE: for historical reasons this function does not clear the input
+ message. This is different from what the binary msg.ParseFrom(...) does.
-def Parse(text, message, allow_unknown_extension=False):
- """Parses an text representation of a protocol message into a message.
+ Example
+ a = MyProto()
+ a.repeated_field.append('test')
+ b = MyProto()
+
+ text_format.Parse(repr(a), b)
+ text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"]
+
+ # Binary version:
+ b.ParseFromString(a.SerializeToString()) # repeated_field is now "test"
+
+ Caller is responsible for clearing the message as needed.
Args:
text: Message text representation.
message: A protocol buffer message to merge into.
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
+ allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -284,12 +488,23 @@ def Parse(text, message, allow_unknown_extension=False):
ParseError: On text parsing problems.
"""
if not isinstance(text, str):
- text = text.decode('utf-8')
- return ParseLines(text.split('\n'), message, allow_unknown_extension)
+ if six.PY3:
+ text = text.decode('utf-8')
+ else:
+ text = text.encode('utf-8')
+ return ParseLines(text.split('\n'),
+ message,
+ allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
-def Merge(text, message, allow_unknown_extension=False):
- """Parses an text representation of a protocol message into a message.
+def Merge(text,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
Like Parse(), but allows repeated values for a non-repeated field, and uses
the last one.
@@ -299,6 +514,8 @@ def Merge(text, message, allow_unknown_extension=False):
message: A protocol buffer message to merge into.
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
+ allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -306,17 +523,33 @@ def Merge(text, message, allow_unknown_extension=False):
Raises:
ParseError: On text parsing problems.
"""
- return MergeLines(text.split('\n'), message, allow_unknown_extension)
-
-
-def ParseLines(lines, message, allow_unknown_extension=False):
- """Parses an text representation of a protocol message into a message.
+ if not isinstance(text, str):
+ if six.PY3:
+ text = text.decode('utf-8')
+ else:
+ text = text.encode('utf-8')
+ return MergeLines(
+ text.split('\n'),
+ message,
+ allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
+
+
+def ParseLines(lines,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
Args:
lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
+ allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -324,18 +557,26 @@ def ParseLines(lines, message, allow_unknown_extension=False):
Raises:
ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, False, allow_unknown_extension)
- return message
+ parser = _Parser(allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
+ return parser.ParseLines(lines, message)
-def MergeLines(lines, message, allow_unknown_extension=False):
- """Parses an text representation of a protocol message into a message.
+def MergeLines(lines,
+ message,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ """Parses a text representation of a protocol message into a message.
Args:
lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
allow_unknown_extension: if True, skip over missing extensions and keep
parsing
+ allow_field_number: if True, both field number and field name are allowed.
+ descriptor_pool: A DescriptorPool used to resolve Any types.
Returns:
The same message passed as argument.
@@ -343,108 +584,220 @@ def MergeLines(lines, message, allow_unknown_extension=False):
Raises:
ParseError: On text parsing problems.
"""
- _ParseOrMerge(lines, message, True, allow_unknown_extension)
- return message
+ parser = _Parser(allow_unknown_extension,
+ allow_field_number,
+ descriptor_pool=descriptor_pool)
+ return parser.MergeLines(lines, message)
+
+
+class _Parser(object):
+ """Text format parser for protocol message."""
+
+ def __init__(self,
+ allow_unknown_extension=False,
+ allow_field_number=False,
+ descriptor_pool=None):
+ self.allow_unknown_extension = allow_unknown_extension
+ self.allow_field_number = allow_field_number
+ self.descriptor_pool = descriptor_pool
+
+ def ParseFromString(self, text, message):
+ """Parses a text representation of a protocol message into a message."""
+ if not isinstance(text, str):
+ text = text.decode('utf-8')
+ return self.ParseLines(text.split('\n'), message)
+
+ def ParseLines(self, lines, message):
+ """Parses a text representation of a protocol message into a message."""
+ self._allow_multiple_scalars = False
+ self._ParseOrMerge(lines, message)
+ return message
+
+ def MergeFromString(self, text, message):
+ """Merges a text representation of a protocol message into a message."""
+ return self._MergeLines(text.split('\n'), message)
+
+ def MergeLines(self, lines, message):
+ """Merges a text representation of a protocol message into a message."""
+ self._allow_multiple_scalars = True
+ self._ParseOrMerge(lines, message)
+ return message
+
+ def _ParseOrMerge(self, lines, message):
+ """Converts a text representation of a protocol message into a message.
+ Args:
+ lines: Lines of a message's text representation.
+ message: A protocol buffer message to merge into.
-def _ParseOrMerge(lines,
- message,
- allow_multiple_scalars,
- allow_unknown_extension=False):
- """Converts an text representation of a protocol message into a message.
+ Raises:
+ ParseError: On text parsing problems.
+ """
+ tokenizer = Tokenizer(lines)
+ while not tokenizer.AtEnd():
+ self._MergeField(tokenizer, message)
- Args:
- lines: Lines of a message's text representation.
- message: A protocol buffer message to merge into.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
- allow_unknown_extension: if True, skip over missing extensions and keep
- parsing
+ def _MergeField(self, tokenizer, message):
+ """Merges a single protocol message field into a message.
- Raises:
- ParseError: On text parsing problems.
- """
- tokenizer = _Tokenizer(lines)
- while not tokenizer.AtEnd():
- _MergeField(tokenizer, message, allow_multiple_scalars,
- allow_unknown_extension)
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ message: A protocol message to record the data.
+ Raises:
+ ParseError: In case of text parsing problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ if (message_descriptor.full_name == _ANY_FULL_TYPE_NAME and
+ tokenizer.TryConsume('[')):
+ type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
+ tokenizer.Consume(']')
+ tokenizer.TryConsume(':')
+ if tokenizer.TryConsume('<'):
+ expanded_any_end_token = '>'
+ else:
+ tokenizer.Consume('{')
+ expanded_any_end_token = '}'
+ expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
+ self.descriptor_pool)
+ if not expanded_any_sub_message:
+ raise ParseError('Type %s not found in descriptor pool' %
+ packed_type_name)
+ while not tokenizer.TryConsume(expanded_any_end_token):
+ if tokenizer.AtEnd():
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' %
+ (expanded_any_end_token,))
+ self._MergeField(tokenizer, expanded_any_sub_message)
+ message.Pack(expanded_any_sub_message,
+ type_url_prefix=type_url_prefix)
+ return
-def _MergeField(tokenizer,
- message,
- allow_multiple_scalars,
- allow_unknown_extension=False):
- """Merges a single protocol message field into a message.
+ if tokenizer.TryConsume('['):
+ name = [tokenizer.ConsumeIdentifier()]
+ while tokenizer.TryConsume('.'):
+ name.append(tokenizer.ConsumeIdentifier())
+ name = '.'.join(name)
- Args:
- tokenizer: A tokenizer to parse the field name and values.
- message: A protocol message to record the data.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
- allow_unknown_extension: if True, skip over missing extensions and keep
- parsing
+ if not message_descriptor.is_extendable:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" does not have extensions.' %
+ message_descriptor.full_name)
+ # pylint: disable=protected-access
+ field = message.Extensions._FindExtensionByName(name)
+ # pylint: enable=protected-access
+ if not field:
+ if self.allow_unknown_extension:
+ field = None
+ else:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" not registered. '
+ 'Did you import the _pb2 module which defines it? '
+ 'If you are trying to place the extension in the MessageSet '
+ 'field of another message that is in an Any or MessageSet field, '
+ 'that message\'s _pb2 module must be imported as well' % name)
+ elif message_descriptor != field.containing_type:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" does not extend message type "%s".' %
+ (name, message_descriptor.full_name))
- Raises:
- ParseError: In case of text parsing problems.
- """
- message_descriptor = message.DESCRIPTOR
- if (hasattr(message_descriptor, 'syntax') and
- message_descriptor.syntax == 'proto3'):
- # Proto3 doesn't represent presence so we can't test if multiple
- # scalars have occurred. We have to allow them.
- allow_multiple_scalars = True
- if tokenizer.TryConsume('['):
+ tokenizer.Consume(']')
+
+ else:
+ name = tokenizer.ConsumeIdentifierOrNumber()
+ if self.allow_field_number and name.isdigit():
+ number = ParseInteger(name, True, True)
+ field = message_descriptor.fields_by_number.get(number, None)
+ if not field and message_descriptor.is_extendable:
+ field = message.Extensions._FindExtensionByNumber(number)
+ else:
+ field = message_descriptor.fields_by_name.get(name, None)
+
+ # Group names are expected to be capitalized as they appear in the
+ # .proto file, which actually matches their type names, not their field
+ # names.
+ if not field:
+ field = message_descriptor.fields_by_name.get(name.lower(), None)
+ if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
+ field = None
+
+ if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
+ field.message_type.name != name):
+ field = None
+
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" has no field named "%s".' %
+ (message_descriptor.full_name, name))
+
+ if field:
+ if not self._allow_multiple_scalars and field.containing_oneof:
+ # Check if there's a different field set in this oneof.
+ # Note that we ignore the case if the same field was set before, and we
+ # apply _allow_multiple_scalars to non-scalar fields as well.
+ which_oneof = message.WhichOneof(field.containing_oneof.name)
+ if which_oneof is not None and which_oneof != field.name:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Field "%s" is specified along with field "%s", another member '
+ 'of oneof "%s" for message type "%s".' %
+ (field.name, which_oneof, field.containing_oneof.name,
+ message_descriptor.full_name))
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ tokenizer.TryConsume(':')
+ merger = self._MergeMessageField
+ else:
+ tokenizer.Consume(':')
+ merger = self._MergeScalarField
+
+ if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
+ tokenizer.TryConsume('[')):
+ # Short repeated format, e.g. "foo: [1, 2, 3]"
+ if not tokenizer.TryConsume(']'):
+ while True:
+ merger(tokenizer, message, field)
+ if tokenizer.TryConsume(']'):
+ break
+ tokenizer.Consume(',')
+
+ else:
+ merger(tokenizer, message, field)
+
+ else: # Proto field is unknown.
+ assert self.allow_unknown_extension
+ _SkipFieldContents(tokenizer)
+
+ # For historical reasons, fields may optionally be separated by commas or
+ # semicolons.
+ if not tokenizer.TryConsume(','):
+ tokenizer.TryConsume(';')
+
+ def _ConsumeAnyTypeUrl(self, tokenizer):
+ """Consumes a google.protobuf.Any type URL and returns the type name."""
+ # Consume "type.googleapis.com/".
+ prefix = [tokenizer.ConsumeIdentifier()]
+ tokenizer.Consume('.')
+ prefix.append(tokenizer.ConsumeIdentifier())
+ tokenizer.Consume('.')
+ prefix.append(tokenizer.ConsumeIdentifier())
+ tokenizer.Consume('/')
+ # Consume the fully-qualified type name.
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
name.append(tokenizer.ConsumeIdentifier())
- name = '.'.join(name)
-
- if not message_descriptor.is_extendable:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" does not have extensions.' %
- message_descriptor.full_name)
- # pylint: disable=protected-access
- field = message.Extensions._FindExtensionByName(name)
- # pylint: enable=protected-access
- if not field:
- if allow_unknown_extension:
- field = None
- else:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
- elif message_descriptor != field.containing_type:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" does not extend message type "%s".' % (
- name, message_descriptor.full_name))
+ return '.'.join(prefix), '.'.join(name)
- tokenizer.Consume(']')
+ def _MergeMessageField(self, tokenizer, message, field):
+ """Merges a single scalar field into a message.
- else:
- name = tokenizer.ConsumeIdentifier()
- field = message_descriptor.fields_by_name.get(name, None)
-
- # Group names are expected to be capitalized as they appear in the
- # .proto file, which actually matches their type names, not their field
- # names.
- if not field:
- field = message_descriptor.fields_by_name.get(name.lower(), None)
- if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
- field = None
-
- if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
- field.message_type.name != name):
- field = None
-
- if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" has no field named "%s".' % (
- message_descriptor.full_name, name))
-
- if field and field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: The message of which field is a member.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of text parsing problems.
+ """
is_map_entry = _IsMapEntry(field)
- tokenizer.TryConsume(':')
if tokenizer.TryConsume('<'):
end_token = '>'
@@ -456,21 +809,32 @@ def _MergeField(tokenizer,
if field.is_extension:
sub_message = message.Extensions[field].add()
elif is_map_entry:
- sub_message = field.message_type._concrete_class()
+ sub_message = getattr(message, field.name).GetEntryClass()()
else:
sub_message = getattr(message, field.name).add()
else:
if field.is_extension:
+ if (not self._allow_multiple_scalars and
+ message.HasExtension(field)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" extensions.' %
+ (message.DESCRIPTOR.full_name, field.full_name))
sub_message = message.Extensions[field]
else:
+ # Also apply _allow_multiple_scalars to message field.
+ # TODO(jieluo): Change to _allow_singular_overwrites.
+ if (not self._allow_multiple_scalars and
+ message.HasField(field.name)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" fields.' %
+ (message.DESCRIPTOR.full_name, field.name))
sub_message = getattr(message, field.name)
sub_message.SetInParent()
while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd():
- raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
- _MergeField(tokenizer, sub_message, allow_multiple_scalars,
- allow_unknown_extension)
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,))
+ self._MergeField(tokenizer, sub_message)
if is_map_entry:
value_cpptype = field.message_type.fields_by_name['value'].cpp_type
@@ -479,26 +843,81 @@ def _MergeField(tokenizer,
value.MergeFrom(sub_message.value)
else:
getattr(message, field.name)[sub_message.key] = sub_message.value
- elif field:
- tokenizer.Consume(':')
- if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and
- tokenizer.TryConsume('[')):
- # Short repeated format, e.g. "foo: [1, 2, 3]"
- while True:
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
- if tokenizer.TryConsume(']'):
- break
- tokenizer.Consume(',')
+
+ @staticmethod
+ def _IsProto3Syntax(message):
+ message_descriptor = message.DESCRIPTOR
+ return (hasattr(message_descriptor, 'syntax') and
+ message_descriptor.syntax == 'proto3')
+
+ def _MergeScalarField(self, tokenizer, message, field):
+ """Merges a single scalar field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: A protocol message to record the data.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of text parsing problems.
+ RuntimeError: On runtime errors.
+ """
+ _ = self.allow_unknown_extension
+ value = None
+
+ if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
+ descriptor.FieldDescriptor.TYPE_SINT32,
+ descriptor.FieldDescriptor.TYPE_SFIXED32):
+ value = _ConsumeInt32(tokenizer)
+ elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
+ descriptor.FieldDescriptor.TYPE_SINT64,
+ descriptor.FieldDescriptor.TYPE_SFIXED64):
+ value = _ConsumeInt64(tokenizer)
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
+ descriptor.FieldDescriptor.TYPE_FIXED32):
+ value = _ConsumeUint32(tokenizer)
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
+ descriptor.FieldDescriptor.TYPE_FIXED64):
+ value = _ConsumeUint64(tokenizer)
+ elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
+ descriptor.FieldDescriptor.TYPE_DOUBLE):
+ value = tokenizer.ConsumeFloat()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ value = tokenizer.ConsumeBool()
+ elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
+ value = tokenizer.ConsumeString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ value = tokenizer.ConsumeByteString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ value = tokenizer.ConsumeEnum(field)
else:
- _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
- else: # Proto field is unknown.
- assert allow_unknown_extension
- _SkipFieldContents(tokenizer)
+ raise RuntimeError('Unknown field type %d' % field.type)
- # For historical reasons, fields may optionally be separated by commas or
- # semicolons.
- if not tokenizer.TryConsume(','):
- tokenizer.TryConsume(';')
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ message.Extensions[field].append(value)
+ else:
+ getattr(message, field.name).append(value)
+ else:
+ # Proto3 doesn't represent presence so we can't test if multiple scalars
+ # have occurred. We have to allow them.
+ can_check_presence = not self._IsProto3Syntax(message)
+ if field.is_extension:
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasExtension(field)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" extensions.' %
+ (message.DESCRIPTOR.full_name, field.full_name))
+ else:
+ message.Extensions[field] = value
+ else:
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasField(field.name)):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" fields.' %
+ (message.DESCRIPTOR.full_name, field.name))
+ else:
+ setattr(message, field.name, value)
def _SkipFieldContents(tokenizer):
@@ -533,7 +952,7 @@ def _SkipField(tokenizer):
tokenizer.ConsumeIdentifier()
tokenizer.Consume(']')
else:
- tokenizer.ConsumeIdentifier()
+ tokenizer.ConsumeIdentifierOrNumber()
_SkipFieldContents(tokenizer)
@@ -571,88 +990,20 @@ def _SkipFieldValue(tokenizer):
Raises:
ParseError: In case an invalid field value is found.
"""
- # String tokens can come in multiple adjacent string literals.
+ # String/bytes tokens can come in multiple adjacent string literals.
# If we can consume one, consume as many as we can.
- if tokenizer.TryConsumeString():
- while tokenizer.TryConsumeString():
+ if tokenizer.TryConsumeByteString():
+ while tokenizer.TryConsumeByteString():
pass
return
if (not tokenizer.TryConsumeIdentifier() and
- not tokenizer.TryConsumeInt64() and
- not tokenizer.TryConsumeUint64() and
+ not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and
not tokenizer.TryConsumeFloat()):
raise ParseError('Invalid field value: ' + tokenizer.token)
-def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
- """Merges a single protocol message scalar field into a message.
-
- Args:
- tokenizer: A tokenizer to parse the field value.
- message: A protocol message to record the data.
- field: The descriptor of the field to be merged.
- allow_multiple_scalars: Determines if repeated values for a non-repeated
- field are permitted, e.g., the string "foo: 1 foo: 2" for a
- required/optional field named "foo".
-
- Raises:
- ParseError: In case of text parsing problems.
- RuntimeError: On runtime errors.
- """
- value = None
-
- if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
- descriptor.FieldDescriptor.TYPE_SINT32,
- descriptor.FieldDescriptor.TYPE_SFIXED32):
- value = tokenizer.ConsumeInt32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
- descriptor.FieldDescriptor.TYPE_SINT64,
- descriptor.FieldDescriptor.TYPE_SFIXED64):
- value = tokenizer.ConsumeInt64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
- descriptor.FieldDescriptor.TYPE_FIXED32):
- value = tokenizer.ConsumeUint32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
- descriptor.FieldDescriptor.TYPE_FIXED64):
- value = tokenizer.ConsumeUint64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
- descriptor.FieldDescriptor.TYPE_DOUBLE):
- value = tokenizer.ConsumeFloat()
- elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
- value = tokenizer.ConsumeBool()
- elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
- value = tokenizer.ConsumeString()
- elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- value = tokenizer.ConsumeByteString()
- elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
- value = tokenizer.ConsumeEnum(field)
- else:
- raise RuntimeError('Unknown field type %d' % field.type)
-
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- if field.is_extension:
- message.Extensions[field].append(value)
- else:
- getattr(message, field.name).append(value)
- else:
- if field.is_extension:
- if not allow_multiple_scalars and message.HasExtension(field):
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" should not have multiple "%s" extensions.' %
- (message.DESCRIPTOR.full_name, field.full_name))
- else:
- message.Extensions[field] = value
- else:
- if not allow_multiple_scalars and message.HasField(field.name):
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" should not have multiple "%s" fields.' %
- (message.DESCRIPTOR.full_name, field.name))
- else:
- setattr(message, field.name, value)
-
-
-class _Tokenizer(object):
+class Tokenizer(object):
"""Protocol buffer text representation tokenizer.
This class handles the lower level string parsing by splitting it into
@@ -661,17 +1012,20 @@ class _Tokenizer(object):
It was directly ported from the Java protocol buffer API.
"""
- _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
+ _WHITESPACE = re.compile(r'\s+')
+ _COMMENT = re.compile(r'(\s*#.*$)', re.MULTILINE)
+ _WHITESPACE_OR_COMMENT = re.compile(r'(\s|(#.*$))+', re.MULTILINE)
_TOKEN = re.compile('|'.join([
- r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
+ r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
- ] + [ # quoted str for each quote mark
+ ] + [ # quoted str for each quote mark
r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
]))
- _IDENTIFIER = re.compile(r'\w+')
+ _IDENTIFIER = re.compile(r'[^\d\W]\w*')
+ _IDENTIFIER_OR_NUMBER = re.compile(r'\w+')
- def __init__(self, lines):
+ def __init__(self, lines, skip_comments=True):
self._position = 0
self._line = -1
self._column = 0
@@ -682,6 +1036,9 @@ class _Tokenizer(object):
self._previous_line = 0
self._previous_column = 0
self._more_lines = True
+ self._skip_comments = skip_comments
+ self._whitespace_pattern = (skip_comments and self._WHITESPACE_OR_COMMENT
+ or self._WHITESPACE)
self._SkipWhitespace()
self.NextToken()
@@ -711,7 +1068,7 @@ class _Tokenizer(object):
def _SkipWhitespace(self):
while True:
self._PopLine()
- match = self._WHITESPACE.match(self._current_line, self._column)
+ match = self._whitespace_pattern.match(self._current_line, self._column)
if not match:
break
length = len(match.group(0))
@@ -741,7 +1098,30 @@ class _Tokenizer(object):
ParseError: If the text couldn't be consumed.
"""
if not self.TryConsume(token):
- raise self._ParseError('Expected "%s".' % token)
+ raise self.ParseError('Expected "%s".' % token)
+
+ def ConsumeComment(self):
+ result = self.token
+ if not self._COMMENT.match(result):
+ raise self.ParseError('Expected comment.')
+ self.NextToken()
+ return result
+
+ def ConsumeCommentOrTrailingComment(self):
+ """Consumes a comment, returns a 2-tuple (trailing bool, comment str)."""
+
+ # Tokenizer initializes _previous_line and _previous_column to 0. As the
+ # tokenizer starts, it looks like there is a previous token on the line.
+ just_started = self._line == 0 and self._column == 0
+
+ before_parsing = self._previous_line
+ comment = self.ConsumeComment()
+
+ # A trailing comment is a comment on the same line than the previous token.
+ trailing = (self._previous_line == before_parsing
+ and not just_started)
+
+ return trailing, comment
def TryConsumeIdentifier(self):
try:
@@ -761,85 +1141,55 @@ class _Tokenizer(object):
"""
result = self.token
if not self._IDENTIFIER.match(result):
- raise self._ParseError('Expected identifier.')
- self.NextToken()
- return result
-
- def ConsumeInt32(self):
- """Consumes a signed 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If a signed 32bit integer couldn't be consumed.
- """
- try:
- result = ParseInteger(self.token, is_signed=True, is_long=False)
- except ValueError as e:
- raise self._ParseError(str(e))
- self.NextToken()
- return result
-
- def ConsumeUint32(self):
- """Consumes an unsigned 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If an unsigned 32bit integer couldn't be consumed.
- """
- try:
- result = ParseInteger(self.token, is_signed=False, is_long=False)
- except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError('Expected identifier.')
self.NextToken()
return result
- def TryConsumeInt64(self):
+ def TryConsumeIdentifierOrNumber(self):
try:
- self.ConsumeInt64()
+ self.ConsumeIdentifierOrNumber()
return True
except ParseError:
return False
- def ConsumeInt64(self):
- """Consumes a signed 64bit integer number.
+ def ConsumeIdentifierOrNumber(self):
+ """Consumes protocol message field identifier.
Returns:
- The integer parsed.
+ Identifier string.
Raises:
- ParseError: If a signed 64bit integer couldn't be consumed.
+ ParseError: If an identifier couldn't be consumed.
"""
- try:
- result = ParseInteger(self.token, is_signed=True, is_long=True)
- except ValueError as e:
- raise self._ParseError(str(e))
+ result = self.token
+ if not self._IDENTIFIER_OR_NUMBER.match(result):
+ raise self.ParseError('Expected identifier or number, got %s.' % result)
self.NextToken()
return result
- def TryConsumeUint64(self):
+ def TryConsumeInteger(self):
try:
- self.ConsumeUint64()
+ # Note: is_long only affects value type, not whether an error is raised.
+ self.ConsumeInteger()
return True
except ParseError:
return False
- def ConsumeUint64(self):
- """Consumes an unsigned 64bit integer number.
+ def ConsumeInteger(self, is_long=False):
+ """Consumes an integer number.
+ Args:
+ is_long: True if the value should be returned as a long integer.
Returns:
The integer parsed.
Raises:
- ParseError: If an unsigned 64bit integer couldn't be consumed.
+ ParseError: If an integer couldn't be consumed.
"""
try:
- result = ParseInteger(self.token, is_signed=False, is_long=True)
+ result = _ParseAbstractInteger(self.token, is_long=is_long)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -862,7 +1212,7 @@ class _Tokenizer(object):
try:
result = ParseFloat(self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -878,13 +1228,13 @@ class _Tokenizer(object):
try:
result = ParseBool(self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
- def TryConsumeString(self):
+ def TryConsumeByteString(self):
try:
- self.ConsumeString()
+ self.ConsumeByteString()
return True
except ParseError:
return False
@@ -932,15 +1282,15 @@ class _Tokenizer(object):
"""
text = self.token
if len(text) < 1 or text[0] not in _QUOTES:
- raise self._ParseError('Expected string but found: %r' % (text,))
+ raise self.ParseError('Expected string but found: %r' % (text,))
if len(text) < 2 or text[-1] != text[0]:
- raise self._ParseError('String missing ending quote: %r' % (text,))
+ raise self.ParseError('String missing ending quote: %r' % (text,))
try:
result = text_encoding.CUnescape(text[1:-1])
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -948,7 +1298,7 @@ class _Tokenizer(object):
try:
result = ParseEnum(field, self.token)
except ValueError as e:
- raise self._ParseError(str(e))
+ raise self.ParseError(str(e))
self.NextToken()
return result
@@ -961,16 +1311,15 @@ class _Tokenizer(object):
Returns:
A ParseError instance.
"""
- return ParseError('%d:%d : %s' % (
- self._previous_line + 1, self._previous_column + 1, message))
+ return ParseError(message, self._previous_line + 1,
+ self._previous_column + 1)
- def _ParseError(self, message):
+ def ParseError(self, message):
"""Creates and *returns* a ParseError for the current token."""
- return ParseError('%d:%d : %s' % (
- self._line + 1, self._column + 1, message))
+ return ParseError(message, self._line + 1, self._column + 1)
def _StringParseError(self, e):
- return self._ParseError('Couldn\'t parse string: ' + str(e))
+ return self.ParseError('Couldn\'t parse string: ' + str(e))
def NextToken(self):
"""Reads the next meaningful token."""
@@ -985,12 +1334,124 @@ class _Tokenizer(object):
return
match = self._TOKEN.match(self._current_line, self._column)
+ if not match and not self._skip_comments:
+ match = self._COMMENT.match(self._current_line, self._column)
if match:
token = match.group(0)
self.token = token
else:
self.token = self._current_line[self._column]
+# Aliased so it can still be accessed by current visibility violators.
+# TODO(dbarnett): Migrate violators to textformat_tokenizer.
+_Tokenizer = Tokenizer # pylint: disable=invalid-name
+
+
+def _ConsumeInt32(tokenizer):
+ """Consumes a signed 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=True, is_long=False)
+
+
+def _ConsumeUint32(tokenizer):
+ """Consumes an unsigned 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=False, is_long=False)
+
+
+def _TryConsumeInt64(tokenizer):
+ try:
+ _ConsumeInt64(tokenizer)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeInt64(tokenizer):
+ """Consumes a signed 32bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=True, is_long=True)
+
+
+def _TryConsumeUint64(tokenizer):
+ try:
+ _ConsumeUint64(tokenizer)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeUint64(tokenizer):
+ """Consumes an unsigned 64bit integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 64bit integer couldn't be consumed.
+ """
+ return _ConsumeInteger(tokenizer, is_signed=False, is_long=True)
+
+
+def _TryConsumeInteger(tokenizer, is_signed=False, is_long=False):
+ try:
+ _ConsumeInteger(tokenizer, is_signed=is_signed, is_long=is_long)
+ return True
+ except ParseError:
+ return False
+
+
+def _ConsumeInteger(tokenizer, is_signed=False, is_long=False):
+ """Consumes an integer number from tokenizer.
+
+ Args:
+ tokenizer: A tokenizer used to parse the number.
+ is_signed: True if a signed integer must be parsed.
+ is_long: True if a long integer must be parsed.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an integer with given characteristics couldn't be consumed.
+ """
+ try:
+ result = ParseInteger(tokenizer.token, is_signed=is_signed, is_long=is_long)
+ except ValueError as e:
+ raise tokenizer.ParseError(str(e))
+ tokenizer.NextToken()
+ return result
+
def ParseInteger(text, is_signed=False, is_long=False):
"""Parses an integer.
@@ -1007,22 +1468,39 @@ def ParseInteger(text, is_signed=False, is_long=False):
ValueError: Thrown Iff the text is not a valid integer.
"""
# Do the actual parsing. Exception handling is propagated to caller.
+ result = _ParseAbstractInteger(text, is_long=is_long)
+
+ # Check if the integer is sane. Exceptions handled by callers.
+ checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
+ checker.CheckValue(result)
+ return result
+
+
+def _ParseAbstractInteger(text, is_long=False):
+ """Parses an integer without checking size/signedness.
+
+ Args:
+ text: The text to parse.
+ is_long: True if the value should be returned as a long integer.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ValueError: Thrown Iff the text is not a valid integer.
+ """
+ # Do the actual parsing. Exception handling is propagated to caller.
try:
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
# (e.g. the C++ implementation) simpler.
if is_long:
- result = long(text, 0)
+ return long(text, 0)
else:
- result = int(text, 0)
+ return int(text, 0)
except ValueError:
raise ValueError('Couldn\'t parse integer: %s' % text)
- # Check if the integer is sane. Exceptions handled by callers.
- checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
- checker.CheckValue(result)
- return result
-
def ParseFloat(text):
"""Parse a floating point number.
@@ -1068,9 +1546,9 @@ def ParseBool(text):
Raises:
ValueError: If text is not a valid boolean.
"""
- if text in ('true', 't', '1'):
+ if text in ('true', 't', '1', 'True'):
return True
- elif text in ('false', 'f', '0'):
+ elif text in ('false', 'f', '0', 'False'):
return False
else:
raise ValueError('Expected "true" or "false".')
@@ -1099,14 +1577,17 @@ def ParseEnum(field, value):
# Identifier.
enum_value = enum_descriptor.values_by_name.get(value, None)
if enum_value is None:
- raise ValueError(
- 'Enum type "%s" has no value named %s.' % (
- enum_descriptor.full_name, value))
+ raise ValueError('Enum type "%s" has no value named %s.' %
+ (enum_descriptor.full_name, value))
else:
# Numeric value.
+ if hasattr(field.file, 'syntax'):
+ # Attribute is checked for compatibility.
+ if field.file.syntax == 'proto3':
+ # Proto3 accept numeric unknown enums.
+ return number
enum_value = enum_descriptor.values_by_number.get(number, None)
if enum_value is None:
- raise ValueError(
- 'Enum type "%s" has no value with number %d.' % (
- enum_descriptor.full_name, number))
+ raise ValueError('Enum type "%s" has no value with number %d.' %
+ (enum_descriptor.full_name, number))
return enum_value.number
diff --git a/python/google/protobuf/util/__init__.py b/python/google/protobuf/util/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/python/google/protobuf/util/__init__.py