aboutsummaryrefslogtreecommitdiffhomepage
path: root/python
diff options
context:
space:
mode:
authorGravatar Bo Yang <teboring@google.com>2015-05-21 14:28:59 -0700
committerGravatar Bo Yang <teboring@google.com>2015-05-21 19:32:02 -0700
commit5db217305f37a79eeccd70f000088a06ec82fcec (patch)
treebe53dcf0c0b47ef9178ab8a6fa5c1946ee84a28f /python
parent56095026ccc2f755a6fdb296e30c3ddec8f556a2 (diff)
down-integrate internal changes
Diffstat (limited to 'python')
-rwxr-xr-xpython/google/protobuf/descriptor.py46
-rw-r--r--python/google/protobuf/descriptor_pool.py39
-rwxr-xr-xpython/google/protobuf/internal/containers.py297
-rwxr-xr-xpython/google/protobuf/internal/decoder.py44
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py1
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py45
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py2
-rwxr-xr-xpython/google/protobuf/internal/encoder.py55
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py1
-rw-r--r--python/google/protobuf/internal/message_factory_test.py1
-rwxr-xr-xpython/google/protobuf/internal/message_test.py521
-rw-r--r--python/google/protobuf/internal/proto_builder_test.py21
-rwxr-xr-xpython/google/protobuf/internal/python_message.py159
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py6
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py3
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py1
-rwxr-xr-xpython/google/protobuf/internal/test_util.py24
-rwxr-xr-xpython/google/protobuf/internal/text_encoding_test.py1
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py141
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py9
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py1
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py1
-rw-r--r--python/google/protobuf/proto_builder.py20
-rw-r--r--python/google/protobuf/pyext/descriptor.cc287
-rw-r--r--python/google/protobuf/pyext/descriptor.h21
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.cc26
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.h5
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc185
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h8
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc17
-rw-r--r--python/google/protobuf/pyext/message.cc432
-rw-r--r--python/google/protobuf/pyext/message.h13
-rw-r--r--python/google/protobuf/pyext/message_map_container.cc540
-rw-r--r--python/google/protobuf/pyext/message_map_container.h117
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc56
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.h10
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc8
-rw-r--r--python/google/protobuf/pyext/scalar_map_container.cc514
-rw-r--r--python/google/protobuf/pyext/scalar_map_container.h110
-rwxr-xr-xpython/google/protobuf/reflection.py1
-rwxr-xr-xpython/google/protobuf/text_format.py36
-rwxr-xr-xpython/setup.py1
42 files changed, 3282 insertions, 544 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index f7a58ca0..970b1a88 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -245,9 +245,6 @@ class Descriptor(_NestedDescriptorBase):
is_extendable: Does this type define any extension ranges?
- options: (descriptor_pb2.MessageOptions) Protocol message options or None
- to use default message options.
-
oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields
in this message.
oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|,
@@ -265,7 +262,7 @@ class Descriptor(_NestedDescriptorBase):
file=None, serialized_start=None, serialized_end=None,
syntax=None):
_message.Message._CheckCalledFromGeneratedFile()
- return _message.Message._GetMessageDescriptor(full_name)
+ return _message.default_pool.FindMessageTypeByName(full_name)
# NOTE(tmarek): The file argument redefining a builtin is nothing we can
# fix right now since we don't know how many clients already rely on the
@@ -495,9 +492,9 @@ class FieldDescriptor(DescriptorBase):
has_default_value=True, containing_oneof=None):
_message.Message._CheckCalledFromGeneratedFile()
if is_extension:
- return _message.Message._GetExtensionDescriptor(full_name)
+ return _message.default_pool.FindExtensionByName(full_name)
else:
- return _message.Message._GetFieldDescriptor(full_name)
+ return _message.default_pool.FindFieldByName(full_name)
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
@@ -528,14 +525,9 @@ class FieldDescriptor(DescriptorBase):
self.containing_oneof = containing_oneof
if api_implementation.Type() == 'cpp':
if is_extension:
- # pylint: disable=protected-access
- self._cdescriptor = (
- _message.Message._GetExtensionDescriptor(full_name))
- # pylint: enable=protected-access
+ self._cdescriptor = _message.default_pool.FindExtensionByName(full_name)
else:
- # pylint: disable=protected-access
- self._cdescriptor = _message.Message._GetFieldDescriptor(full_name)
- # pylint: enable=protected-access
+ self._cdescriptor = _message.default_pool.FindFieldByName(full_name)
else:
self._cdescriptor = None
@@ -592,7 +584,7 @@ class EnumDescriptor(_NestedDescriptorBase):
containing_type=None, options=None, file=None,
serialized_start=None, serialized_end=None):
_message.Message._CheckCalledFromGeneratedFile()
- return _message.Message._GetEnumDescriptor(full_name)
+ return _message.default_pool.FindEnumTypeByName(full_name)
def __init__(self, name, full_name, filename, values,
containing_type=None, options=None, file=None,
@@ -677,7 +669,7 @@ class OneofDescriptor(object):
def __new__(cls, name, full_name, index, containing_type, fields):
_message.Message._CheckCalledFromGeneratedFile()
- return _message.Message._GetOneofDescriptor(full_name)
+ return _message.default_pool.FindOneofByName(full_name)
def __init__(self, name, full_name, index, containing_type, fields):
"""Arguments are as described in the attribute description above."""
@@ -788,12 +780,8 @@ class FileDescriptor(DescriptorBase):
dependencies=None, syntax=None):
# FileDescriptor() is called from various places, not only from generated
# files, to register dynamic proto files and messages.
- # TODO(amauryfa): Expose BuildFile() as a public function and make this
- # constructor an implementation detail.
if serialized_pb:
- # pylint: disable=protected-access2
- return _message.Message._BuildFile(serialized_pb)
- # pylint: enable=protected-access
+ return _message.default_pool.AddSerializedFile(serialized_pb)
else:
return super(FileDescriptor, cls).__new__(cls)
@@ -814,9 +802,7 @@ class FileDescriptor(DescriptorBase):
if (api_implementation.Type() == 'cpp' and
self.serialized_pb is not None):
- # pylint: disable=protected-access
- _message.Message._BuildFile(self.serialized_pb)
- # pylint: enable=protected-access
+ _message.default_pool.AddSerializedFile(self.serialized_pb)
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.FileDescriptorProto.
@@ -864,10 +850,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
file_descriptor_proto.message_type.add().MergeFrom(desc_proto)
- # Generate a random name for this proto file to prevent conflicts with
- # any imported ones. We need to specify a file name so BuildFile accepts
- # our FileDescriptorProto, but it is not important what that file name
- # is actually set to.
+ # Generate a random name for this proto file to prevent conflicts with any
+ # imported ones. We need to specify a file name so the descriptor pool
+ # accepts our FileDescriptorProto, but it is not important what that file
+ # name is actually set to.
proto_name = str(uuid.uuid4())
if package:
@@ -877,10 +863,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
else:
file_descriptor_proto.name = proto_name + '.proto'
- # pylint: disable=protected-access
- result = _message.Message._BuildFile(
- file_descriptor_proto.SerializeToString())
- # pylint: enable=protected-access
+ _message.default_pool.Add(file_descriptor_proto)
+ result = _message.default_pool.FindFileByName(file_descriptor_proto.name)
if _USE_C_DESCRIPTORS:
return result.message_types_by_name[desc_proto.name]
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 7e7701f8..1244ba7c 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -113,6 +113,20 @@ class DescriptorPool(object):
self._internal_db.Add(file_desc_proto)
+ def AddSerializedFile(self, serialized_file_desc_proto):
+ """Adds the FileDescriptorProto and its types to this pool.
+
+ Args:
+ serialized_file_desc_proto: A bytes string, serialization of the
+ FileDescriptorProto to add.
+ """
+
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf import descriptor_pb2
+ file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
+ serialized_file_desc_proto)
+ self.Add(file_desc_proto)
+
def AddDescriptor(self, desc):
"""Adds a Descriptor to the pool, non-recursively.
@@ -320,17 +334,17 @@ class DescriptorPool(object):
file_descriptor, None, scope))
for index, extension_proto in enumerate(file_proto.extension):
- extension_desc = self.MakeFieldDescriptor(
+ extension_desc = self._MakeFieldDescriptor(
extension_proto, file_proto.package, index, is_extension=True)
extension_desc.containing_type = self._GetTypeFromScope(
file_descriptor.package, extension_proto.extendee, scope)
- self.SetFieldType(extension_proto, extension_desc,
+ self._SetFieldType(extension_proto, extension_desc,
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
for desc_proto in file_proto.message_type:
- self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
if file_proto.package:
desc_proto_prefix = _PrefixWithDot(file_proto.package)
@@ -381,10 +395,11 @@ class DescriptorPool(object):
enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
for enum in desc_proto.enum_type]
- fields = [self.MakeFieldDescriptor(field, desc_name, index)
+ fields = [self._MakeFieldDescriptor(field, desc_name, index)
for index, field in enumerate(desc_proto.field)]
extensions = [
- self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True)
+ self._MakeFieldDescriptor(extension, desc_name, index,
+ is_extension=True)
for index, extension in enumerate(desc_proto.extension)]
oneofs = [
descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
@@ -464,8 +479,8 @@ class DescriptorPool(object):
self._enum_descriptors[enum_name] = desc
return desc
- def MakeFieldDescriptor(self, field_proto, message_name, index,
- is_extension=False):
+ def _MakeFieldDescriptor(self, field_proto, message_name, index,
+ is_extension=False):
"""Creates a field descriptor from a FieldDescriptorProto.
For message and enum type fields, this method will do a look up
@@ -506,7 +521,7 @@ class DescriptorPool(object):
extension_scope=None,
options=field_proto.options)
- def SetAllFieldTypes(self, package, desc_proto, scope):
+ def _SetAllFieldTypes(self, package, desc_proto, scope):
"""Sets all the descriptor's fields's types.
This method also sets the containing types on any extensions.
@@ -527,18 +542,18 @@ class DescriptorPool(object):
nested_package = '.'.join([package, desc_proto.name])
for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
- self.SetFieldType(field_proto, field_desc, nested_package, scope)
+ self._SetFieldType(field_proto, field_desc, nested_package, scope)
for extension_proto, extension_desc in (
zip(desc_proto.extension, main_desc.extensions)):
extension_desc.containing_type = self._GetTypeFromScope(
nested_package, extension_proto.extendee, scope)
- self.SetFieldType(extension_proto, extension_desc, nested_package, scope)
+ self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
for nested_type in desc_proto.nested_type:
- self.SetAllFieldTypes(nested_package, nested_type, scope)
+ self._SetAllFieldTypes(nested_package, nested_type, scope)
- def SetFieldType(self, field_proto, field_desc, package, scope):
+ def _SetFieldType(self, field_proto, field_desc, package, scope):
"""Sets the field's type, cpp_type, message_type and enum_type.
Args:
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index d976f9e1..72c2fa01 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -41,6 +41,146 @@ are:
__author__ = 'petar@google.com (Petar Petrov)'
+import sys
+
+if sys.version_info[0] < 3:
+ # We would use collections.MutableMapping all the time, but in Python 2 it
+ # doesn't define __slots__. This causes two significant problems:
+ #
+ # 1. we can't disallow arbitrary attribute assignment, even if our derived
+ # classes *do* define __slots__.
+ #
+ # 2. we can't safely derive a C type from it without __slots__ defined (the
+ # interpreter expects to find a dict at tp_dictoffset, which we can't
+ # robustly provide. And we don't want an instance dict anyway.
+ #
+ # So this is the Python 2.7 definition of Mapping/MutableMapping functions
+ # verbatim, except that:
+ # 1. We declare __slots__.
+ # 2. We don't declare this as a virtual base class. The classes defined
+ # in collections are the interesting base classes, not us.
+ #
+ # Note: deriving from object is critical. It is the only thing that makes
+ # this a true type, allowing us to derive from it in C++ cleanly and making
+ # __slots__ properly disallow arbitrary element assignment.
+ from collections import Mapping as _Mapping
+
+ class Mapping(object):
+ __slots__ = ()
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def __contains__(self, key):
+ try:
+ self[key]
+ except KeyError:
+ return False
+ else:
+ return True
+
+ def iterkeys(self):
+ return iter(self)
+
+ def itervalues(self):
+ for key in self:
+ yield self[key]
+
+ def iteritems(self):
+ for key in self:
+ yield (key, self[key])
+
+ def keys(self):
+ return list(self)
+
+ def items(self):
+ return [(key, self[key]) for key in self]
+
+ def values(self):
+ return [self[key] for key in self]
+
+ # Mappings are not hashable by default, but subclasses can change this
+ __hash__ = None
+
+ def __eq__(self, other):
+ if not isinstance(other, _Mapping):
+ return NotImplemented
+ return dict(self.items()) == dict(other.items())
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ class MutableMapping(Mapping):
+ __slots__ = ()
+
+ __marker = object()
+
+ def pop(self, key, default=__marker):
+ try:
+ value = self[key]
+ except KeyError:
+ if default is self.__marker:
+ raise
+ return default
+ else:
+ del self[key]
+ return value
+
+ def popitem(self):
+ try:
+ key = next(iter(self))
+ except StopIteration:
+ raise KeyError
+ value = self[key]
+ del self[key]
+ return key, value
+
+ def clear(self):
+ try:
+ while True:
+ self.popitem()
+ except KeyError:
+ pass
+
+ def update(*args, **kwds):
+ if len(args) > 2:
+ raise TypeError("update() takes at most 2 positional "
+ "arguments ({} given)".format(len(args)))
+ elif not args:
+ raise TypeError("update() takes at least 1 argument (0 given)")
+ self = args[0]
+ other = args[1] if len(args) >= 2 else ()
+
+ if isinstance(other, Mapping):
+ for key in other:
+ self[key] = other[key]
+ elif hasattr(other, "keys"):
+ for key in other.keys():
+ self[key] = other[key]
+ else:
+ for key, value in other:
+ self[key] = value
+ for key, value in kwds.items():
+ self[key] = value
+
+ def setdefault(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ self[key] = default
+ return default
+
+ _Mapping.register(Mapping)
+
+else:
+ # In Python 3 we can just use MutableMapping directly, because it defines
+ # __slots__.
+ from collections import MutableMapping
+
+
class BaseContainer(object):
"""Base container class."""
@@ -286,3 +426,160 @@ class RepeatedCompositeFieldContainer(BaseContainer):
raise TypeError('Can only compare repeated composite fields against '
'other repeated composite fields.')
return self._values == other._values
+
+
+class ScalarMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for holding repeated scalars."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener']
+
+ def __init__(self, message_listener, key_checker, value_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._key_checker = key_checker
+ self._value_checker = value_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ val = self._value_checker.DefaultValue()
+ self._values[key] = val
+ return val
+
+ def __contains__(self, item):
+ return item in self._values
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __setitem__(self, key, value):
+ checked_key = self._key_checker.CheckValue(key)
+ checked_value = self._value_checker.CheckValue(value)
+ self._values[checked_key] = checked_value
+ self._message_listener.Modified()
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def MergeFrom(self, other):
+ self._values.update(other._values)
+ self._message_listener.Modified()
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
+
+
+class MessageMap(MutableMapping):
+
+ """Simple, type-checked, dict-like container for with submessage values."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_key_checker', '_values', '_message_listener',
+ '_message_descriptor']
+
+ def __init__(self, message_listener, message_descriptor, key_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The ScalarMap will call this object's Modified() method when it
+ is modified.
+ key_checker: A type_checkers.ValueChecker instance to run on keys
+ inserted into this container.
+ value_checker: A type_checkers.ValueChecker instance to run on values
+ inserted into this container.
+ """
+ self._message_listener = message_listener
+ self._message_descriptor = message_descriptor
+ self._key_checker = key_checker
+ self._values = {}
+
+ def __getitem__(self, key):
+ try:
+ return self._values[key]
+ except KeyError:
+ key = self._key_checker.CheckValue(key)
+ new_element = self._message_descriptor._concrete_class()
+ new_element._SetListener(self._message_listener)
+ self._values[key] = new_element
+ self._message_listener.Modified()
+
+ return new_element
+
+ def get_or_create(self, key):
+ """get_or_create() is an alias for getitem (ie. map[key]).
+
+ Args:
+ key: The key to get or create in the map.
+
+ This is useful in cases where you want to be explicit that the call is
+ mutating the map. This can avoid lint errors for statements like this
+ that otherwise would appear to be pointless statements:
+
+ msg.my_map[key]
+ """
+ return self[key]
+
+ # We need to override this explicitly, because our defaultdict-like behavior
+ # will make the default implementation (from our base class) always insert
+ # the key.
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def __contains__(self, item):
+ return item in self._values
+
+ def __setitem__(self, key, value):
+ raise ValueError('May not set values directly, call my_map[key].foo = 5')
+
+ def __delitem__(self, key):
+ del self._values[key]
+ self._message_listener.Modified()
+
+ def __len__(self):
+ return len(self._values)
+
+ def __iter__(self):
+ return iter(self._values)
+
+ def MergeFrom(self, other):
+ for key in other:
+ self[key].MergeFrom(other[key])
+ # self._message_listener.Modified() not required here, because
+ # mutations to submessages already propagate.
+
+ # This is defined in the abstract base, but we can do it much more cheaply.
+ def clear(self):
+ self._values.clear()
+ self._message_listener.Modified()
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 0f500606..3837eaea 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -733,6 +733,50 @@ def MessageSetItemDecoder(extensions_by_number):
return DecodeItem
# --------------------------------------------------------------------
+
+def MapDecoder(field_descriptor, new_default, is_message_map):
+ """Returns a decoder for a map field."""
+
+ key = field_descriptor
+ tag_bytes = encoder.TagBytes(field_descriptor.number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ local_DecodeVarint = _DecodeVarint
+ # Can't read _concrete_class yet; might not be initialized.
+ message_type = field_descriptor.message_type
+
+ def DecodeMap(buffer, pos, end, message, field_dict):
+ submsg = message_type._concrete_class()
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ # Read length.
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ # Read sub-message.
+ submsg.Clear()
+ if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+
+ if is_message_map:
+ value[submsg.key].MergeFrom(submsg.value)
+ else:
+ value[submsg.key] = submsg.value
+
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+
+ return DecodeMap
+
+# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index 56fe14e9..8416e157 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -35,7 +35,6 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
-
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf import descriptor_database
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 7d145f42..d159cc62 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -37,6 +37,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)'
import os
import unittest
+import unittest
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import api_implementation
@@ -226,6 +227,13 @@ class DescriptorPoolTest(unittest.TestCase):
db.Add(self.factory_test2_fd)
self.testFindMessageTypeByName()
+ def testAddSerializedFile(self):
+ db = descriptor_database.DescriptorDatabase()
+ self.pool = descriptor_pool.DescriptorPool(db)
+ self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString())
+ self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString())
+ self.testFindMessageTypeByName()
+
def testComplexNesting(self):
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
@@ -510,6 +518,43 @@ class AddDescriptorTest(unittest.TestCase):
'protobuf_unittest.TestAllTypes')
+@unittest.skipIf(
+ api_implementation.Type() != 'cpp',
+ 'default_pool is only supported by the C++ implementation')
+class DefaultPoolTest(unittest.TestCase):
+
+ def testFindMethods(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ self.assertIs(
+ pool.FindFileByName('google/protobuf/unittest.proto'),
+ unittest_pb2.DESCRIPTOR)
+ self.assertIs(
+ pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertIs(
+ pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertIs(
+ pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
+ self.assertIs(
+ pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
+ unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertIs(
+ pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
+ unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+
+ def testAddFileDescriptor(self):
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.pyext import _message
+ pool = _message.default_pool
+ file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
+ pool.Add(file_desc)
+ pool.AddSerializedFile(file_desc.SerializeToString())
+
+
TEST1_FILE = ProtoFile(
'google/protobuf/internal/descriptor_pool_test1.proto',
'google.protobuf.python.internal',
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 335caee6..26866f3a 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -35,8 +35,8 @@
__author__ = 'robinson@google.com (Will Robinson)'
import sys
-import unittest
+import unittest
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 38a5138a..752f4eab 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -314,7 +314,7 @@ def MessageSizer(field_number, is_repeated, is_packed):
# --------------------------------------------------------------------
-# MessageSet is special.
+# MessageSet is special: it needs custom logic to compute its size properly.
def MessageSetItemSizer(field_number):
@@ -339,6 +339,32 @@ def MessageSetItemSizer(field_number):
return FieldSize
+# --------------------------------------------------------------------
+# Map is special: it needs custom logic to compute its size properly.
+
+
+def MapSizer(field_descriptor):
+ """Returns a sizer for a map field."""
+
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ message_sizer = MessageSizer(field_descriptor.number, False, False)
+
+ def FieldSize(map_value):
+ total = 0
+ for key in map_value:
+ value = map_value[key]
+ # It's wasteful to create the messages and throw them away one second
+ # later since we'll do the same for the actual encode. But there's not an
+ # obvious way to avoid this within the current design without tons of code
+ # duplication.
+ entry_msg = message_type._concrete_class(key=key, value=value)
+ total += message_sizer(entry_msg)
+ return total
+
+ return FieldSize
+
# ====================================================================
# Encoders!
@@ -786,3 +812,30 @@ def MessageSetItemEncoder(field_number):
return write(end_bytes)
return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, Map is special.
+
+
+def MapEncoder(field_descriptor):
+ """Encoder for extensions of MessageSet.
+
+ Maps always have a wire format like this:
+ message MapEntry {
+ key_type key = 1;
+ value_type value = 2;
+ }
+ repeated MapEntry map = N;
+ """
+ # Can't look at field_descriptor.message_type._concrete_class because it may
+ # not have been initialized yet.
+ message_type = field_descriptor.message_type
+ encode_message = MessageEncoder(field_descriptor.number, False, False)
+
+ def EncodeField(write, value):
+ for key in value:
+ entry_msg = message_type._concrete_class(key=key, value=value[key])
+ encode_message(write, entry_msg)
+
+ return EncodeField
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index ccc5860b..5c07cbe6 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -42,7 +42,6 @@ further ensures that we can use Python protocol message objects as we expect.
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
-
from google.protobuf.internal import test_bad_identifiers_pb2
from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 626c3fc9..b8694f96 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -35,7 +35,6 @@
__author__ = 'matthewtoia@google.com (Matt Toia)'
import unittest
-
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 4ecaa1c7..320ff0d2 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -50,7 +50,9 @@ import pickle
import sys
import unittest
+import unittest
from google.protobuf.internal import _parameterized
+from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
@@ -125,10 +127,17 @@ class MessageTest(unittest.TestCase):
self.assertEqual(unpickled_message, golden_message)
def testPositiveInfinity(self, message_module):
- golden_data = (b'\x5D\x00\x00\x80\x7F'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
- b'\xCD\x02\x00\x00\x80\x7F'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCD\x02\x00\x00\x80\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCA\x02\x04\x00\x00\x80\x7F'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
+
golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsPosInf(golden_message.optional_float))
@@ -138,10 +147,17 @@ class MessageTest(unittest.TestCase):
self.assertEqual(golden_data, golden_message.SerializeToString())
def testNegativeInfinity(self, message_module):
- golden_data = (b'\x5D\x00\x00\x80\xFF'
- b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
- b'\xCD\x02\x00\x00\x80\xFF'
- b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ if message_module is unittest_pb2:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCD\x02\x00\x00\x80\xFF'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ else:
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCA\x02\x04\x00\x00\x80\xFF'
+ b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
+
golden_message = message_module.TestAllTypes()
golden_message.ParseFromString(golden_data)
self.assertTrue(IsNegInf(golden_message.optional_float))
@@ -1034,64 +1050,132 @@ class Proto2Test(unittest.TestCase):
self.assertEqual(len(parsing_merge.Extensions[
unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+ def testPythonicInit(self):
+ message = unittest_pb2.TestAllTypes(
+ optional_int32=100,
+ optional_fixed32=200,
+ optional_float=300.5,
+ optional_bytes=b'x',
+ optionalgroup={'a': 400},
+ optional_nested_message={'bb': 500},
+ optional_nested_enum='BAZ',
+ repeatedgroup=[{'a': 600},
+ {'a': 700}],
+ repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
+ default_int32=800,
+ oneof_string='y')
+ self.assertTrue(isinstance(message, unittest_pb2.TestAllTypes))
+ self.assertEqual(100, message.optional_int32)
+ self.assertEqual(200, message.optional_fixed32)
+ self.assertEqual(300.5, message.optional_float)
+ self.assertEqual(b'x', message.optional_bytes)
+ self.assertEqual(400, message.optionalgroup.a)
+ self.assertTrue(isinstance(message.optional_nested_message,
+ unittest_pb2.TestAllTypes.NestedMessage))
+ self.assertEqual(500, message.optional_nested_message.bb)
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ self.assertEqual(2, len(message.repeatedgroup))
+ self.assertEqual(600, message.repeatedgroup[0].a)
+ self.assertEqual(700, message.repeatedgroup[1].a)
+ self.assertEqual(2, len(message.repeated_nested_enum))
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.repeated_nested_enum[0])
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[1])
+ self.assertEqual(800, message.default_int32)
+ self.assertEqual('y', message.oneof_string)
+ self.assertFalse(message.HasField('optional_int64'))
+ self.assertEqual(0, len(message.repeated_float))
+ self.assertEqual(42, message.default_int64)
+
+ message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'INVALID_NESTED_FIELD': 17})
+
+ with self.assertRaises(TypeError):
+ unittest_pb2.TestAllTypes(
+ optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
+
+ with self.assertRaises(ValueError):
+ unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+
# Class to test proto3-only features/behavior (updated field presence & enums)
class Proto3Test(unittest.TestCase):
+ # Utility method for comparing equality with a map.
+ def assertMapIterEquals(self, map_iter, dict_value):
+ # Avoid mutating caller's copy.
+ dict_value = dict(dict_value)
+
+ for k, v in map_iter:
+ self.assertEqual(v, dict_value[k])
+ del dict_value[k]
+
+ self.assertEqual({}, dict_value)
+
def testFieldPresence(self):
message = unittest_proto3_arena_pb2.TestAllTypes()
# We can't test presence of non-repeated, non-submessage fields.
with self.assertRaises(ValueError):
- message.HasField("optional_int32")
+ message.HasField('optional_int32')
with self.assertRaises(ValueError):
- message.HasField("optional_float")
+ message.HasField('optional_float')
with self.assertRaises(ValueError):
- message.HasField("optional_string")
+ message.HasField('optional_string')
with self.assertRaises(ValueError):
- message.HasField("optional_bool")
+ message.HasField('optional_bool')
# But we can still test presence of submessage fields.
- self.assertFalse(message.HasField("optional_nested_message"))
+ self.assertFalse(message.HasField('optional_nested_message'))
# As with proto2, we can't test presence of fields that don't exist, or
# repeated fields.
with self.assertRaises(ValueError):
- message.HasField("field_doesnt_exist")
+ message.HasField('field_doesnt_exist')
with self.assertRaises(ValueError):
- message.HasField("repeated_int32")
+ message.HasField('repeated_int32')
with self.assertRaises(ValueError):
- message.HasField("repeated_nested_message")
+ message.HasField('repeated_nested_message')
# Fields should default to their type-specific default.
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
- self.assertEqual("", message.optional_string)
+ self.assertEqual('', message.optional_string)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
# Setting a submessage should still return proper presence information.
message.optional_nested_message.bb = 0
- self.assertTrue(message.HasField("optional_nested_message"))
+ self.assertTrue(message.HasField('optional_nested_message'))
# Set the fields to non-default values.
message.optional_int32 = 5
message.optional_float = 1.1
- message.optional_string = "abc"
+ message.optional_string = 'abc'
message.optional_bool = True
message.optional_nested_message.bb = 15
# Clearing the fields unsets them and resets their value to default.
- message.ClearField("optional_int32")
- message.ClearField("optional_float")
- message.ClearField("optional_string")
- message.ClearField("optional_bool")
- message.ClearField("optional_nested_message")
+ message.ClearField('optional_int32')
+ message.ClearField('optional_float')
+ message.ClearField('optional_string')
+ message.ClearField('optional_bool')
+ message.ClearField('optional_nested_message')
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
- self.assertEqual("", message.optional_string)
+ self.assertEqual('', message.optional_string)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
@@ -1113,6 +1197,393 @@ class Proto3Test(unittest.TestCase):
self.assertEqual(1234567, m2.optional_nested_enum)
self.assertEqual(7654321, m2.repeated_nested_enum[0])
+ # Map isn't really a proto3-only feature. But there is no proto2 equivalent
+ # of google/protobuf/map_unittest.proto right now, so it's not easy to
+ # test both with the same test like we do for the other proto2/proto3 tests.
+ # (google/protobuf/map_protobuf_unittest.proto is very different in the set
+ # of messages and fields it contains).
+ def testScalarMapDefaults(self):
+ msg = map_unittest_pb2.TestMap()
+
+ # Scalars start out unset.
+ self.assertFalse(-123 in msg.map_int32_int32)
+ self.assertFalse(-2**33 in msg.map_int64_int64)
+ self.assertFalse(123 in msg.map_uint32_uint32)
+ self.assertFalse(2**33 in msg.map_uint64_uint64)
+ self.assertFalse('abc' in msg.map_string_string)
+ self.assertFalse(888 in msg.map_int32_enum)
+
+ # Accessing an unset key returns the default.
+ self.assertEqual(0, msg.map_int32_int32[-123])
+ self.assertEqual(0, msg.map_int64_int64[-2**33])
+ self.assertEqual(0, msg.map_uint32_uint32[123])
+ self.assertEqual(0, msg.map_uint64_uint64[2**33])
+ self.assertEqual('', msg.map_string_string['abc'])
+ self.assertEqual(0, msg.map_int32_enum[888])
+
+ # It also sets the value in the map
+ self.assertTrue(-123 in msg.map_int32_int32)
+ self.assertTrue(-2**33 in msg.map_int64_int64)
+ self.assertTrue(123 in msg.map_uint32_uint32)
+ self.assertTrue(2**33 in msg.map_uint64_uint64)
+ self.assertTrue('abc' in msg.map_string_string)
+ self.assertTrue(888 in msg.map_int32_enum)
+
+ self.assertTrue(isinstance(msg.map_string_string['abc'], unicode))
+
+ # Accessing an unset key still throws TypeError of the type of the key
+ # is incorrect.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123]
+
+ self.assertFalse(123 in msg.map_string_string)
+
+ def testMapGet(self):
+ # Need to test that get() properly returns the default, even though the dict
+ # has defaultdict-like semantics.
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertIsNone(msg.map_int32_int32.get(5))
+ self.assertEquals(10, msg.map_int32_int32.get(5, 10))
+ self.assertIsNone(msg.map_int32_int32.get(5))
+
+ msg.map_int32_int32[5] = 15
+ self.assertEquals(15, msg.map_int32_int32.get(5))
+
+ self.assertIsNone(msg.map_int32_foreign_message.get(5))
+ self.assertEquals(10, msg.map_int32_foreign_message.get(5, 10))
+
+ submsg = msg.map_int32_foreign_message[5]
+ self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
+
+ def testScalarMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+ self.assertFalse(5 in msg.map_int32_int32)
+
+ msg.map_int32_int32[-123] = -456
+ msg.map_int64_int64[-2**33] = -2**34
+ msg.map_uint32_uint32[123] = 456
+ msg.map_uint64_uint64[2**33] = 2**34
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_enum[888] = 2
+
+ self.assertEqual([], msg.FindInitializationErrors())
+
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_string_string[123] = '123'
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(1, len(msg.map_string_string))
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg.map_string_string['123'] = 123
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string[123] = '123'
+
+ # Bad value.
+ with self.assertRaises(TypeError):
+ msg2.map_string_string['123'] = 123
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+ self.assertEqual('123', msg2.map_string_string['abc'])
+ self.assertEqual(2, msg2.map_int32_enum[888])
+
+ def testStringUnicodeConversionInMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ unicode_obj = u'\u1234'
+ bytes_obj = unicode_obj.encode('utf8')
+
+ msg.map_string_string[bytes_obj] = bytes_obj
+
+ (key, value) = msg.map_string_string.items()[0]
+
+ self.assertEqual(key, unicode_obj)
+ self.assertEqual(value, unicode_obj)
+
+ self.assertTrue(isinstance(key, unicode))
+ self.assertTrue(isinstance(value, unicode))
+
+ def testMessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_foreign_message))
+ self.assertFalse(5 in msg.map_int32_foreign_message)
+
+ msg.map_int32_foreign_message[123]
+ # get_or_create() is an alias for getitem.
+ msg.map_int32_foreign_message.get_or_create(-456)
+
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+ self.assertIn(123, msg.map_int32_foreign_message)
+ self.assertIn(-456, msg.map_int32_foreign_message)
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ # Bad key.
+ with self.assertRaises(TypeError):
+ msg.map_int32_foreign_message['123']
+
+ # Can't assign directly to submessage.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
+
+ # Verify that trying to assign a bad key doesn't actually add a member to
+ # the map.
+ self.assertEqual(2, len(msg.map_int32_foreign_message))
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ self.assertIn(123, msg2.map_int32_foreign_message)
+ self.assertIn(-456, msg2.map_int32_foreign_message)
+ self.assertEqual(2, len(msg2.map_int32_foreign_message))
+
+ def testMergeFrom(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[12] = 34
+ msg.map_int32_int32[56] = 78
+ msg.map_int64_int64[22] = 33
+ msg.map_int32_foreign_message[111].c = 5
+ msg.map_int32_foreign_message[222].c = 10
+
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.map_int32_int32[12] = 55
+ msg2.map_int64_int64[88] = 99
+ msg2.map_int32_foreign_message[222].c = 15
+
+ msg2.MergeFrom(msg)
+
+ self.assertEqual(34, msg2.map_int32_int32[12])
+ self.assertEqual(78, msg2.map_int32_int32[56])
+ self.assertEqual(33, msg2.map_int64_int64[22])
+ self.assertEqual(99, msg2.map_int64_int64[88])
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+ self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
+
+ # Verify that there is only one entry per key, even though the MergeFrom
+ # may have internally created multiple entries for a single key in the
+ # list representation.
+ as_dict = {}
+ for key in msg2.map_int32_foreign_message:
+ self.assertFalse(key in as_dict)
+ as_dict[key] = msg2.map_int32_foreign_message[key].c
+
+ self.assertEqual({111: 5, 222: 10}, as_dict)
+
+ # Special case: test that delete of item really removes the item, even if
+ # there might have physically been duplicate keys due to the previous merge.
+ # This is only a special case for the C++ implementation which stores the
+ # map as an array.
+ del msg2.map_int32_int32[12]
+ self.assertFalse(12 in msg2.map_int32_int32)
+
+ del msg2.map_int32_foreign_message[222]
+ self.assertFalse(222 in msg2.map_int32_foreign_message)
+
+ def testIntegerMapWithLongs(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_int32[long(-123)] = long(-456)
+ msg.map_int64_int64[long(-2**33)] = long(-2**34)
+ msg.map_uint32_uint32[long(123)] = long(456)
+ msg.map_uint64_uint64[long(2**33)] = long(2**34)
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(-456, msg2.map_int32_int32[-123])
+ self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
+ self.assertEqual(456, msg2.map_uint32_uint32[123])
+ self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
+
+ def testMapAssignmentCausesPresence(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_int32[123] = 456
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_int32[888] = 999
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_int32.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testMapAssignmentCausesPresenceForSubmessages(self):
+ msg = map_unittest_pb2.TestMapSubmessage()
+ msg.test_map.map_int32_foreign_message[123].c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMapSubmessage()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(msg, msg2)
+
+ # Now test that various mutations of the map properly invalidate the
+ # cached size of the submessage.
+ msg.test_map.map_int32_foreign_message[888].c = 7
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message[888].MergeFrom(
+ msg.test_map.map_int32_foreign_message[123])
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ msg.test_map.map_int32_foreign_message.clear()
+ serialized = msg.SerializeToString()
+ msg2.ParseFromString(serialized)
+ self.assertEqual(msg, msg2)
+
+ def testModifyMapWhileIterating(self):
+ msg = map_unittest_pb2.TestMap()
+
+ string_string_iter = iter(msg.map_string_string)
+ int32_foreign_iter = iter(msg.map_int32_foreign_message)
+
+ msg.map_string_string['abc'] = '123'
+ msg.map_int32_foreign_message[5].c = 5
+
+ with self.assertRaises(RuntimeError):
+ for key in string_string_iter:
+ pass
+
+ with self.assertRaises(RuntimeError):
+ for key in int32_foreign_iter:
+ pass
+
+ def testSubmessageMap(self):
+ msg = map_unittest_pb2.TestMap()
+
+ submsg = msg.map_int32_foreign_message[111]
+ self.assertIs(submsg, msg.map_int32_foreign_message[111])
+ self.assertTrue(isinstance(submsg, unittest_pb2.ForeignMessage))
+
+ submsg.c = 5
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+
+ self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
+
+ # Doesn't allow direct submessage assignment.
+ with self.assertRaises(ValueError):
+ msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
+
+ def testMapIteration(self):
+ msg = map_unittest_pb2.TestMap()
+
+ for k, v in msg.map_int32_int32.iteritems():
+ # Should not be reached.
+ self.assertTrue(False)
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+ self.assertEqual(3, len(msg.map_int32_int32))
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(msg.map_int32_int32.iteritems(), matching_dict)
+
+ def testMapIterationClearMessage(self):
+ # Iterator needs to work even if message and map are deleted.
+ msg = map_unittest_pb2.TestMap()
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+
+ it = msg.map_int32_int32.iteritems()
+ del msg
+
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(it, matching_dict)
+
+ def testMapConstruction(self):
+ msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
+ self.assertEqual(2, msg.map_int32_int32[1])
+ self.assertEqual(4, msg.map_int32_int32[3])
+
+ msg = map_unittest_pb2.TestMap(
+ map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
+ self.assertEqual(5, msg.map_int32_foreign_message[3].c)
+
+ def testMapValidAfterFieldCleared(self):
+ # Map needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ map = msg.map_int32_int32
+
+ map[2] = 4
+ map[3] = 6
+ map[4] = 8
+
+ msg.ClearField('map_int32_int32')
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(map.iteritems(), matching_dict)
+
+ def testMapIterValidAfterFieldCleared(self):
+ # Map iterator needs to work even if field is cleared.
+ # For the C++ implementation this tests the correctness of
+ # ScalarMapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+
+ msg.map_int32_int32[2] = 4
+ msg.map_int32_int32[3] = 6
+ msg.map_int32_int32[4] = 8
+
+ it = msg.map_int32_int32.iteritems()
+
+ msg.ClearField('map_int32_int32')
+ matching_dict = {2: 4, 3: 6, 4: 8}
+ self.assertMapIterEquals(it, matching_dict)
+
+ def testMapDelete(self):
+ msg = map_unittest_pb2.TestMap()
+
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+ msg.map_int32_int32[4] = 6
+ self.assertEqual(1, len(msg.map_int32_int32))
+
+ with self.assertRaises(KeyError):
+ del msg.map_int32_int32[88]
+
+ del msg.map_int32_int32[4]
+ self.assertEqual(0, len(msg.map_int32_int32))
+
+
class ValidTypeNamesTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py
index b1e57f35..edaf3fa3 100644
--- a/python/google/protobuf/internal/proto_builder_test.py
+++ b/python/google/protobuf/internal/proto_builder_test.py
@@ -32,6 +32,7 @@
"""Tests for google.protobuf.proto_builder."""
+import collections
import unittest
from google.protobuf import descriptor_pb2
@@ -43,10 +44,11 @@ from google.protobuf import text_format
class ProtoBuilderTest(unittest.TestCase):
def setUp(self):
- self._fields = {
- 'foo': descriptor_pb2.FieldDescriptorProto.TYPE_INT64,
- 'bar': descriptor_pb2.FieldDescriptorProto.TYPE_STRING,
- }
+ self.ordered_fields = collections.OrderedDict([
+ ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64),
+ ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING),
+ ])
+ self._fields = dict(self.ordered_fields)
def testMakeSimpleProtoClass(self):
"""Test that we can create a proto class."""
@@ -59,6 +61,17 @@ class ProtoBuilderTest(unittest.TestCase):
self.assertMultiLineEqual(
'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto))
+ def testOrderedFields(self):
+ """Test that the field order is maintained when given an OrderedDict."""
+ proto_cls = proto_builder.MakeSimpleProtoClass(
+ self.ordered_fields,
+ full_name='net.proto2.python.public.proto_builder_test.OrderedTest')
+ proto = proto_cls()
+ proto.foo = 12345
+ proto.bar = 'asdf'
+ self.assertMultiLineEqual(
+ 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto))
+
def testMakeSameProtoClassTwice(self):
"""Test that the DescriptorPool is used."""
pool = descriptor_pool.DescriptorPool()
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 54f584ae..ca9f7675 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -61,9 +61,11 @@ if sys.version_info[0] < 3:
except ImportError:
from StringIO import StringIO as BytesIO
import copy_reg as copyreg
+ _basestring = basestring
else:
from io import BytesIO
import copyreg
+ _basestring = str
import struct
import weakref
@@ -77,6 +79,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -101,6 +104,7 @@ def InitMessage(descriptor, cls):
for field in descriptor.fields:
_AttachFieldHelpers(cls, field)
+ descriptor._concrete_class = cls # pylint: disable=protected-access
_AddEnumValues(descriptor, cls)
_AddInitMethod(descriptor, cls)
_AddPropertiesForFields(descriptor, cls)
@@ -198,12 +202,37 @@ def _IsMessageSetExtension(field):
field.label == _FieldDescriptor.LABEL_OPTIONAL)
+def _IsMapField(field):
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _IsMessageMapField(field):
+ value_type = field.message_type.fields_by_name["value"]
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
+ is_packable = (is_repeated and
+ wire_format.IsTypePackable(field_descriptor.type))
+ if not is_packable:
+ is_packed = False
+ elif field_descriptor.containing_type.syntax == "proto2":
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+ else:
+ has_packed_false = (field_descriptor.has_options and
+ field_descriptor.GetOptions().HasField("packed") and
+ field_descriptor.GetOptions().packed == False)
+ is_packed = not has_packed_false
+ is_map_entry = _IsMapField(field_descriptor)
+
+ if is_map_entry:
+ field_encoder = encoder.MapEncoder(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor)
+ elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -228,9 +257,16 @@ def _AttachFieldHelpers(cls, field_descriptor):
if field_descriptor.containing_oneof is not None:
oneof_descriptor = field_descriptor
- field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor)
+ if is_map_entry:
+ is_message_map = _IsMessageMapField(field_descriptor)
+
+ field_decoder = decoder.MapDecoder(
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
+ is_message_map)
+ else:
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
@@ -265,6 +301,26 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
+def _GetInitializeDefaultForMap(field):
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ raise ValueError('map_entry set on non-repeated field %s' % (
+ field.name))
+ fields_by_name = field.message_type.fields_by_name
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
+
+ value_field = fields_by_name['value']
+ if _IsMessageMapField(field):
+ def MakeMessageMapDefault(message):
+ return containers.MessageMap(
+ message._listener_for_children, value_field.message_type, key_checker)
+ return MakeMessageMapDefault
+ else:
+ value_checker = type_checkers.GetTypeChecker(value_field)
+ def MakePrimitiveMapDefault(message):
+ return containers.ScalarMap(
+ message._listener_for_children, key_checker, value_checker)
+ return MakePrimitiveMapDefault
+
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -279,6 +335,9 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
+ if _IsMapField(field):
+ return _GetInitializeDefaultForMap(field)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -329,7 +388,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name):
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
- fields = message_descriptor.fields
+
+ def _GetIntegerEnumValue(enum_type, value):
+ """Convert a string or integer enum value to an integer.
+
+ If the value is a string, it is converted to the enum value in
+ enum_type with the same name. If the value is not a string, it's
+ returned as-is. (No conversion or bounds-checking is done.)
+ """
+ if isinstance(value, _basestring):
+ try:
+ return enum_type.values_by_name[value].number
+ except KeyError:
+ raise ValueError('Enum type %s: unknown label "%s"' % (
+ enum_type.full_name, value))
+ return value
+
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
@@ -352,19 +426,37 @@ def _AddInitMethod(message_descriptor, cls):
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in field_value:
+ copy[key].MergeFrom(field_value[key])
+ else:
+ copy.update(field_value)
+ else:
+ for val in field_value:
+ if isinstance(val, dict):
+ copy.add(**val)
+ else:
+ copy.add().MergeFrom(val)
else: # Scalar
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = [_GetIntegerEnumValue(field.enum_type, val)
+ for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
+ new_val = field_value
+ if isinstance(field_value, dict):
+ new_val = field.message_type._concrete_class(**field_value)
try:
- copy.MergeFrom(field_value)
+ copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value)
try:
setattr(self, field_name, field_value)
except TypeError:
@@ -758,6 +850,26 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
+def _UnpackAny(msg):
+ type_url = msg.type_url
+ db = symbol_database.Default()
+
+ if not type_url:
+ return None
+
+ # TODO(haberman): For now we just strip the hostname. Better logic will be
+ # required.
+ type_name = type_url.split("/")[-1]
+ descriptor = db.pool.FindMessageTypeByName(type_name)
+
+ if descriptor is None:
+ return None
+
+ message_class = db.GetPrototype(descriptor)
+ message = message_class()
+
+ message.ParseFromString(msg.value)
+ return message
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -769,6 +881,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
+ if self.DESCRIPTOR.full_name == "google.protobuf.Any":
+ any_a = _UnpackAny(self)
+ any_b = _UnpackAny(other)
+ if any_a and any_b:
+ return any_a == any_b
+
if not self.ListFields() == other.ListFields():
return False
@@ -961,6 +1079,9 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.has_options and
+ field.message_type.GetOptions().map_entry):
+ continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -996,16 +1117,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in value:
+ element = value[key]
+ prefix = "%s[%d]." % (name, key)
+ sub_errors = element.FindInitializationErrors()
+ errors += [prefix + error for error in sub_errors]
+ else:
+ # ScalarMaps can't have any initialization errors.
+ pass
+ elif field.label == _FieldDescriptor.LABEL_REPEATED:
for i in xrange(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
return errors
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index ae79c78b..4eca4989 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -39,8 +39,8 @@ import copy
import gc
import operator
import struct
-import unittest
+import unittest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -1798,8 +1798,8 @@ class ReflectionTest(unittest.TestCase):
def testBadArguments(self):
# Some of these assertions used to segfault.
from google.protobuf.pyext import _message
- self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3)
- self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42)
+ self.assertRaises(TypeError, _message.default_pool.FindFieldByName, 3)
+ self.assertRaises(TypeError, _message.default_pool.FindExtensionByName, 42)
self.assertRaises(TypeError,
unittest_pb2.TestAllTypes().__getattribute__, 42)
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index de462124..9967255a 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -35,7 +35,6 @@
__author__ = 'petar@google.com (Petar Petrov)'
import unittest
-
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
@@ -81,7 +80,7 @@ class FooUnitTest(unittest.TestCase):
self.assertEqual('Method Bar not implemented.',
rpc_controller.failure_message)
self.assertEqual(None, self.callback_response)
-
+
class MyServiceImpl(unittest_pb2.TestService):
def Foo(self, rpc_controller, request, done):
self.foo_called = True
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
index c888aff7..80b83bc2 100644
--- a/python/google/protobuf/internal/symbol_database_test.py
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -33,7 +33,6 @@
"""Tests for google.protobuf.symbol_database."""
import unittest
-
from google.protobuf import unittest_pb2
from google.protobuf import symbol_database
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index d84e3836..0cbdbad9 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -75,7 +75,8 @@ def SetAllNonLazyFields(message):
message.optional_string = u'115'
message.optional_bytes = b'116'
- message.optionalgroup.a = 117
+ if IsProto2(message):
+ message.optionalgroup.a = 117
message.optional_nested_message.bb = 118
message.optional_foreign_message.c = 119
message.optional_import_message.d = 120
@@ -109,7 +110,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'215')
message.repeated_bytes.append(b'216')
- message.repeatedgroup.add().a = 217
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 217
message.repeated_nested_message.add().bb = 218
message.repeated_foreign_message.add().c = 219
message.repeated_import_message.add().d = 220
@@ -140,7 +142,8 @@ def SetAllNonLazyFields(message):
message.repeated_string.append(u'315')
message.repeated_bytes.append(b'316')
- message.repeatedgroup.add().a = 317
+ if IsProto2(message):
+ message.repeatedgroup.add().a = 317
message.repeated_nested_message.add().bb = 318
message.repeated_foreign_message.add().c = 319
message.repeated_import_message.add().d = 320
@@ -396,7 +399,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertTrue(message.HasField('optional_string'))
test_case.assertTrue(message.HasField('optional_bytes'))
- test_case.assertTrue(message.HasField('optionalgroup'))
+ if IsProto2(message):
+ test_case.assertTrue(message.HasField('optionalgroup'))
test_case.assertTrue(message.HasField('optional_nested_message'))
test_case.assertTrue(message.HasField('optional_foreign_message'))
test_case.assertTrue(message.HasField('optional_import_message'))
@@ -430,7 +434,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('115', message.optional_string)
test_case.assertEqual(b'116', message.optional_bytes)
- test_case.assertEqual(117, message.optionalgroup.a)
+ if IsProto2(message):
+ test_case.assertEqual(117, message.optionalgroup.a)
test_case.assertEqual(118, message.optional_nested_message.bb)
test_case.assertEqual(119, message.optional_foreign_message.c)
test_case.assertEqual(120, message.optional_import_message.d)
@@ -463,7 +468,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(2, len(message.repeated_string))
test_case.assertEqual(2, len(message.repeated_bytes))
- test_case.assertEqual(2, len(message.repeatedgroup))
+ if IsProto2(message):
+ test_case.assertEqual(2, len(message.repeatedgroup))
test_case.assertEqual(2, len(message.repeated_nested_message))
test_case.assertEqual(2, len(message.repeated_foreign_message))
test_case.assertEqual(2, len(message.repeated_import_message))
@@ -491,7 +497,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('215', message.repeated_string[0])
test_case.assertEqual(b'216', message.repeated_bytes[0])
- test_case.assertEqual(217, message.repeatedgroup[0].a)
+ if IsProto2(message):
+ test_case.assertEqual(217, message.repeatedgroup[0].a)
test_case.assertEqual(218, message.repeated_nested_message[0].bb)
test_case.assertEqual(219, message.repeated_foreign_message[0].c)
test_case.assertEqual(220, message.repeated_import_message[0].d)
@@ -521,7 +528,8 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual('315', message.repeated_string[1])
test_case.assertEqual(b'316', message.repeated_bytes[1])
- test_case.assertEqual(317, message.repeatedgroup[1].a)
+ if IsProto2(message):
+ test_case.assertEqual(317, message.repeatedgroup[1].a)
test_case.assertEqual(318, message.repeated_nested_message[1].bb)
test_case.assertEqual(319, message.repeated_foreign_message[1].c)
test_case.assertEqual(320, message.repeated_import_message[1].d)
diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py
index 27896b94..5df13b78 100755
--- a/python/google/protobuf/internal/text_encoding_test.py
+++ b/python/google/protobuf/internal/text_encoding_test.py
@@ -33,7 +33,6 @@
"""Tests for google.protobuf.text_encoding."""
import unittest
-
from google.protobuf import text_encoding
TEST_VALUES = [
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index bf7e06ee..06bd1ee5 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -37,8 +37,10 @@ __author__ = 'kenton@google.com (Kenton Varda)'
import re
import unittest
+import unittest
from google.protobuf.internal import _parameterized
+from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
@@ -309,31 +311,6 @@ class TextFormatTest(TextFormatBase):
r'"unknown_field".'),
text_format.Parse, text, message)
- def testParseGroupNotClosed(self, message_module):
- message = message_module.TestAllTypes()
- text = 'RepeatedGroup: <'
- self.assertRaisesRegexp(
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Parse, text, message)
-
- text = 'RepeatedGroup: {'
- self.assertRaisesRegexp(
- text_format.ParseError, '1:16 : Expected "}".',
- text_format.Parse, text, message)
-
- def testParseEmptyGroup(self, message_module):
- message = message_module.TestAllTypes()
- text = 'OptionalGroup: {}'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- message.Clear()
-
- message = message_module.TestAllTypes()
- text = 'OptionalGroup: <>'
- text_format.Parse(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
def testParseBadEnumValue(self, message_module):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
@@ -408,6 +385,14 @@ class TextFormatTest(TextFormatBase):
# Ideally the schemas would be made more similar so these tests could pass.
class OnlyWorksWithProto2RightNowTests(TextFormatBase):
+ def testPrintAllFieldsPointy(self, message_module):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
def testParseGolden(self):
golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
parsed_message = unittest_pb2.TestAllTypes()
@@ -471,8 +456,49 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
test_util.SetAllFields(message)
self.assertEquals(message, parsed_message)
+ def testPrintMap(self):
+ message = map_unittest_pb2.TestMap()
-# Tests of proto2-only features (MessageSet and extensions).
+ message.map_int32_int32[-123] = -456
+ message.map_int64_int64[-2**33] = -2**34
+ message.map_uint32_uint32[123] = 456
+ message.map_uint64_uint64[2**33] = 2**34
+ message.map_string_string["abc"] = "123"
+ message.map_int32_foreign_message[111].c = 5
+
+ # Maps are serialized to text format using their underlying repeated
+ # representation.
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+
+
+# Tests of proto2-only features (MessageSet, extensions, etc.).
class Proto2Tests(TextFormatBase):
def testPrintMessageSet(self):
@@ -620,6 +646,69 @@ class Proto2Tests(TextFormatBase):
'have multiple "optional_int32" fields.'),
text_format.Parse, text, message)
+ def testParseGroupNotClosed(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'RepeatedGroup: <'
+ self.assertRaisesRegexp(
+ text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Parse, text, message)
+ text = 'RepeatedGroup: {'
+ self.assertRaisesRegexp(
+ text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Parse, text, message)
+
+ def testParseEmptyGroup(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: {}'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
+ message.Clear()
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: <>'
+ text_format.Parse(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
+ # Maps aren't really proto2-only, but our test schema only has maps for
+ # proto2.
+ def testParseMap(self):
+ text = ('map_int32_int32 {\n'
+ ' key: -123\n'
+ ' value: -456\n'
+ '}\n'
+ 'map_int64_int64 {\n'
+ ' key: -8589934592\n'
+ ' value: -17179869184\n'
+ '}\n'
+ 'map_uint32_uint32 {\n'
+ ' key: 123\n'
+ ' value: 456\n'
+ '}\n'
+ 'map_uint64_uint64 {\n'
+ ' key: 8589934592\n'
+ ' value: 17179869184\n'
+ '}\n'
+ 'map_string_string {\n'
+ ' key: "abc"\n'
+ ' value: "123"\n'
+ '}\n'
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 5\n'
+ ' }\n'
+ '}\n')
+ message = map_unittest_pb2.TestMap()
+ text_format.Parse(text, message)
+
+ self.assertEqual(-456, message.map_int32_int32[-123])
+ self.assertEqual(-2**34, message.map_int64_int64[-2**33])
+ self.assertEqual(456, message.map_uint32_uint32[123])
+ self.assertEqual(2**34, message.map_uint64_uint64[2**33])
+ self.assertEqual("123", message.map_string_string["abc"])
+ self.assertEqual(5, message.map_int32_foreign_message[111].c)
+
class TokenizerTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 76c056c4..f20e526a 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -129,6 +129,9 @@ class IntValueChecker(object):
proposed_value = self._TYPE(proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return 0
+
class EnumValueChecker(object):
@@ -146,6 +149,9 @@ class EnumValueChecker(object):
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
+ def DefaultValue(self):
+ return self._enum_type.values[0].number
+
class UnicodeValueChecker(object):
@@ -171,6 +177,9 @@ class UnicodeValueChecker(object):
(proposed_value))
return proposed_value
+ def DefaultValue(self):
+ return u""
+
class Int32ValueChecker(IntValueChecker):
# We're sure to use ints instead of longs here since comparison may be more
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 9337ae8a..1b81ae79 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -36,7 +36,6 @@
__author__ = 'bohdank@google.com (Bohdan Koval)'
import unittest
-
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
index 5cd7fcb9..78dc1167 100755
--- a/python/google/protobuf/internal/wire_format_test.py
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -35,7 +35,6 @@
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
-
from google.protobuf import message
from google.protobuf.internal import wire_format
diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py
index 1fa28f1a..7489cf63 100644
--- a/python/google/protobuf/proto_builder.py
+++ b/python/google/protobuf/proto_builder.py
@@ -30,6 +30,7 @@
"""Dynamic Protobuf class creator."""
+import collections
import hashlib
import os
@@ -59,7 +60,9 @@ def MakeSimpleProtoClass(fields, full_name, pool=None):
Note: this doesn't validate field names!
Args:
- fields: dict of {name: field_type} mappings for each field in the proto.
+ fields: dict of {name: field_type} mappings for each field in the proto. If
+ this is an OrderedDict the order will be maintained, otherwise the
+ fields will be sorted by name.
full_name: str, the fully-qualified name of the proto type.
pool: optional DescriptorPool instance.
Returns:
@@ -73,12 +76,19 @@ def MakeSimpleProtoClass(fields, full_name, pool=None):
# The factory's DescriptorPool doesn't know about this class yet.
pass
+ # Get a list of (name, field_type) tuples from the fields dict. If fields was
+ # an OrderedDict we keep the order, but otherwise we sort the field to ensure
+ # consistent ordering.
+ field_items = fields.items()
+ if not isinstance(fields, collections.OrderedDict):
+ field_items = sorted(field_items)
+
# Use a consistent file name that is unlikely to conflict with any imported
# proto files.
fields_hash = hashlib.sha1()
- for f_name, f_type in sorted(fields.items()):
- fields_hash.update(f_name.encode('utf8'))
- fields_hash.update(str(f_type).encode('utf8'))
+ for f_name, f_type in field_items:
+ fields_hash.update(f_name.encode('utf-8'))
+ fields_hash.update(str(f_type).encode('utf-8'))
proto_file_name = fields_hash.hexdigest() + '.proto'
package, name = full_name.rsplit('.', 1)
@@ -87,7 +97,7 @@ def MakeSimpleProtoClass(fields, full_name, pool=None):
file_proto.package = package
desc_proto = file_proto.message_type.add()
desc_proto.name = name
- for f_number, (f_name, f_type) in enumerate(sorted(fields.items()), 1):
+ for f_number, (f_name, f_type) in enumerate(field_items, 1):
field_proto = desc_proto.field.add()
field_proto.name = f_name
field_proto.number = f_number
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index e77d0bb9..2160757b 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -43,8 +43,6 @@
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#define C(str) const_cast<char*>(str)
-
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#define PyString_Check PyUnicode_Check
@@ -257,8 +255,14 @@ namespace descriptor {
// Creates or retrieve a Python descriptor of the specified type.
// Objects are interned: the same descriptor will return the same object if it
// was kept alive.
+// 'was_created' is an optional pointer to a bool, and is set to true if a new
+// object was allocated.
// Always return a new reference.
-PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor) {
+PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor,
+ bool* was_created) {
+ if (was_created) {
+ *was_created = false;
+ }
if (descriptor == NULL) {
PyErr_BadInternalCall();
return NULL;
@@ -283,6 +287,9 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor) {
GetDescriptorPool()->interned_descriptors->insert(
std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor)));
+ if (was_created) {
+ *was_created = true;
+ }
return reinterpret_cast<PyObject*>(py_descriptor);
}
@@ -298,9 +305,7 @@ static PyGetSetDef Getters[] = {
PyTypeObject PyBaseDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- "google.protobuf.internal._message."
- "DescriptorBase", // tp_name
+ FULL_MODULE_NAME ".DescriptorBase", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
(destructor)Dealloc, // tp_dealloc
@@ -357,7 +362,7 @@ static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
}
static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
- return PyFileDescriptor_New(_GetDescriptor(self)->file());
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
}
static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
@@ -367,17 +372,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
return concrete_class;
}
-static int SetConcreteClass(PyBaseDescriptor *self, PyObject *value,
- void *closure) {
- // This attribute is also set from reflection.py. Check that it's actually a
- // no-op.
- if (value != cdescriptor_pool::GetMessageClass(
- GetDescriptorPool(), _GetDescriptor(self))) {
- PyErr_SetString(PyExc_AttributeError, "Cannot change _concrete_class");
- }
- return 0;
-}
-
static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) {
return NewMessageFieldsByName(_GetDescriptor(self));
}
@@ -452,7 +446,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
const Descriptor* containing_type =
_GetDescriptor(self)->containing_type();
if (containing_type) {
- return PyMessageDescriptor_New(containing_type);
+ return PyMessageDescriptor_FromDescriptor(containing_type);
} else {
Py_RETURN_NONE;
}
@@ -515,29 +509,34 @@ static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) {
}
static PyGetSetDef Getters[] = {
- { C("name"), (getter)GetName, NULL, "Last name", NULL},
- { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
- { C("_concrete_class"), (getter)GetConcreteClass, (setter)SetConcreteClass, "concrete class", NULL},
- { C("file"), (getter)GetFile, NULL, "File descriptor", NULL},
-
- { C("fields"), (getter)GetFieldsSeq, NULL, "Fields sequence", NULL},
- { C("fields_by_name"), (getter)GetFieldsByName, NULL, "Fields by name", NULL},
- { C("fields_by_number"), (getter)GetFieldsByNumber, NULL, "Fields by number", NULL},
- { C("nested_types"), (getter)GetNestedTypesSeq, NULL, "Nested types sequence", NULL},
- { C("nested_types_by_name"), (getter)GetNestedTypesByName, NULL, "Nested types by name", NULL},
- { C("extensions"), (getter)GetExtensions, NULL, "Extensions Sequence", NULL},
- { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL},
- { C("extension_ranges"), (getter)GetExtensionRanges, NULL, "Extension ranges", NULL},
- { C("enum_types"), (getter)GetEnumsSeq, NULL, "Enum sequence", NULL},
- { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enum types by name", NULL},
- { C("enum_values_by_name"), (getter)GetEnumValuesByName, NULL, "Enum values by name", NULL},
- { C("oneofs_by_name"), (getter)GetOneofsByName, NULL, "Oneofs by name", NULL},
- { C("oneofs"), (getter)GetOneofsSeq, NULL, "Oneofs by name", NULL},
- { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL},
- { C("is_extendable"), (getter)IsExtendable, (setter)NULL, NULL, NULL},
- { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL},
- { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL},
- { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL},
+ { "name", (getter)GetName, NULL, "Last name"},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "_concrete_class", (getter)GetConcreteClass, NULL, "concrete class"},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
+
+ { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"},
+ { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"},
+ { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"},
+ { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"},
+ { "nested_types_by_name", (getter)GetNestedTypesByName, NULL,
+ "Nested types by name"},
+ { "extensions", (getter)GetExtensions, NULL, "Extensions Sequence"},
+ { "extensions_by_name", (getter)GetExtensionsByName, NULL,
+ "Extensions by name"},
+ { "extension_ranges", (getter)GetExtensionRanges, NULL, "Extension ranges"},
+ { "enum_types", (getter)GetEnumsSeq, NULL, "Enum sequence"},
+ { "enum_types_by_name", (getter)GetEnumTypesByName, NULL,
+ "Enum types by name"},
+ { "enum_values_by_name", (getter)GetEnumValuesByName, NULL,
+ "Enum values by name"},
+ { "oneofs_by_name", (getter)GetOneofsByName, NULL, "Oneofs by name"},
+ { "oneofs", (getter)GetOneofsSeq, NULL, "Oneofs by name"},
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "is_extendable", (getter)IsExtendable, (setter)NULL},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
{NULL}
};
@@ -552,9 +551,7 @@ static PyMethodDef Methods[] = {
PyTypeObject PyMessageDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- C("google.protobuf.internal._message."
- "MessageDescriptor"), // tp_name
+ FULL_MODULE_NAME ".MessageDescriptor", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
@@ -573,7 +570,7 @@ PyTypeObject PyMessageDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Message Descriptor"), // tp_doc
+ "A Message Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -586,10 +583,10 @@ PyTypeObject PyMessageDescriptor_Type = {
&descriptor::PyBaseDescriptor_Type, // tp_base
};
-PyObject* PyMessageDescriptor_New(
+PyObject* PyMessageDescriptor_FromDescriptor(
const Descriptor* message_descriptor) {
return descriptor::NewInternedDescriptor(
- &PyMessageDescriptor_Type, message_descriptor);
+ &PyMessageDescriptor_Type, message_descriptor, NULL);
}
const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) {
@@ -715,7 +712,7 @@ static PyObject* GetCDescriptor(PyObject *self, void *closure) {
static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) {
const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type();
if (enum_type) {
- return PyEnumDescriptor_New(enum_type);
+ return PyEnumDescriptor_FromDescriptor(enum_type);
} else {
Py_RETURN_NONE;
}
@@ -728,7 +725,7 @@ static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) {
static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) {
const Descriptor* message_type = _GetDescriptor(self)->message_type();
if (message_type) {
- return PyMessageDescriptor_New(message_type);
+ return PyMessageDescriptor_FromDescriptor(message_type);
} else {
Py_RETURN_NONE;
}
@@ -743,7 +740,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
const Descriptor* containing_type =
_GetDescriptor(self)->containing_type();
if (containing_type) {
- return PyMessageDescriptor_New(containing_type);
+ return PyMessageDescriptor_FromDescriptor(containing_type);
} else {
Py_RETURN_NONE;
}
@@ -758,7 +755,7 @@ static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) {
const Descriptor* extension_scope =
_GetDescriptor(self)->extension_scope();
if (extension_scope) {
- return PyMessageDescriptor_New(extension_scope);
+ return PyMessageDescriptor_FromDescriptor(extension_scope);
} else {
Py_RETURN_NONE;
}
@@ -768,7 +765,7 @@ static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) {
const OneofDescriptor* containing_oneof =
_GetDescriptor(self)->containing_oneof();
if (containing_oneof) {
- return PyOneofDescriptor_New(containing_oneof);
+ return PyOneofDescriptor_FromDescriptor(containing_oneof);
} else {
Py_RETURN_NONE;
}
@@ -803,26 +800,30 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
static PyGetSetDef Getters[] = {
- { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
- { C("name"), (getter)GetName, NULL, "Unqualified name", NULL},
- { C("type"), (getter)GetType, NULL, "C++ Type", NULL},
- { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL},
- { C("label"), (getter)GetLabel, NULL, "Label", NULL},
- { C("number"), (getter)GetNumber, NULL, "Number", NULL},
- { C("index"), (getter)GetIndex, NULL, "Index", NULL},
- { C("default_value"), (getter)GetDefaultValue, NULL, "Default Value", NULL},
- { C("has_default_value"), (getter)HasDefaultValue, NULL, NULL, NULL},
- { C("is_extension"), (getter)IsExtension, NULL, "ID", NULL},
- { C("id"), (getter)GetID, NULL, "ID", NULL},
- { C("_cdescriptor"), (getter)GetCDescriptor, NULL, "HAACK REMOVE ME", NULL},
-
- { C("message_type"), (getter)GetMessageType, (setter)SetMessageType, "Message type", NULL},
- { C("enum_type"), (getter)GetEnumType, (setter)SetEnumType, "Enum type", NULL},
- { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL},
- { C("extension_scope"), (getter)GetExtensionScope, (setter)NULL, "Extension scope", NULL},
- { C("containing_oneof"), (getter)GetContainingOneof, (setter)SetContainingOneof, "Containing oneof", NULL},
- { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL},
- { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "name", (getter)GetName, NULL, "Unqualified name"},
+ { "type", (getter)GetType, NULL, "C++ Type"},
+ { "cpp_type", (getter)GetCppType, NULL, "C++ Type"},
+ { "label", (getter)GetLabel, NULL, "Label"},
+ { "number", (getter)GetNumber, NULL, "Number"},
+ { "index", (getter)GetIndex, NULL, "Index"},
+ { "default_value", (getter)GetDefaultValue, NULL, "Default Value"},
+ { "has_default_value", (getter)HasDefaultValue},
+ { "is_extension", (getter)IsExtension, NULL, "ID"},
+ { "id", (getter)GetID, NULL, "ID"},
+ { "_cdescriptor", (getter)GetCDescriptor, NULL, "HAACK REMOVE ME"},
+
+ { "message_type", (getter)GetMessageType, (setter)SetMessageType,
+ "Message type"},
+ { "enum_type", (getter)GetEnumType, (setter)SetEnumType, "Enum type"},
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "extension_scope", (getter)GetExtensionScope, (setter)NULL,
+ "Extension scope"},
+ { "containing_oneof", (getter)GetContainingOneof, (setter)SetContainingOneof,
+ "Containing oneof"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
{NULL}
};
@@ -835,8 +836,7 @@ static PyMethodDef Methods[] = {
PyTypeObject PyFieldDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_message.FieldDescriptor"), // tp_name
+ FULL_MODULE_NAME ".FieldDescriptor", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
@@ -855,7 +855,7 @@ PyTypeObject PyFieldDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Field Descriptor"), // tp_doc
+ "A Field Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -868,10 +868,10 @@ PyTypeObject PyFieldDescriptor_Type = {
&descriptor::PyBaseDescriptor_Type, // tp_base
};
-PyObject* PyFieldDescriptor_New(
+PyObject* PyFieldDescriptor_FromDescriptor(
const FieldDescriptor* field_descriptor) {
return descriptor::NewInternedDescriptor(
- &PyFieldDescriptor_Type, field_descriptor);
+ &PyFieldDescriptor_Type, field_descriptor, NULL);
}
const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) {
@@ -900,7 +900,7 @@ static PyObject* GetName(PyBaseDescriptor *self, void *closure) {
}
static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
- return PyFileDescriptor_New(_GetDescriptor(self)->file());
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
}
static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) {
@@ -919,7 +919,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
const Descriptor* containing_type =
_GetDescriptor(self)->containing_type();
if (containing_type) {
- return PyMessageDescriptor_New(containing_type);
+ return PyMessageDescriptor_FromDescriptor(containing_type);
} else {
Py_RETURN_NONE;
}
@@ -964,16 +964,19 @@ static PyMethodDef Methods[] = {
};
static PyGetSetDef Getters[] = {
- { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
- { C("name"), (getter)GetName, NULL, "last name", NULL},
- { C("file"), (getter)GetFile, NULL, "File descriptor", NULL},
- { C("values"), (getter)GetEnumvaluesSeq, NULL, "values", NULL},
- { C("values_by_name"), (getter)GetEnumvaluesByName, NULL, "Enumvalues by name", NULL},
- { C("values_by_number"), (getter)GetEnumvaluesByNumber, NULL, "Enumvalues by number", NULL},
-
- { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL},
- { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL},
- { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "name", (getter)GetName, NULL, "last name"},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
+ { "values", (getter)GetEnumvaluesSeq, NULL, "values"},
+ { "values_by_name", (getter)GetEnumvaluesByName, NULL,
+ "Enum values by name"},
+ { "values_by_number", (getter)GetEnumvaluesByNumber, NULL,
+ "Enum values by number"},
+
+ { "containing_type", (getter)GetContainingType, (setter)SetContainingType,
+ "Containing type"},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
{NULL}
};
@@ -981,9 +984,7 @@ static PyGetSetDef Getters[] = {
PyTypeObject PyEnumDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- C("google.protobuf.internal._message."
- "EnumDescriptor"), // tp_name
+ FULL_MODULE_NAME ".EnumDescriptor", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
@@ -1002,7 +1003,7 @@ PyTypeObject PyEnumDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Enum Descriptor"), // tp_doc
+ "A Enum Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -1015,10 +1016,10 @@ PyTypeObject PyEnumDescriptor_Type = {
&descriptor::PyBaseDescriptor_Type, // tp_base
};
-PyObject* PyEnumDescriptor_New(
+PyObject* PyEnumDescriptor_FromDescriptor(
const EnumDescriptor* enum_descriptor) {
return descriptor::NewInternedDescriptor(
- &PyEnumDescriptor_Type, enum_descriptor);
+ &PyEnumDescriptor_Type, enum_descriptor, NULL);
}
namespace enumvalue_descriptor {
@@ -1042,7 +1043,7 @@ static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
}
static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
- return PyEnumDescriptor_New(_GetDescriptor(self)->type());
+ return PyEnumDescriptor_FromDescriptor(_GetDescriptor(self)->type());
}
static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) {
@@ -1069,13 +1070,13 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value,
static PyGetSetDef Getters[] = {
- { C("name"), (getter)GetName, NULL, "name", NULL},
- { C("number"), (getter)GetNumber, NULL, "number", NULL},
- { C("index"), (getter)GetIndex, NULL, "index", NULL},
- { C("type"), (getter)GetType, NULL, "index", NULL},
+ { "name", (getter)GetName, NULL, "name"},
+ { "number", (getter)GetNumber, NULL, "number"},
+ { "index", (getter)GetIndex, NULL, "index"},
+ { "type", (getter)GetType, NULL, "index"},
- { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL},
- { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL},
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
{NULL}
};
@@ -1088,8 +1089,7 @@ static PyMethodDef Methods[] = {
PyTypeObject PyEnumValueDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_message.EnumValueDescriptor"), // tp_name
+ FULL_MODULE_NAME ".EnumValueDescriptor", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
@@ -1108,7 +1108,7 @@ PyTypeObject PyEnumValueDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A EnumValue Descriptor"), // tp_doc
+ "A EnumValue Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -1121,10 +1121,10 @@ PyTypeObject PyEnumValueDescriptor_Type = {
&descriptor::PyBaseDescriptor_Type, // tp_base
};
-PyObject* PyEnumValueDescriptor_New(
+PyObject* PyEnumValueDescriptor_FromDescriptor(
const EnumValueDescriptor* enumvalue_descriptor) {
return descriptor::NewInternedDescriptor(
- &PyEnumValueDescriptor_Type, enumvalue_descriptor);
+ &PyEnumValueDescriptor_Type, enumvalue_descriptor, NULL);
}
namespace file_descriptor {
@@ -1218,18 +1218,20 @@ static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) {
}
static PyGetSetDef Getters[] = {
- { C("name"), (getter)GetName, NULL, "name", NULL},
- { C("package"), (getter)GetPackage, NULL, "package", NULL},
- { C("serialized_pb"), (getter)GetSerializedPb, NULL, NULL, NULL},
- { C("message_types_by_name"), (getter)GetMessageTypesByName, NULL, "Messages by name", NULL},
- { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enums by name", NULL},
- { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL},
- { C("dependencies"), (getter)GetDependencies, NULL, "Dependencies", NULL},
- { C("public_dependencies"), (getter)GetPublicDependencies, NULL, "Dependencies", NULL},
-
- { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL},
- { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL},
- { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL},
+ { "name", (getter)GetName, NULL, "name"},
+ { "package", (getter)GetPackage, NULL, "package"},
+ { "serialized_pb", (getter)GetSerializedPb},
+ { "message_types_by_name", (getter)GetMessageTypesByName, NULL,
+ "Messages by name"},
+ { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"},
+ { "extensions_by_name", (getter)GetExtensionsByName, NULL,
+ "Extensions by name"},
+ { "dependencies", (getter)GetDependencies, NULL, "Dependencies"},
+ { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"},
+
+ { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"},
+ { "_options", (getter)NULL, (setter)SetOptions, "Options"},
+ { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"},
{NULL}
};
@@ -1243,11 +1245,10 @@ static PyMethodDef Methods[] = {
PyTypeObject PyFileDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_message.FileDescriptor"), // tp_name
+ FULL_MODULE_NAME ".FileDescriptor", // tp_name
sizeof(PyFileDescriptor), // tp_basicsize
0, // tp_itemsize
- (destructor)file_descriptor::Dealloc, // tp_dealloc
+ (destructor)file_descriptor::Dealloc, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
@@ -1263,7 +1264,7 @@ PyTypeObject PyFileDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A File Descriptor"), // tp_doc
+ "A File Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -1284,23 +1285,28 @@ PyTypeObject PyFileDescriptor_Type = {
PyObject_Del, // tp_free
};
-PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor) {
- return descriptor::NewInternedDescriptor(
- &PyFileDescriptor_Type, file_descriptor);
+PyObject* PyFileDescriptor_FromDescriptor(
+ const FileDescriptor* file_descriptor) {
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(file_descriptor,
+ NULL);
}
-PyObject* PyFileDescriptor_NewWithPb(
+PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb(
const FileDescriptor* file_descriptor, PyObject *serialized_pb) {
- PyObject* py_descriptor = PyFileDescriptor_New(file_descriptor);
+ bool was_created;
+ PyObject* py_descriptor = descriptor::NewInternedDescriptor(
+ &PyFileDescriptor_Type, file_descriptor, &was_created);
if (py_descriptor == NULL) {
return NULL;
}
- if (serialized_pb != NULL) {
+ if (was_created) {
PyFileDescriptor* cfile_descriptor =
reinterpret_cast<PyFileDescriptor*>(py_descriptor);
Py_XINCREF(serialized_pb);
cfile_descriptor->serialized_pb = serialized_pb;
}
+ // TODO(amauryfa): In the case of a cached object, check that serialized_pb
+ // is the same as before.
return py_descriptor;
}
@@ -1333,19 +1339,19 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) {
const Descriptor* containing_type =
_GetDescriptor(self)->containing_type();
if (containing_type) {
- return PyMessageDescriptor_New(containing_type);
+ return PyMessageDescriptor_FromDescriptor(containing_type);
} else {
Py_RETURN_NONE;
}
}
static PyGetSetDef Getters[] = {
- { C("name"), (getter)GetName, NULL, "Name", NULL},
- { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
- { C("index"), (getter)GetIndex, NULL, "Index", NULL},
+ { "name", (getter)GetName, NULL, "Name"},
+ { "full_name", (getter)GetFullName, NULL, "Full name"},
+ { "index", (getter)GetIndex, NULL, "Index"},
- { C("containing_type"), (getter)GetContainingType, NULL, "Containing type", NULL},
- { C("fields"), (getter)GetFields, NULL, "Fields", NULL},
+ { "containing_type", (getter)GetContainingType, NULL, "Containing type"},
+ { "fields", (getter)GetFields, NULL, "Fields"},
{NULL}
};
@@ -1353,8 +1359,7 @@ static PyGetSetDef Getters[] = {
PyTypeObject PyOneofDescriptor_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_message.OneofDescriptor"), // tp_name
+ FULL_MODULE_NAME ".OneofDescriptor", // tp_name
sizeof(PyBaseDescriptor), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
@@ -1373,7 +1378,7 @@ PyTypeObject PyOneofDescriptor_Type = {
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Oneof Descriptor"), // tp_doc
+ "A Oneof Descriptor", // tp_doc
0, // tp_traverse
0, // tp_clear
0, // tp_richcompare
@@ -1386,10 +1391,10 @@ PyTypeObject PyOneofDescriptor_Type = {
&descriptor::PyBaseDescriptor_Type, // tp_base
};
-PyObject* PyOneofDescriptor_New(
+PyObject* PyOneofDescriptor_FromDescriptor(
const OneofDescriptor* oneof_descriptor) {
return descriptor::NewInternedDescriptor(
- &PyOneofDescriptor_Type, oneof_descriptor);
+ &PyOneofDescriptor_Type, oneof_descriptor, NULL);
}
// Add a enum values to a type dictionary.
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
index ba6e7298..b2550406 100644
--- a/python/google/protobuf/pyext/descriptor.h
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -48,21 +48,24 @@ extern PyTypeObject PyEnumValueDescriptor_Type;
extern PyTypeObject PyFileDescriptor_Type;
extern PyTypeObject PyOneofDescriptor_Type;
-// Return a new reference to a Descriptor object.
+// Wraps a Descriptor in a Python object.
// The C++ pointer is usually borrowed from the global DescriptorPool.
// In any case, it must stay alive as long as the Python object.
-PyObject* PyMessageDescriptor_New(const Descriptor* descriptor);
-PyObject* PyFieldDescriptor_New(const FieldDescriptor* descriptor);
-PyObject* PyEnumDescriptor_New(const EnumDescriptor* descriptor);
-PyObject* PyEnumValueDescriptor_New(const EnumValueDescriptor* descriptor);
-PyObject* PyOneofDescriptor_New(const OneofDescriptor* descriptor);
-PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor);
+// Returns a new reference.
+PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor);
+PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor);
+PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor);
+PyObject* PyEnumValueDescriptor_FromDescriptor(
+ const EnumValueDescriptor* descriptor);
+PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor);
+PyObject* PyFileDescriptor_FromDescriptor(
+ const FileDescriptor* file_descriptor);
// Alternate constructor of PyFileDescriptor, used when we already have a
// serialized FileDescriptorProto that can be cached.
// Returns a new reference.
-PyObject* PyFileDescriptor_NewWithPb(const FileDescriptor* file_descriptor,
- PyObject* serialized_pb);
+PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb(
+ const FileDescriptor* file_descriptor, PyObject* serialized_pb);
// Return the C++ descriptor pointer.
// This function checks the parameter type; on error, return NULL with a Python
diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc
index 06edebf8..92e11e31 100644
--- a/python/google/protobuf/pyext/descriptor_containers.cc
+++ b/python/google/protobuf/pyext/descriptor_containers.cc
@@ -898,7 +898,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_New(item);
+ return PyFieldDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -956,7 +956,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyMessageDescriptor_New(item);
+ return PyMessageDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1006,7 +1006,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumDescriptor_New(item);
+ return PyEnumDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1082,7 +1082,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumValueDescriptor_New(item);
+ return PyEnumValueDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1124,7 +1124,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_New(item);
+ return PyFieldDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1174,7 +1174,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyOneofDescriptor_New(item);
+ return PyOneofDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1238,7 +1238,7 @@ static ItemDescriptor GetByNumber(PyContainer* self, int number) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumValueDescriptor_New(item);
+ return PyEnumValueDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1302,7 +1302,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_New(item);
+ return PyFieldDescriptor_FromDescriptor(item);
}
static int GetItemIndex(ItemDescriptor item) {
@@ -1354,7 +1354,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyMessageDescriptor_New(item);
+ return PyMessageDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1400,7 +1400,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyEnumDescriptor_New(item);
+ return PyEnumDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1446,7 +1446,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFieldDescriptor_New(item);
+ return PyFieldDescriptor_FromDescriptor(item);
}
static const string& GetItemName(ItemDescriptor item) {
@@ -1488,7 +1488,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFileDescriptor_New(item);
+ return PyFileDescriptor_FromDescriptor(item);
}
static DescriptorContainerDef ContainerDef = {
@@ -1522,7 +1522,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) {
}
static PyObject* NewObjectFromItem(ItemDescriptor item) {
- return PyFileDescriptor_New(item);
+ return PyFileDescriptor_FromDescriptor(item);
}
static DescriptorContainerDef ContainerDef = {
diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h
index d81537de..8fbdaff9 100644
--- a/python/google/protobuf/pyext/descriptor_containers.h
+++ b/python/google/protobuf/pyext/descriptor_containers.h
@@ -28,6 +28,9 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
+
// Mappings and Sequences of descriptors.
// They implement containers like fields_by_name, EnumDescriptor.values...
// See descriptor_containers.cc for more description.
@@ -92,4 +95,6 @@ PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor);
} // namespace python
} // namespace protobuf
+
} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index bc3077bc..ecd90847 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -35,10 +35,9 @@
#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
-#define C(str) const_cast<char*>(str)
-
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#if PY_VERSION_HEX < 0x03030000
@@ -108,11 +107,11 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
self->pool->FindMessageTypeByName(string(name, name_size));
if (message_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find message %.200s", name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find message %.200s", name);
return NULL;
}
- return PyMessageDescriptor_New(message_descriptor);
+ return PyMessageDescriptor_FromDescriptor(message_descriptor);
}
// Add a message class to our database.
@@ -158,6 +157,24 @@ PyObject *GetMessageClass(PyDescriptorPool* self,
}
}
+PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
+ Py_ssize_t name_size;
+ char* name;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
+ return NULL;
+ }
+
+ const FileDescriptor* file_descriptor =
+ self->pool->FindFileByName(string(name, name_size));
+ if (file_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s",
+ name);
+ return NULL;
+ }
+
+ return PyFileDescriptor_FromDescriptor(file_descriptor);
+}
+
PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
char* name;
@@ -168,12 +185,12 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
const FieldDescriptor* field_descriptor =
self->pool->FindFieldByName(string(name, name_size));
if (field_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s",
+ PyErr_Format(PyExc_KeyError, "Couldn't find field %.200s",
name);
return NULL;
}
- return PyFieldDescriptor_New(field_descriptor);
+ return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
@@ -186,11 +203,11 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
const FieldDescriptor* field_descriptor =
self->pool->FindExtensionByName(string(name, name_size));
if (field_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find extension field %.200s", name);
return NULL;
}
- return PyFieldDescriptor_New(field_descriptor);
+ return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
@@ -203,11 +220,11 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
const EnumDescriptor* enum_descriptor =
self->pool->FindEnumTypeByName(string(name, name_size));
if (enum_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find enum %.200s", name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find enum %.200s", name);
return NULL;
}
- return PyEnumDescriptor_New(enum_descriptor);
+ return PyEnumDescriptor_FromDescriptor(enum_descriptor);
}
PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
@@ -220,70 +237,13 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
const OneofDescriptor* oneof_descriptor =
self->pool->FindOneofByName(string(name, name_size));
if (oneof_descriptor == NULL) {
- PyErr_Format(PyExc_TypeError, "Couldn't find oneof %.200s", name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find oneof %.200s", name);
return NULL;
}
- return PyOneofDescriptor_New(oneof_descriptor);
+ return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
}
-static PyMethodDef Methods[] = {
- { C("FindFieldByName"),
- (PyCFunction)FindFieldByName,
- METH_O,
- C("Searches for a field descriptor by full name.") },
- { C("FindExtensionByName"),
- (PyCFunction)FindExtensionByName,
- METH_O,
- C("Searches for extension descriptor by full name.") },
- {NULL}
-};
-
-} // namespace cdescriptor_pool
-
-PyTypeObject PyDescriptorPool_Type = {
- PyVarObject_HEAD_INIT(&PyType_Type, 0)
- C("google.protobuf.internal."
- "_message.DescriptorPool"), // tp_name
- sizeof(PyDescriptorPool), // tp_basicsize
- 0, // tp_itemsize
- (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
- 0, // tp_print
- 0, // tp_getattr
- 0, // tp_setattr
- 0, // tp_compare
- 0, // tp_repr
- 0, // tp_as_number
- 0, // tp_as_sequence
- 0, // tp_as_mapping
- 0, // tp_hash
- 0, // tp_call
- 0, // tp_str
- 0, // tp_getattro
- 0, // tp_setattro
- 0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT, // tp_flags
- C("A Descriptor Pool"), // tp_doc
- 0, // tp_traverse
- 0, // tp_clear
- 0, // tp_richcompare
- 0, // tp_weaklistoffset
- 0, // tp_iter
- 0, // tp_iternext
- cdescriptor_pool::Methods, // tp_methods
- 0, // tp_members
- 0, // tp_getset
- 0, // tp_base
- 0, // tp_dict
- 0, // tp_descr_get
- 0, // tp_descr_set
- 0, // tp_dictoffset
- 0, // tp_init
- 0, // tp_alloc
- 0, // tp_new
- PyObject_Del, // tp_free
-};
-
// The code below loads new Descriptors from a serialized FileDescriptorProto.
@@ -301,6 +261,7 @@ class BuildFileErrorCollector : public DescriptorPool::ErrorCollector {
if (!had_errors) {
error_message +=
("Invalid proto descriptor for file \"" + filename + "\":\n");
+ had_errors = true;
}
// As this only happens on failure and will result in the program not
// running at all, no effort is made to optimize this string manipulation.
@@ -311,7 +272,7 @@ class BuildFileErrorCollector : public DescriptorPool::ErrorCollector {
bool had_errors;
};
-PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) {
+PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) {
char* message_type;
Py_ssize_t message_len;
@@ -330,13 +291,14 @@ PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) {
const FileDescriptor* generated_file =
DescriptorPool::generated_pool()->FindFileByName(file_proto.name());
if (generated_file != NULL) {
- return PyFileDescriptor_NewWithPb(generated_file, serialized_pb);
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(
+ generated_file, serialized_pb);
}
BuildFileErrorCollector error_collector;
const FileDescriptor* descriptor =
- GetDescriptorPool()->pool->BuildFileCollectingErrors(file_proto,
- &error_collector);
+ self->pool->BuildFileCollectingErrors(file_proto,
+ &error_collector);
if (descriptor == NULL) {
PyErr_Format(PyExc_TypeError,
"Couldn't build proto file into descriptor pool!\n%s",
@@ -344,9 +306,84 @@ PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) {
return NULL;
}
- return PyFileDescriptor_NewWithPb(descriptor, serialized_pb);
+ return PyFileDescriptor_FromDescriptorWithSerializedPb(
+ descriptor, serialized_pb);
+}
+
+PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) {
+ ScopedPyObjectPtr serialized_pb(
+ PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL));
+ if (serialized_pb == NULL) {
+ return NULL;
+ }
+ return AddSerializedFile(self, serialized_pb);
}
+static PyMethodDef Methods[] = {
+ { "Add", (PyCFunction)Add, METH_O,
+ "Adds the FileDescriptorProto and its types to this pool." },
+ { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O,
+ "Adds a serialized FileDescriptorProto to this pool." },
+
+ { "FindFileByName", (PyCFunction)FindFileByName, METH_O,
+ "Searches for a file descriptor by its .proto name." },
+ { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O,
+ "Searches for a message descriptor by full name." },
+ { "FindFieldByName", (PyCFunction)FindFieldByName, METH_O,
+ "Searches for a field descriptor by full name." },
+ { "FindExtensionByName", (PyCFunction)FindExtensionByName, METH_O,
+ "Searches for extension descriptor by full name." },
+ { "FindEnumTypeByName", (PyCFunction)FindEnumTypeByName, METH_O,
+ "Searches for enum type descriptor by full name." },
+ { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O,
+ "Searches for oneof descriptor by full name." },
+ {NULL}
+};
+
+} // namespace cdescriptor_pool
+
+PyTypeObject PyDescriptorPool_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".DescriptorPool", // tp_name
+ sizeof(PyDescriptorPool), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Descriptor Pool", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ cdescriptor_pool::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ 0, // tp_new
+ PyObject_Del, // tp_free
+};
+
static PyDescriptorPool* global_cdescriptor_pool = NULL;
bool InitDescriptorPool() {
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
index 4e494b89..efb1abeb 100644
--- a/python/google/protobuf/pyext/descriptor_pool.h
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -95,6 +95,8 @@ const Descriptor* FindMessageTypeByName(PyDescriptorPool* self,
const Descriptor* RegisterMessageClass(
PyDescriptorPool* self, PyObject* message_class, PyObject* descriptor);
+// The function below are also exposed as methods of the DescriptorPool type.
+
// Retrieves the Python class registered with the given message descriptor.
//
// Returns a *borrowed* reference if found, otherwise returns NULL with an
@@ -134,12 +136,8 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg);
} // namespace cdescriptor_pool
-// Implement the Python "_BuildFile" method, it takes a serialized
-// FileDescriptorProto, and adds it to the C++ DescriptorPool.
-// It returns a new FileDescriptor object, or NULL when an exception is raised.
-PyObject* Python_BuildFile(PyObject* ignored, PyObject* args);
-
// Retrieve the global descriptor pool owned by the _message module.
+// Returns a *borrowed* reference.
PyDescriptorPool* GetDescriptorPool();
// Initialize objects used by this module.
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index 8e38fc42..b8d18f8d 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -51,16 +51,6 @@ namespace python {
namespace extension_dict {
-// TODO(tibell): Always use self->message for clarity, just like in
-// RepeatedCompositeContainer.
-static Message* GetMessage(ExtensionDict* self) {
- if (self->parent != NULL) {
- return self->parent->message;
- } else {
- return self->message;
- }
-}
-
PyObject* len(ExtensionDict* self) {
#if PY_MAJOR_VERSION >= 3
return PyLong_FromLong(PyDict_Size(self->values));
@@ -89,7 +79,7 @@ int ReleaseExtension(ExtensionDict* self,
}
} else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
if (cmessage::ReleaseSubMessage(
- GetMessage(self), descriptor,
+ self->parent, descriptor,
reinterpret_cast<CMessage*>(extension)) < 0) {
return -1;
}
@@ -109,7 +99,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
- return cmessage::InternalGetScalar(self->parent, descriptor);
+ return cmessage::InternalGetScalar(self->parent->message, descriptor);
}
PyObject* value = PyDict_GetItem(self->values, key);
@@ -266,8 +256,7 @@ static PyMethodDef Methods[] = {
PyTypeObject ExtensionDict_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- "google.protobuf.internal."
- "cpp._message.ExtensionDict", // tp_name
+ FULL_MODULE_NAME ".ExtensionDict", // tp_name
sizeof(ExtensionDict), // tp_basicsize
0, // tp_itemsize
(destructor)extension_dict::dealloc, // tp_dealloc
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index a2b357b2..a4843e8d 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -59,6 +59,8 @@
#include <google/protobuf/pyext/extension_dict.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/message_map_container.h>
+#include <google/protobuf/pyext/scalar_map_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/strutil.h>
@@ -93,9 +95,9 @@ static const FieldDescriptor* GetFieldDescriptor(
static const Descriptor* GetMessageDescriptor(PyTypeObject* cls);
static string GetMessageName(CMessage* self);
int InternalReleaseFieldByDescriptor(
+ CMessage* self,
const FieldDescriptor* field_descriptor,
- PyObject* composite_field,
- Message* parent_message);
+ PyObject* composite_field);
} // namespace cmessage
// ---------------------------------------------------------------------
@@ -127,10 +129,29 @@ static int VisitCompositeField(const FieldDescriptor* descriptor,
Visitor visitor) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- RepeatedCompositeContainer* container =
- reinterpret_cast<RepeatedCompositeContainer*>(child);
- if (visitor.VisitRepeatedCompositeContainer(container) == -1)
- return -1;
+ if (descriptor->is_map()) {
+ const Descriptor* entry_type = descriptor->message_type();
+ const FieldDescriptor* value_type =
+ entry_type->FindFieldByName("value");
+ if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ MessageMapContainer* container =
+ reinterpret_cast<MessageMapContainer*>(child);
+ if (visitor.VisitMessageMapContainer(container) == -1) {
+ return -1;
+ }
+ } else {
+ ScalarMapContainer* container =
+ reinterpret_cast<ScalarMapContainer*>(child);
+ if (visitor.VisitScalarMapContainer(container) == -1) {
+ return -1;
+ }
+ }
+ } else {
+ RepeatedCompositeContainer* container =
+ reinterpret_cast<RepeatedCompositeContainer*>(child);
+ if (visitor.VisitRepeatedCompositeContainer(container) == -1)
+ return -1;
+ }
} else {
RepeatedScalarContainer* container =
reinterpret_cast<RepeatedScalarContainer*>(child);
@@ -444,7 +465,7 @@ static int MaybeReleaseOverlappingOneofField(
}
if (InternalReleaseFieldByDescriptor(
- existing_field, child_message, message) < 0) {
+ cmessage, existing_field, child_message) < 0) {
return -1;
}
return PyDict_DelItemString(cmessage->composite_fields, field_name);
@@ -483,6 +504,16 @@ struct FixupMessageReference : public ChildVisitor {
return 0;
}
+ int VisitScalarMapContainer(ScalarMapContainer* container) {
+ container->message = message_;
+ return 0;
+ }
+
+ int VisitMessageMapContainer(MessageMapContainer* container) {
+ container->message = message_;
+ return 0;
+ }
+
private:
Message* message_;
};
@@ -500,6 +531,9 @@ int AssureWritable(CMessage* self) {
self->message->GetDescriptor());
self->message = prototype->New();
self->owner.reset(self->message);
+ // Cascade the new owner to eventual children: even if this message is
+ // empty, some submessages or repeated containers might exist already.
+ SetOwner(self, self->owner);
} else {
// Otherwise, we need a mutable child message.
if (AssureWritable(self->parent) == -1)
@@ -520,8 +554,9 @@ int AssureWritable(CMessage* self) {
// When a CMessage is made writable its Message pointer is updated
// to point to a new mutable Message. When that happens we need to
// update any references to the old, read-only CMessage. There are
- // three places such references occur: RepeatedScalarContainer,
- // RepeatedCompositeContainer, and ExtensionDict.
+ // five places such references occur: RepeatedScalarContainer,
+ // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer,
+ // and ExtensionDict.
if (self->extensions != NULL)
self->extensions->message = self->message;
if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1)
@@ -583,15 +618,43 @@ const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
return PyFieldDescriptor_AsDescriptor(extension);
}
+// If value is a string, convert it into an enum value based on the labels in
+// descriptor, otherwise simply return value. Always returns a new reference.
+static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
+ PyObject* value) {
+ if (PyString_Check(value) || PyUnicode_Check(value)) {
+ const EnumDescriptor* enum_descriptor = descriptor.enum_type();
+ if (enum_descriptor == NULL) {
+ PyErr_SetString(PyExc_TypeError, "not an enum field");
+ return NULL;
+ }
+ char* enum_label;
+ Py_ssize_t size;
+ if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) {
+ return NULL;
+ }
+ const EnumValueDescriptor* enum_value_descriptor =
+ enum_descriptor->FindValueByName(string(enum_label, size));
+ if (enum_value_descriptor == NULL) {
+ PyErr_SetString(PyExc_ValueError, "unknown enum label");
+ return NULL;
+ }
+ return PyInt_FromLong(enum_value_descriptor->number());
+ }
+ Py_INCREF(value);
+ return value;
+}
+
// If cmessage_list is not NULL, this function releases values into the
// container CMessages instead of just removing. Repeated composite container
// needs to do this to make sure CMessages stay alive if they're still
// referenced after deletion. Repeated scalar container doesn't need to worry.
int InternalDeleteRepeatedField(
- Message* message,
+ CMessage* self,
const FieldDescriptor* field_descriptor,
PyObject* slice,
PyObject* cmessage_list) {
+ Message* message = self->message;
Py_ssize_t length, from, to, step, slice_length;
const Reflection* reflection = message->GetReflection();
int min, max;
@@ -665,7 +728,7 @@ int InternalDeleteRepeatedField(
CMessage* last_cmessage = reinterpret_cast<CMessage*>(
PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1));
repeated_composite_container::ReleaseLastTo(
- field_descriptor, message, last_cmessage);
+ self, field_descriptor, last_cmessage);
if (PySequence_DelItem(cmessage_list, -1) < 0) {
return -1;
}
@@ -696,16 +759,90 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
PyString_AsString(name));
return -1;
}
- if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->is_map()) {
+ ScopedPyObjectPtr map(GetAttr(self, name));
+ const FieldDescriptor* value_descriptor =
+ descriptor->message_type()->FindFieldByName("value");
+ if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ Py_ssize_t map_pos = 0;
+ PyObject* map_key;
+ PyObject* map_value;
+ while (PyDict_Next(value, &map_pos, &map_key, &map_value)) {
+ ScopedPyObjectPtr function_return;
+ function_return.reset(PyObject_GetItem(map.get(), map_key));
+ if (function_return.get() == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr ok(PyObject_CallMethod(
+ function_return.get(), "MergeFrom", "O", map_value));
+ if (ok.get() == NULL) {
+ return -1;
+ }
+ }
+ } else {
+ ScopedPyObjectPtr function_return;
+ function_return.reset(
+ PyObject_CallMethod(map.get(), "update", "O", value));
+ if (function_return.get() == NULL) {
+ return -1;
+ }
+ }
+ } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
ScopedPyObjectPtr container(GetAttr(self, name));
if (container == NULL) {
return -1;
}
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- if (repeated_composite_container::Extend(
- reinterpret_cast<RepeatedCompositeContainer*>(container.get()),
- value)
- == NULL) {
+ RepeatedCompositeContainer* rc_container =
+ reinterpret_cast<RepeatedCompositeContainer*>(container.get());
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return -1;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter))) != NULL) {
+ PyObject* kwargs = (PyDict_Check(next) ? next.get() : NULL);
+ ScopedPyObjectPtr new_msg(
+ repeated_composite_container::Add(rc_container, NULL, kwargs));
+ if (new_msg == NULL) {
+ return -1;
+ }
+ if (kwargs == NULL) {
+ // next was not a dict, it's a message we need to merge
+ ScopedPyObjectPtr merged(
+ MergeFrom(reinterpret_cast<CMessage*>(new_msg.get()), next));
+ if (merged == NULL) {
+ return -1;
+ }
+ }
+ }
+ if (PyErr_Occurred()) {
+ // Check to see how PyIter_Next() exited.
+ return -1;
+ }
+ } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ RepeatedScalarContainer* rs_container =
+ reinterpret_cast<RepeatedScalarContainer*>(container.get());
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return -1;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter))) != NULL) {
+ ScopedPyObjectPtr enum_value(GetIntegerEnumValue(*descriptor, next));
+ if (enum_value == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr new_msg(
+ repeated_scalar_container::Append(rs_container, enum_value));
+ if (new_msg == NULL) {
+ return -1;
+ }
+ }
+ if (PyErr_Occurred()) {
+ // Check to see how PyIter_Next() exited.
return -1;
}
} else {
@@ -721,12 +858,26 @@ int InitAttributes(CMessage* self, PyObject* kwargs) {
if (message == NULL) {
return -1;
}
- if (MergeFrom(reinterpret_cast<CMessage*>(message.get()),
- value) == NULL) {
- return -1;
+ CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
+ if (PyDict_Check(value)) {
+ if (InitAttributes(cmessage, value) < 0) {
+ return -1;
+ }
+ } else {
+ ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
+ if (merged == NULL) {
+ return -1;
+ }
}
} else {
- if (SetAttr(self, name, value) < 0) {
+ ScopedPyObjectPtr new_val;
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ new_val.reset(GetIntegerEnumValue(*descriptor, value));
+ if (new_val == NULL) {
+ return -1;
+ }
+ }
+ if (SetAttr(self, name, (new_val == NULL) ? value : new_val) < 0) {
return -1;
}
}
@@ -789,7 +940,6 @@ static PyObject* New(PyTypeObject* type,
}
self->message = default_message->New();
self->owner.reset(self->message);
-
return reinterpret_cast<PyObject*>(self);
}
@@ -830,6 +980,16 @@ struct ClearWeakReferences : public ChildVisitor {
return 0;
}
+ int VisitScalarMapContainer(ScalarMapContainer* container) {
+ container->parent = NULL;
+ return 0;
+ }
+
+ int VisitMessageMapContainer(MessageMapContainer* container) {
+ container->parent = NULL;
+ return 0;
+ }
+
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
cmessage->parent = NULL;
@@ -1064,6 +1224,16 @@ struct SetOwnerVisitor : public ChildVisitor {
return 0;
}
+ int VisitScalarMapContainer(ScalarMapContainer* container) {
+ scalar_map_container::SetOwner(container, new_owner_);
+ return 0;
+ }
+
+ int VisitMessageMapContainer(MessageMapContainer* container) {
+ message_map_container::SetOwner(container, new_owner_);
+ return 0;
+ }
+
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
return SetOwner(cmessage, new_owner_);
@@ -1084,11 +1254,11 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
// Releases the message specified by 'field' and returns the
// pointer. If the field does not exist a new message is created using
// 'descriptor'. The caller takes ownership of the returned pointer.
-Message* ReleaseMessage(Message* message,
+Message* ReleaseMessage(CMessage* self,
const Descriptor* descriptor,
const FieldDescriptor* field_descriptor) {
- Message* released_message = message->GetReflection()->ReleaseMessage(
- message, field_descriptor, message_factory);
+ Message* released_message = self->message->GetReflection()->ReleaseMessage(
+ self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
// child_cmessage->message, if the field does not exist. In this case,
// the latter points to the default instance via a const_cast<>, so we
@@ -1102,12 +1272,12 @@ Message* ReleaseMessage(Message* message,
return released_message;
}
-int ReleaseSubMessage(Message* message,
+int ReleaseSubMessage(CMessage* self,
const FieldDescriptor* field_descriptor,
CMessage* child_cmessage) {
// Release the Message
shared_ptr<Message> released_message(ReleaseMessage(
- message, child_cmessage->message->GetDescriptor(), field_descriptor));
+ self, child_cmessage->message->GetDescriptor(), field_descriptor));
child_cmessage->message = released_message.get();
child_cmessage->owner.swap(released_message);
child_cmessage->parent = NULL;
@@ -1119,8 +1289,8 @@ int ReleaseSubMessage(Message* message,
struct ReleaseChild : public ChildVisitor {
// message must outlive this object.
- explicit ReleaseChild(Message* parent_message) :
- parent_message_(parent_message) {}
+ explicit ReleaseChild(CMessage* parent) :
+ parent_(parent) {}
int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
return repeated_composite_container::Release(
@@ -1132,23 +1302,33 @@ struct ReleaseChild : public ChildVisitor {
reinterpret_cast<RepeatedScalarContainer*>(container));
}
+ int VisitScalarMapContainer(ScalarMapContainer* container) {
+ return scalar_map_container::Release(
+ reinterpret_cast<ScalarMapContainer*>(container));
+ }
+
+ int VisitMessageMapContainer(MessageMapContainer* container) {
+ return message_map_container::Release(
+ reinterpret_cast<MessageMapContainer*>(container));
+ }
+
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
- return ReleaseSubMessage(parent_message_, field_descriptor,
+ return ReleaseSubMessage(parent_, field_descriptor,
reinterpret_cast<CMessage*>(cmessage));
}
- Message* parent_message_;
+ CMessage* parent_;
};
int InternalReleaseFieldByDescriptor(
+ CMessage* self,
const FieldDescriptor* field_descriptor,
- PyObject* composite_field,
- Message* parent_message) {
+ PyObject* composite_field) {
return VisitCompositeField(
field_descriptor,
composite_field,
- ReleaseChild(parent_message));
+ ReleaseChild(self));
}
PyObject* ClearFieldByDescriptor(
@@ -1200,8 +1380,8 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
// Only release the field if there's a possibility that there are
// references to it.
if (composite_field != NULL) {
- if (InternalReleaseFieldByDescriptor(field_descriptor,
- composite_field, message) < 0) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor,
+ composite_field) < 0) {
return NULL;
}
PyDict_DelItem(self->composite_fields, arg);
@@ -1219,7 +1399,7 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
PyObject* Clear(CMessage* self) {
AssureWritable(self);
- if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1)
+ if (ForEachCompositeField(self, ReleaseChild(self)) == -1)
return NULL;
// The old ExtensionDict still aliases this CMessage, but all its
@@ -1582,7 +1762,8 @@ static PyObject* ListFields(CMessage* self) {
}
if (fields[i]->is_extension()) {
- ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(fields[i]));
+ ScopedPyObjectPtr extension_field(
+ PyFieldDescriptor_FromDescriptor(fields[i]));
if (extension_field == NULL) {
return NULL;
}
@@ -1616,7 +1797,8 @@ static PyObject* ListFields(CMessage* self) {
PyErr_SetString(PyExc_ValueError, "bad string");
return NULL;
}
- ScopedPyObjectPtr field_descriptor(PyFieldDescriptor_New(fields[i]));
+ ScopedPyObjectPtr field_descriptor(
+ PyFieldDescriptor_FromDescriptor(fields[i]));
if (field_descriptor == NULL) {
return NULL;
}
@@ -1683,10 +1865,8 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
}
}
-PyObject* InternalGetScalar(
- CMessage* self,
- const FieldDescriptor* field_descriptor) {
- Message* message = self->message;
+PyObject* InternalGetScalar(const Message* message,
+ const FieldDescriptor* field_descriptor) {
const Reflection* reflection = message->GetReflection();
if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
@@ -1739,12 +1919,12 @@ PyObject* InternalGetScalar(
if (!message->GetReflection()->SupportsUnknownEnumValues() &&
!message->GetReflection()->HasField(*message, field_descriptor)) {
// Look for the value in the unknown fields.
- UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- for (int i = 0; i < unknown_field_set->field_count(); ++i) {
- if (unknown_field_set->field(i).number() ==
+ const UnknownFieldSet& unknown_field_set =
+ message->GetReflection()->GetUnknownFields(*message);
+ for (int i = 0; i < unknown_field_set.field_count(); ++i) {
+ if (unknown_field_set.field(i).number() ==
field_descriptor->number()) {
- result = PyInt_FromLong(unknown_field_set->field(i).varint());
+ result = PyInt_FromLong(unknown_field_set.field(i).varint());
break;
}
}
@@ -1793,21 +1973,16 @@ PyObject* InternalGetSubMessage(
return reinterpret_cast<PyObject*>(cmsg);
}
-int InternalSetScalar(
- CMessage* self,
+int InternalSetNonOneofScalar(
+ Message* message,
const FieldDescriptor* field_descriptor,
PyObject* arg) {
- Message* message = self->message;
const Reflection* reflection = message->GetReflection();
if (!CheckFieldBelongsToMessage(field_descriptor, message)) {
return -1;
}
- if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
- return -1;
- }
-
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
GOOGLE_CHECK_GET_INT32(arg, value, -1);
@@ -1878,6 +2053,21 @@ int InternalSetScalar(
return 0;
}
+int InternalSetScalar(
+ CMessage* self,
+ const FieldDescriptor* field_descriptor,
+ PyObject* arg) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
+ return -1;
+ }
+
+ if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
+ return -1;
+ }
+
+ return InternalSetNonOneofScalar(self->message, field_descriptor, arg);
+}
+
PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
PyObject* py_cmsg = PyObject_CallObject(
reinterpret_cast<PyObject*>(cls), NULL);
@@ -1955,7 +2145,8 @@ static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) {
// which was built previously.
for (int i = 0; i < message_descriptor->enum_type_count(); ++i) {
const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i);
- ScopedPyObjectPtr enum_type(PyEnumDescriptor_New(enum_descriptor));
+ ScopedPyObjectPtr enum_type(
+ PyEnumDescriptor_FromDescriptor(enum_descriptor));
if (enum_type == NULL) {
return NULL;
}
@@ -1993,7 +2184,7 @@ static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) {
// which was defined previously.
for (int i = 0; i < message_descriptor->extension_count(); ++i) {
const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i);
- ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(field));
+ ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
if (extension_field == NULL) {
return NULL;
}
@@ -2097,26 +2288,6 @@ PyObject* SetState(CMessage* self, PyObject* state) {
}
// CMessage static methods:
-PyObject* _GetMessageDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindMessageByName(GetDescriptorPool(), arg);
-}
-
-PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindFieldByName(GetDescriptorPool(), arg);
-}
-
-PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindExtensionByName(GetDescriptorPool(), arg);
-}
-
-PyObject* _GetEnumDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindEnumTypeByName(GetDescriptorPool(), arg);
-}
-
-PyObject* _GetOneofDescriptor(PyObject* unused, PyObject* arg) {
- return cdescriptor_pool::FindOneofByName(GetDescriptorPool(), arg);
-}
-
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
PyObject* unused_arg) {
if (!_CalledFromGeneratedFile(1)) {
@@ -2188,21 +2359,6 @@ static PyMethodDef Methods[] = {
"or None if no field is set." },
// Static Methods.
- { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC,
- "Registers a new protocol buffer file in the global C++ descriptor pool." },
- { "_GetMessageDescriptor", (PyCFunction)_GetMessageDescriptor,
- METH_O | METH_STATIC, "Finds a message descriptor in the message pool." },
- { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor,
- METH_O | METH_STATIC, "Finds a field descriptor in the message pool." },
- { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor,
- METH_O | METH_STATIC,
- "Finds a extension descriptor in the message pool." },
- { "_GetEnumDescriptor", (PyCFunction)_GetEnumDescriptor,
- METH_O | METH_STATIC,
- "Finds an enum descriptor in the message pool." },
- { "_GetOneofDescriptor", (PyCFunction)_GetOneofDescriptor,
- METH_O | METH_STATIC,
- "Finds an oneof descriptor in the message pool." },
{ "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile,
METH_NOARGS | METH_STATIC,
"Raises TypeError if the caller is not in a _pb2.py file."},
@@ -2234,6 +2390,31 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
reinterpret_cast<PyObject*>(self), name);
}
+ if (field_descriptor->is_map()) {
+ PyObject* py_container = NULL;
+ const Descriptor* entry_type = field_descriptor->message_type();
+ const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
+ if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyObject* value_class = cdescriptor_pool::GetMessageClass(
+ GetDescriptorPool(), value_type->message_type());
+ if (value_class == NULL) {
+ return NULL;
+ }
+ py_container = message_map_container::NewContainer(self, field_descriptor,
+ value_class);
+ } else {
+ py_container = scalar_map_container::NewContainer(self, field_descriptor);
+ }
+ if (py_container == NULL) {
+ return NULL;
+ }
+ if (!SetCompositeField(self, name, py_container)) {
+ Py_DECREF(py_container);
+ return NULL;
+ }
+ return py_container;
+ }
+
if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyObject* py_container = NULL;
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -2267,7 +2448,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
return sub_message;
}
- return InternalGetScalar(self, field_descriptor);
+ return InternalGetScalar(self->message, field_descriptor);
}
int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
@@ -2304,9 +2485,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
PyTypeObject CMessage_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- "google.protobuf.pyext._message."
- "CMessage", // tp_name
+ FULL_MODULE_NAME ".CMessage", // tp_name
sizeof(CMessage), // tp_basicsize
0, // tp_itemsize
(destructor)cmessage::Dealloc, // tp_dealloc
@@ -2401,7 +2580,7 @@ void InitGlobals() {
k_extensions_by_name = PyString_FromString("_extensions_by_name");
k_extensions_by_number = PyString_FromString("_extensions_by_number");
- message_factory = new DynamicMessageFactory(GetDescriptorPool()->pool);
+ message_factory = new DynamicMessageFactory();
message_factory->SetDelegateToGeneratedFactory(true);
}
@@ -2469,6 +2648,61 @@ bool InitProto2MessageModule(PyObject *m) {
reinterpret_cast<PyObject*>(
&RepeatedCompositeContainer_Type));
+ // ScalarMapContainer_Type derives from our MutableMapping type.
+ PyObject* containers =
+ PyImport_ImportModule("google.protobuf.internal.containers");
+ if (containers == NULL) {
+ return false;
+ }
+
+ PyObject* mutable_mapping =
+ PyObject_GetAttrString(containers, "MutableMapping");
+ Py_DECREF(containers);
+
+ if (mutable_mapping == NULL) {
+ return false;
+ }
+
+ if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) {
+ Py_DECREF(mutable_mapping);
+ return false;
+ }
+
+ ScalarMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping);
+
+ if (PyType_Ready(&ScalarMapContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "ScalarMapContainer",
+ reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
+
+ if (PyType_Ready(&ScalarMapIterator_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "ScalarMapIterator",
+ reinterpret_cast<PyObject*>(&ScalarMapIterator_Type));
+
+ Py_INCREF(mutable_mapping);
+ MessageMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping);
+
+ if (PyType_Ready(&MessageMapContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "MessageMapContainer",
+ reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
+
+ if (PyType_Ready(&MessageMapIterator_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "MessageMapIterator",
+ reinterpret_cast<PyObject*>(&MessageMapIterator_Type));
+
ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented;
if (PyType_Ready(&ExtensionDict_Type) < 0) {
return false;
@@ -2478,6 +2712,12 @@ bool InitProto2MessageModule(PyObject *m) {
m, "ExtensionDict",
reinterpret_cast<PyObject*>(&ExtensionDict_Type));
+ // Expose the DescriptorPool used to hold all descriptors added from generated
+ // pb2.py files.
+ Py_INCREF(GetDescriptorPool()); // PyModule_AddObject steals a reference.
+ PyModule_AddObject(
+ m, "default_pool", reinterpret_cast<PyObject*>(GetDescriptorPool()));
+
// This implementation provides full Descriptor types, we advertise it so that
// descriptor.py can use them in replacement of the Python classes.
PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1);
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 2f2da795..7360b207 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -120,7 +120,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor* descriptor);
// A new message will be created if this is a read-only default instance.
//
// Corresponds to reflection api method ReleaseMessage.
-int ReleaseSubMessage(Message* message,
+int ReleaseSubMessage(CMessage* self,
const FieldDescriptor* field_descriptor,
CMessage* child_cmessage);
@@ -144,7 +144,7 @@ PyObject* InternalGetSubMessage(
// by slice will be removed from cmessage_list by this function.
//
// Corresponds to reflection api method RemoveLast.
-int InternalDeleteRepeatedField(Message* message,
+int InternalDeleteRepeatedField(CMessage* self,
const FieldDescriptor* field_descriptor,
PyObject* slice, PyObject* cmessage_list);
@@ -153,10 +153,15 @@ int InternalSetScalar(CMessage* self,
const FieldDescriptor* field_descriptor,
PyObject* value);
+// Sets the specified scalar value to the message. Requires it is not a Oneof.
+int InternalSetNonOneofScalar(Message* message,
+ const FieldDescriptor* field_descriptor,
+ PyObject* arg);
+
// Retrieves the specified scalar value from the message.
//
// Returns a new python reference.
-PyObject* InternalGetScalar(CMessage* self,
+PyObject* InternalGetScalar(const Message* message,
const FieldDescriptor* field_descriptor);
// Clears the message, removing all contained data. Extension dictionary and
@@ -279,7 +284,7 @@ extern PyObject* kint64min_py;
extern PyObject* kint64max_py;
extern PyObject* kuint64max_py;
-#define C(str) const_cast<char*>(str)
+#define FULL_MODULE_NAME "google.protobuf.pyext._message"
void FormatTypeError(PyObject* arg, char* expected_types);
template<class T>
diff --git a/python/google/protobuf/pyext/message_map_container.cc b/python/google/protobuf/pyext/message_map_container.cc
new file mode 100644
index 00000000..ab8d8fb9
--- /dev/null
+++ b/python/google/protobuf/pyext/message_map_container.cc
@@ -0,0 +1,540 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: haberman@google.com (Josh Haberman)
+
+#include <google/protobuf/pyext/message_map_container.h>
+
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+struct MessageMapIterator {
+ PyObject_HEAD;
+
+ // This dict contains the full contents of what we want to iterate over.
+ // There's no way to avoid building this, because the list representation
+ // (which is canonical) can contain duplicate keys. So at the very least we
+ // need a set that lets us skip duplicate keys. And at the point that we're
+ // doing that, we might as well just build the actual dict we're iterating
+ // over and use dict's built-in iterator.
+ PyObject* dict;
+
+ // An iterator on dict.
+ PyObject* iter;
+
+ // A pointer back to the container, so we can notice changes to the version.
+ MessageMapContainer* container;
+
+ // The version of the map when we took the iterator to it.
+ //
+ // We store this so that if the map is modified during iteration we can throw
+ // an error.
+ uint64 version;
+};
+
+static MessageMapIterator* GetIter(PyObject* obj) {
+ return reinterpret_cast<MessageMapIterator*>(obj);
+}
+
+namespace message_map_container {
+
+static MessageMapContainer* GetMap(PyObject* obj) {
+ return reinterpret_cast<MessageMapContainer*>(obj);
+}
+
+// The private constructor of MessageMapContainer objects.
+PyObject* NewContainer(CMessage* parent,
+ const google::protobuf::FieldDescriptor* parent_field_descriptor,
+ PyObject* concrete_class) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+ PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ MessageMapContainer* self = GetMap(obj);
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ self->message_dict = PyDict_New();
+ if (self->message_dict == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate message dict.");
+ }
+
+ Py_INCREF(concrete_class);
+ self->subclass_init = concrete_class;
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ Py_DECREF(obj);
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj;
+}
+
+// Initializes the underlying Message object of "to" so it becomes a new parent
+// repeated scalar, and copies all the values from "from" to it. A child scalar
+// container can be released by passing it as both from and to (e.g. making it
+// the recipient of the new parent message and copying the values from itself).
+static int InitializeAndCopyToParentContainer(
+ MessageMapContainer* from,
+ MessageMapContainer* to) {
+ // For now we require from == to, re-evaluate if we want to support deep copy
+ // as in repeated_composite_container.cc.
+ GOOGLE_DCHECK(from == to);
+ Message* old_message = from->message;
+ Message* new_message = old_message->New();
+ to->parent = NULL;
+ to->parent_field_descriptor = from->parent_field_descriptor;
+ to->message = new_message;
+ to->owner.reset(new_message);
+
+ vector<const FieldDescriptor*> fields;
+ fields.push_back(from->parent_field_descriptor);
+ old_message->GetReflection()->SwapFields(old_message, new_message, fields);
+ return 0;
+}
+
+static PyObject* GetCMessage(MessageMapContainer* self, Message* entry) {
+ // Get or create the CMessage object corresponding to this message.
+ Message* message = entry->GetReflection()->MutableMessage(
+ entry, self->value_field_descriptor);
+ ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
+ PyObject* ret = PyDict_GetItem(self->message_dict, key);
+
+ if (ret == NULL) {
+ CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
+ message->GetDescriptor());
+ ret = reinterpret_cast<PyObject*>(cmsg);
+
+ if (cmsg == NULL) {
+ return NULL;
+ }
+ cmsg->owner = self->owner;
+ cmsg->message = message;
+ cmsg->parent = self->parent;
+
+ if (PyDict_SetItem(self->message_dict, key, ret) < 0) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+ } else {
+ Py_INCREF(ret);
+ }
+
+ return ret;
+}
+
+int Release(MessageMapContainer* self) {
+ InitializeAndCopyToParentContainer(self, self);
+ return 0;
+}
+
+void SetOwner(MessageMapContainer* self,
+ const shared_ptr<Message>& new_owner) {
+ self->owner = new_owner;
+}
+
+Py_ssize_t Length(PyObject* _self) {
+ MessageMapContainer* self = GetMap(_self);
+ google::protobuf::Message* message = self->message;
+ return message->GetReflection()->FieldSize(*message,
+ self->parent_field_descriptor);
+}
+
+int MapKeyMatches(MessageMapContainer* self, const Message* entry,
+ PyObject* key) {
+ // TODO(haberman): do we need more strict type checking?
+ ScopedPyObjectPtr entry_key(
+ cmessage::InternalGetScalar(entry, self->key_field_descriptor));
+ int ret = PyObject_RichCompareBool(key, entry_key, Py_EQ);
+ return ret;
+}
+
+int SetItem(PyObject *_self, PyObject *key, PyObject *v) {
+ if (v) {
+ PyErr_Format(PyExc_ValueError,
+ "Direct assignment of submessage not allowed");
+ return -1;
+ }
+
+ // Now we know that this is a delete, not a set.
+
+ MessageMapContainer* self = GetMap(_self);
+ cmessage::AssureWritable(self->parent);
+
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search. We need to search from the end because the underlying
+ // representation can have duplicates if a user calls MergeFrom(); the last
+ // one needs to win.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ bool found = false;
+ for (int i = size - 1; i >= 0; i--) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, entry, key);
+ if (matches < 0) return -1;
+ if (matches) {
+ found = true;
+ if (i != size - 1) {
+ reflection->SwapElements(message, self->parent_field_descriptor, i,
+ size - 1);
+ }
+ reflection->RemoveLast(message, self->parent_field_descriptor);
+
+ // Can't exit now, the repeated field representation of maps allows
+ // duplicate keys, and we have to be sure to remove all of them.
+ }
+ }
+
+ if (!found) {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+
+ self->version++;
+
+ return 0;
+}
+
+PyObject* GetIterator(PyObject *_self) {
+ MessageMapContainer* self = GetMap(_self);
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&MessageMapIterator_Type, 0));
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
+ }
+
+ MessageMapIterator* iter = GetIter(obj);
+
+ Py_INCREF(self);
+ iter->container = self;
+ iter->version = self->version;
+ iter->dict = PyDict_New();
+ if (iter->dict == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate dict for iterator.");
+ }
+
+ // Build the entire map into a dict right now. Start from the beginning so
+ // that later entries win in the case of duplicates.
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search. We need to search from the end because the underlying
+ // representation can have duplicates if a user calls MergeFrom(); the last
+ // one needs to win.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = size - 1; i >= 0; i--) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ ScopedPyObjectPtr key(
+ cmessage::InternalGetScalar(entry, self->key_field_descriptor));
+ if (PyDict_SetItem(iter->dict, key.get(), GetCMessage(self, entry)) < 0) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "SetItem failed in iterator construction.");
+ }
+ }
+
+ iter->iter = PyObject_GetIter(iter->dict);
+
+ return obj.release();
+}
+
+PyObject* GetItem(PyObject* _self, PyObject* key) {
+ MessageMapContainer* self = GetMap(_self);
+ cmessage::AssureWritable(self->parent);
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search. We need to search from the end because the underlying
+ // representation can have duplicates if a user calls MergeFrom(); the last
+ // one needs to win.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = size - 1; i >= 0; i--) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, entry, key);
+ if (matches < 0) return NULL;
+ if (matches) {
+ return GetCMessage(self, entry);
+ }
+ }
+
+ // Key is not already present; insert a new entry.
+ Message* entry =
+ reflection->AddMessage(message, self->parent_field_descriptor);
+
+ self->version++;
+
+ if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
+ key) < 0) {
+ reflection->RemoveLast(message, self->parent_field_descriptor);
+ return NULL;
+ }
+
+ return GetCMessage(self, entry);
+}
+
+PyObject* Contains(PyObject* _self, PyObject* key) {
+ MessageMapContainer* self = GetMap(_self);
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = 0; i < size; i++) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, entry, key);
+ if (matches < 0) return NULL;
+ if (matches) {
+ Py_RETURN_TRUE;
+ }
+ }
+
+ Py_RETURN_FALSE;
+}
+
+PyObject* Clear(PyObject* _self) {
+ MessageMapContainer* self = GetMap(_self);
+ cmessage::AssureWritable(self->parent);
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ self->version++;
+ reflection->ClearField(message, self->parent_field_descriptor);
+
+ Py_RETURN_NONE;
+}
+
+PyObject* Get(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return GetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static PyMappingMethods MpMethods = {
+ Length, // mp_length
+ GetItem, // mp_subscript
+ SetItem, // mp_ass_subscript
+};
+
+static void Dealloc(PyObject* _self) {
+ MessageMapContainer* self = GetMap(_self);
+ self->owner.reset();
+ Py_DECREF(self->message_dict);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMethodDef Methods[] = {
+ { "__contains__", (PyCFunction)Contains, METH_O,
+ "Tests whether the map contains this element."},
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map."},
+ { "get", Get, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ { "get_or_create", GetItem, METH_O,
+ "Alias for getitem, useful to make explicit that the map is mutated." },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+} // namespace message_map_container
+
+namespace message_map_iterator {
+
+static void Dealloc(PyObject* _self) {
+ MessageMapIterator* self = GetIter(_self);
+ Py_DECREF(self->dict);
+ Py_DECREF(self->iter);
+ Py_DECREF(self->container);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+PyObject* IterNext(PyObject* _self) {
+ MessageMapIterator* self = GetIter(_self);
+
+ // This won't catch mutations to the map performed by MergeFrom(); no easy way
+ // to address that.
+ if (self->version != self->container->version) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Map modified during iteration.");
+ }
+
+ return PyIter_Next(self->iter);
+}
+
+} // namespace message_map_iterator
+
+PyTypeObject MessageMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MessageMapContainer", // tp_name
+ sizeof(MessageMapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ message_map_container::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &message_map_container::MpMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A map container for message", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ message_map_container::GetIterator, // tp_iter
+ 0, // tp_iternext
+ message_map_container::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+PyTypeObject MessageMapIterator_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".MessageMapIterator", // tp_name
+ sizeof(MessageMapIterator), // tp_basicsize
+ 0, // tp_itemsize
+ message_map_iterator::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map iterator", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ PyObject_SelfIter, // tp_iter
+ message_map_iterator::IterNext, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/message_map_container.h b/python/google/protobuf/pyext/message_map_container.h
new file mode 100644
index 00000000..4ca0aecc
--- /dev/null
+++ b/python/google/protobuf/pyext/message_map_container.h
@@ -0,0 +1,117 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/descriptor.h>
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CMessage;
+
+struct MessageMapContainer {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python MessageMapContainer holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Pointer to the C++ Message that contains this container. The
+ // MessageMapContainer does not own this pointer.
+ Message* message;
+
+ // Weak reference to a parent CMessage object (i.e. may be NULL.)
+ //
+ // Used to make sure all ancestors are also mutable when first
+ // modifying the container.
+ CMessage* parent;
+
+ // Pointer to the parent's descriptor that describes this
+ // field. Used together with the parent's message when making a
+ // default message instance mutable.
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
+ const FieldDescriptor* key_field_descriptor;
+ const FieldDescriptor* value_field_descriptor;
+
+ // A callable that is used to create new child messages.
+ PyObject* subclass_init;
+
+ // A dict mapping Message* -> CMessage.
+ PyObject* message_dict;
+
+ // We bump this whenever we perform a mutation, to invalidate existing
+ // iterators.
+ uint64 version;
+};
+
+extern PyTypeObject MessageMapContainer_Type;
+extern PyTypeObject MessageMapIterator_Type;
+
+namespace message_map_container {
+
+// Builds a MessageMapContainer object, from a parent message and a
+// field descriptor.
+extern PyObject* NewContainer(CMessage* parent,
+ const FieldDescriptor* parent_field_descriptor,
+ PyObject* concrete_class);
+
+// Releases the messages in the container to a new message.
+//
+// Returns 0 on success, -1 on failure.
+int Release(MessageMapContainer* self);
+
+// Set the owner field of self and any children of self.
+void SetOwner(MessageMapContainer* self,
+ const shared_ptr<Message>& new_owner);
+
+} // namespace message_map_container
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index 0fe98e73..86b75d0f 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -367,8 +367,8 @@ int AssignSubscript(RepeatedCompositeContainer* self,
}
// Delete from the underlying Message, if any.
- if (self->message != NULL) {
- if (cmessage::InternalDeleteRepeatedField(self->message,
+ if (self->parent != NULL) {
+ if (cmessage::InternalDeleteRepeatedField(self->parent,
self->parent_field_descriptor,
slice,
self->child_messages) < 0) {
@@ -572,47 +572,35 @@ static PyObject* Pop(RepeatedCompositeContainer* self,
return item;
}
-// The caller takes ownership of the returned Message.
-Message* ReleaseLast(const FieldDescriptor* field,
- const Descriptor* type,
- Message* message) {
+// Release field of parent message and transfer the ownership to target.
+void ReleaseLastTo(CMessage* parent,
+ const FieldDescriptor* field,
+ CMessage* target) {
+ GOOGLE_CHECK_NOTNULL(parent);
GOOGLE_CHECK_NOTNULL(field);
- GOOGLE_CHECK_NOTNULL(type);
- GOOGLE_CHECK_NOTNULL(message);
+ GOOGLE_CHECK_NOTNULL(target);
- Message* released_message = message->GetReflection()->ReleaseLast(
- message, field);
+ shared_ptr<Message> released_message(
+ parent->message->GetReflection()->ReleaseLast(parent->message, field));
// TODO(tibell): Deal with proto1.
// ReleaseMessage will return NULL which differs from
// child_cmessage->message, if the field does not exist. In this case,
// the latter points to the default instance via a const_cast<>, so we
// have to reset it to a new mutable object since we are taking ownership.
- if (released_message == NULL) {
+ if (released_message.get() == NULL) {
const Message* prototype =
- cmessage::GetMessageFactory()->GetPrototype(type);
+ cmessage::GetMessageFactory()->GetPrototype(
+ target->message->GetDescriptor());
GOOGLE_CHECK_NOTNULL(prototype);
- return prototype->New();
- } else {
- return released_message;
+ released_message.reset(prototype->New());
}
-}
-// Release field of message and transfer the ownership to cmessage.
-void ReleaseLastTo(const FieldDescriptor* field,
- Message* message,
- CMessage* cmessage) {
- GOOGLE_CHECK_NOTNULL(field);
- GOOGLE_CHECK_NOTNULL(message);
- GOOGLE_CHECK_NOTNULL(cmessage);
-
- shared_ptr<Message> released_message(
- ReleaseLast(field, cmessage->message->GetDescriptor(), message));
- cmessage->parent = NULL;
- cmessage->parent_field_descriptor = NULL;
- cmessage->message = released_message.get();
- cmessage->read_only = false;
- cmessage::SetOwner(cmessage, released_message);
+ target->parent = NULL;
+ target->parent_field_descriptor = NULL;
+ target->message = released_message.get();
+ target->read_only = false;
+ cmessage::SetOwner(target, released_message);
}
// Called to release a container using
@@ -635,7 +623,7 @@ int Release(RepeatedCompositeContainer* self) {
for (Py_ssize_t i = size - 1; i >= 0; --i) {
CMessage* child_cmessage = reinterpret_cast<CMessage*>(
PyList_GET_ITEM(self->child_messages, i));
- ReleaseLastTo(field, message, child_cmessage);
+ ReleaseLastTo(self->parent, field, child_cmessage);
}
// Detach from containing message.
@@ -732,9 +720,7 @@ static PyMethodDef Methods[] = {
PyTypeObject RepeatedCompositeContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- "google.protobuf.pyext._message."
- "RepeatedCompositeContainer", // tp_name
+ FULL_MODULE_NAME ".RepeatedCompositeContainer", // tp_name
sizeof(RepeatedCompositeContainer), // tp_basicsize
0, // tp_itemsize
(destructor)repeated_composite_container::Dealloc, // tp_dealloc
diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h
index ce7cee0f..e0f21360 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.h
+++ b/python/google/protobuf/pyext/repeated_composite_container.h
@@ -161,13 +161,13 @@ int SetOwner(RepeatedCompositeContainer* self,
const shared_ptr<Message>& new_owner);
// Removes the last element of the repeated message field 'field' on
-// the Message 'message', and transfers the ownership of the released
-// Message to 'cmessage'.
+// the Message 'parent', and transfers the ownership of the released
+// Message to 'target'.
//
// Corresponds to reflection api method ReleaseMessage.
-void ReleaseLastTo(const FieldDescriptor* field,
- Message* message,
- CMessage* cmessage);
+void ReleaseLastTo(CMessage* parent,
+ const FieldDescriptor* field,
+ CMessage* target);
} // namespace repeated_composite_container
} // namespace python
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index 110a4c85..fd196836 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -102,7 +102,7 @@ static int AssignItem(RepeatedScalarContainer* self,
if (arg == NULL) {
ScopedPyObjectPtr py_index(PyLong_FromLong(index));
- return cmessage::InternalDeleteRepeatedField(message, field_descriptor,
+ return cmessage::InternalDeleteRepeatedField(self->parent, field_descriptor,
py_index, NULL);
}
@@ -470,7 +470,7 @@ static int AssSubscript(RepeatedScalarContainer* self,
if (value == NULL) {
return cmessage::InternalDeleteRepeatedField(
- message, field_descriptor, slice, NULL);
+ self->parent, field_descriptor, slice, NULL);
}
if (!create_list) {
@@ -769,9 +769,7 @@ static PyMethodDef Methods[] = {
PyTypeObject RepeatedScalarContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
- // Keep the fully qualified _message symbol in a line for opensource.
- "google.protobuf.pyext._message."
- "RepeatedScalarContainer", // tp_name
+ FULL_MODULE_NAME ".RepeatedScalarContainer", // tp_name
sizeof(RepeatedScalarContainer), // tp_basicsize
0, // tp_itemsize
(destructor)repeated_scalar_container::Dealloc, // tp_dealloc
diff --git a/python/google/protobuf/pyext/scalar_map_container.cc b/python/google/protobuf/pyext/scalar_map_container.cc
new file mode 100644
index 00000000..6f731d27
--- /dev/null
+++ b/python/google/protobuf/pyext/scalar_map_container.cc
@@ -0,0 +1,514 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: haberman@google.com (Josh Haberman)
+
+#include <google/protobuf/pyext/scalar_map_container.h>
+
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+struct ScalarMapIterator {
+ PyObject_HEAD;
+
+ // This dict contains the full contents of what we want to iterate over.
+ // There's no way to avoid building this, because the list representation
+ // (which is canonical) can contain duplicate keys. So at the very least we
+ // need a set that lets us skip duplicate keys. And at the point that we're
+ // doing that, we might as well just build the actual dict we're iterating
+ // over and use dict's built-in iterator.
+ PyObject* dict;
+
+ // An iterator on dict.
+ PyObject* iter;
+
+ // A pointer back to the container, so we can notice changes to the version.
+ ScalarMapContainer* container;
+
+ // The version of the map when we took the iterator to it.
+ //
+ // We store this so that if the map is modified during iteration we can throw
+ // an error.
+ uint64 version;
+};
+
+static ScalarMapIterator* GetIter(PyObject* obj) {
+ return reinterpret_cast<ScalarMapIterator*>(obj);
+}
+
+namespace scalar_map_container {
+
+static ScalarMapContainer* GetMap(PyObject* obj) {
+ return reinterpret_cast<ScalarMapContainer*>(obj);
+}
+
+// The private constructor of ScalarMapContainer objects.
+PyObject *NewContainer(
+ CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
+ if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
+ if (obj.get() == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate new container.");
+ }
+
+ ScalarMapContainer* self = GetMap(obj);
+
+ self->message = parent->message;
+ self->parent = parent;
+ self->parent_field_descriptor = parent_field_descriptor;
+ self->owner = parent->owner;
+ self->version = 0;
+
+ self->key_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("key");
+ self->value_field_descriptor =
+ parent_field_descriptor->message_type()->FindFieldByName("value");
+
+ if (self->key_field_descriptor == NULL ||
+ self->value_field_descriptor == NULL) {
+ return PyErr_Format(PyExc_KeyError,
+ "Map entry descriptor did not have key/value fields");
+ }
+
+ return obj.release();
+}
+
+// Initializes the underlying Message object of "to" so it becomes a new parent
+// repeated scalar, and copies all the values from "from" to it. A child scalar
+// container can be released by passing it as both from and to (e.g. making it
+// the recipient of the new parent message and copying the values from itself).
+static int InitializeAndCopyToParentContainer(
+ ScalarMapContainer* from,
+ ScalarMapContainer* to) {
+ // For now we require from == to, re-evaluate if we want to support deep copy
+ // as in repeated_scalar_container.cc.
+ GOOGLE_DCHECK(from == to);
+ Message* old_message = from->message;
+ Message* new_message = old_message->New();
+ to->parent = NULL;
+ to->parent_field_descriptor = from->parent_field_descriptor;
+ to->message = new_message;
+ to->owner.reset(new_message);
+
+ vector<const FieldDescriptor*> fields;
+ fields.push_back(from->parent_field_descriptor);
+ old_message->GetReflection()->SwapFields(old_message, new_message, fields);
+ return 0;
+}
+
+int Release(ScalarMapContainer* self) {
+ return InitializeAndCopyToParentContainer(self, self);
+}
+
+void SetOwner(ScalarMapContainer* self,
+ const shared_ptr<Message>& new_owner) {
+ self->owner = new_owner;
+}
+
+Py_ssize_t Length(PyObject* _self) {
+ ScalarMapContainer* self = GetMap(_self);
+ google::protobuf::Message* message = self->message;
+ return message->GetReflection()->FieldSize(*message,
+ self->parent_field_descriptor);
+}
+
+int MapKeyMatches(ScalarMapContainer* self, const Message* entry,
+ PyObject* key) {
+ // TODO(haberman): do we need more strict type checking?
+ ScopedPyObjectPtr entry_key(
+ cmessage::InternalGetScalar(entry, self->key_field_descriptor));
+ int ret = PyObject_RichCompareBool(key, entry_key, Py_EQ);
+ return ret;
+}
+
+PyObject* GetItem(PyObject* _self, PyObject* key) {
+ ScalarMapContainer* self = GetMap(_self);
+
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size = reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = size - 1; i >= 0; i--) {
+ const Message& entry = reflection->GetRepeatedMessage(
+ *message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, &entry, key);
+ if (matches < 0) return NULL;
+ if (matches) {
+ return cmessage::InternalGetScalar(&entry, self->value_field_descriptor);
+ }
+ }
+
+ // Need to add a new entry.
+ Message* entry =
+ reflection->AddMessage(message, self->parent_field_descriptor);
+ PyObject* ret = NULL;
+
+ if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
+ key) >= 0) {
+ ret = cmessage::InternalGetScalar(entry, self->value_field_descriptor);
+ }
+
+ self->version++;
+
+ // If there was a type error above, it set the Python exception.
+ return ret;
+}
+
+int SetItem(PyObject *_self, PyObject *key, PyObject *v) {
+ ScalarMapContainer* self = GetMap(_self);
+ cmessage::AssureWritable(self->parent);
+
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+ self->version++;
+
+ if (v) {
+ // Set item.
+ //
+ // Right now the Reflection API doesn't support map lookup, so we implement
+ // it via linear search.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ for (int i = size - 1; i >= 0; i--) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, entry, key);
+ if (matches < 0) return -1;
+ if (matches) {
+ return cmessage::InternalSetNonOneofScalar(
+ entry, self->value_field_descriptor, v);
+ }
+ }
+
+ // Key is not already present; insert a new entry.
+ Message* entry =
+ reflection->AddMessage(message, self->parent_field_descriptor);
+
+ if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
+ key) < 0 ||
+ cmessage::InternalSetNonOneofScalar(entry, self->value_field_descriptor,
+ v) < 0) {
+ reflection->RemoveLast(message, self->parent_field_descriptor);
+ return -1;
+ }
+
+ return 0;
+ } else {
+ bool found = false;
+ for (int i = size - 1; i >= 0; i--) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, entry, key);
+ if (matches < 0) return -1;
+ if (matches) {
+ found = true;
+ if (i != size - 1) {
+ reflection->SwapElements(message, self->parent_field_descriptor, i,
+ size - 1);
+ }
+ reflection->RemoveLast(message, self->parent_field_descriptor);
+
+ // Can't exit now, the repeated field representation of maps allows
+ // duplicate keys, and we have to be sure to remove all of them.
+ }
+ }
+
+ if (found) {
+ return 0;
+ } else {
+ PyErr_Format(PyExc_KeyError, "Key not present in map");
+ return -1;
+ }
+ }
+}
+
+PyObject* GetIterator(PyObject *_self) {
+ ScalarMapContainer* self = GetMap(_self);
+
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapIterator_Type, 0));
+ if (obj == NULL) {
+ return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
+ }
+
+ ScalarMapIterator* iter = GetIter(obj.get());
+
+ Py_INCREF(self);
+ iter->container = self;
+ iter->version = self->version;
+ iter->dict = PyDict_New();
+ if (iter->dict == NULL) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Could not allocate dict for iterator.");
+ }
+
+ // Build the entire map into a dict right now. Start from the beginning so
+ // that later entries win in the case of duplicates.
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search. We need to search from the end because the underlying
+ // representation can have duplicates if a user calls MergeFrom(); the last
+ // one needs to win.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size =
+ reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = 0; i < size; i++) {
+ Message* entry = reflection->MutableRepeatedMessage(
+ message, self->parent_field_descriptor, i);
+ ScopedPyObjectPtr key(
+ cmessage::InternalGetScalar(entry, self->key_field_descriptor));
+ ScopedPyObjectPtr val(
+ cmessage::InternalGetScalar(entry, self->value_field_descriptor));
+ if (PyDict_SetItem(iter->dict, key.get(), val.get()) < 0) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "SetItem failed in iterator construction.");
+ }
+ }
+
+
+ iter->iter = PyObject_GetIter(iter->dict);
+
+
+ return obj.release();
+}
+
+PyObject* Clear(PyObject* _self) {
+ ScalarMapContainer* self = GetMap(_self);
+ cmessage::AssureWritable(self->parent);
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ reflection->ClearField(message, self->parent_field_descriptor);
+
+ Py_RETURN_NONE;
+}
+
+PyObject* Contains(PyObject* _self, PyObject* key) {
+ ScalarMapContainer* self = GetMap(_self);
+
+ Message* message = self->message;
+ const Reflection* reflection = message->GetReflection();
+
+ // Right now the Reflection API doesn't support map lookup, so we implement it
+ // via linear search.
+ //
+ // TODO(haberman): add lookup API to Reflection API.
+ size_t size = reflection->FieldSize(*message, self->parent_field_descriptor);
+ for (int i = size - 1; i >= 0; i--) {
+ const Message& entry = reflection->GetRepeatedMessage(
+ *message, self->parent_field_descriptor, i);
+ int matches = MapKeyMatches(self, &entry, key);
+ if (matches < 0) return NULL;
+ if (matches) {
+ Py_RETURN_TRUE;
+ }
+ }
+
+ Py_RETURN_FALSE;
+}
+
+PyObject* Get(PyObject* self, PyObject* args) {
+ PyObject* key;
+ PyObject* default_value = NULL;
+ if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr is_present(Contains(self, key));
+ if (is_present.get() == NULL) {
+ return NULL;
+ }
+
+ if (PyObject_IsTrue(is_present.get())) {
+ return GetItem(self, key);
+ } else {
+ if (default_value != NULL) {
+ Py_INCREF(default_value);
+ return default_value;
+ } else {
+ Py_RETURN_NONE;
+ }
+ }
+}
+
+static PyMappingMethods MpMethods = {
+ Length, // mp_length
+ GetItem, // mp_subscript
+ SetItem, // mp_ass_subscript
+};
+
+static void Dealloc(PyObject* _self) {
+ ScalarMapContainer* self = GetMap(_self);
+ self->owner.reset();
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+static PyMethodDef Methods[] = {
+ { "__contains__", Contains, METH_O,
+ "Tests whether a key is a member of the map." },
+ { "clear", (PyCFunction)Clear, METH_NOARGS,
+ "Removes all elements from the map." },
+ { "get", Get, METH_VARARGS,
+ "Gets the value for the given key if present, or otherwise a default" },
+ /*
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ */
+ {NULL, NULL},
+};
+
+} // namespace scalar_map_container
+
+namespace scalar_map_iterator {
+
+static void Dealloc(PyObject* _self) {
+ ScalarMapIterator* self = GetIter(_self);
+ Py_DECREF(self->dict);
+ Py_DECREF(self->iter);
+ Py_DECREF(self->container);
+ Py_TYPE(_self)->tp_free(_self);
+}
+
+PyObject* IterNext(PyObject* _self) {
+ ScalarMapIterator* self = GetIter(_self);
+
+ // This won't catch mutations to the map performed by MergeFrom(); no easy way
+ // to address that.
+ if (self->version != self->container->version) {
+ return PyErr_Format(PyExc_RuntimeError,
+ "Map modified during iteration.");
+ }
+
+ return PyIter_Next(self->iter);
+}
+
+} // namespace scalar_map_iterator
+
+PyTypeObject ScalarMapContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
+ sizeof(ScalarMapContainer), // tp_basicsize
+ 0, // tp_itemsize
+ scalar_map_container::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &scalar_map_container::MpMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map container", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ scalar_map_container::GetIterator, // tp_iter
+ 0, // tp_iternext
+ scalar_map_container::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+PyTypeObject ScalarMapIterator_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ FULL_MODULE_NAME ".ScalarMapIterator", // tp_name
+ sizeof(ScalarMapIterator), // tp_basicsize
+ 0, // tp_itemsize
+ scalar_map_iterator::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A scalar map iterator", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ PyObject_SelfIter, // tp_iter
+ scalar_map_iterator::IterNext, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/scalar_map_container.h b/python/google/protobuf/pyext/scalar_map_container.h
new file mode 100644
index 00000000..254e6e98
--- /dev/null
+++ b/python/google/protobuf/pyext/scalar_map_container.h
@@ -0,0 +1,110 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/descriptor.h>
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CMessage;
+
+struct ScalarMapContainer {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python ScalarMapContainer holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Pointer to the C++ Message that contains this container. The
+ // ScalarMapContainer does not own this pointer.
+ Message* message;
+
+ // Weak reference to a parent CMessage object (i.e. may be NULL.)
+ //
+ // Used to make sure all ancestors are also mutable when first
+ // modifying the container.
+ CMessage* parent;
+
+ // Pointer to the parent's descriptor that describes this
+ // field. Used together with the parent's message when making a
+ // default message instance mutable.
+ // The pointer is owned by the global DescriptorPool.
+ const FieldDescriptor* parent_field_descriptor;
+ const FieldDescriptor* key_field_descriptor;
+ const FieldDescriptor* value_field_descriptor;
+
+ // We bump this whenever we perform a mutation, to invalidate existing
+ // iterators.
+ uint64 version;
+};
+
+extern PyTypeObject ScalarMapContainer_Type;
+extern PyTypeObject ScalarMapIterator_Type;
+
+namespace scalar_map_container {
+
+// Builds a ScalarMapContainer object, from a parent message and a
+// field descriptor.
+extern PyObject *NewContainer(
+ CMessage* parent, const FieldDescriptor* parent_field_descriptor);
+
+// Releases the messages in the container to a new message.
+//
+// Returns 0 on success, -1 on failure.
+int Release(ScalarMapContainer* self);
+
+// Set the owner field of self and any children of self.
+void SetOwner(ScalarMapContainer* self,
+ const shared_ptr<Message>& new_owner);
+
+} // namespace scalar_map_container
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 55e653a0..82fca661 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -144,7 +144,6 @@ class GeneratedProtocolMessageType(type):
_InitMessage(descriptor, cls)
superclass = super(GeneratedProtocolMessageType, cls)
superclass.__init__(name, bases, dictionary)
- setattr(descriptor, '_concrete_class', cls)
def ParseMessage(descriptor, byte_str):
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index a47ce3e3..8cbd6822 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -100,6 +100,10 @@ def MessageToString(message, as_utf8=False, as_one_line=False,
return result.rstrip()
return result
+def _IsMapEntry(field):
+ return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
pointy_brackets=False, use_index_order=False,
@@ -108,7 +112,19 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
if use_index_order:
fields.sort(key=lambda x: x[0].index)
for field, value in fields:
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if _IsMapEntry(field):
+ for key in value:
+ # This is slow for maps with submessage entires because it copies the
+ # entire tree. Unfortunately this would take significant refactoring
+ # of this file to work around.
+ #
+ # TODO(haberman): refactor and optimize if this becomes an issue.
+ entry_submsg = field.message_type._concrete_class(
+ key=key, value=value[key])
+ PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ use_index_order=use_index_order, float_format=float_format)
+ elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
for element in value:
PrintField(field, element, out, indent, as_utf8, as_one_line,
pointy_brackets=pointy_brackets,
@@ -367,6 +383,7 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
message_descriptor.full_name, name))
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ is_map_entry = _IsMapEntry(field)
tokenizer.TryConsume(':')
if tokenizer.TryConsume('<'):
@@ -378,6 +395,8 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
+ elif is_map_entry:
+ sub_message = field.message_type._concrete_class()
else:
sub_message = getattr(message, field.name).add()
else:
@@ -391,6 +410,14 @@ def _MergeField(tokenizer, message, allow_multiple_scalars):
if tokenizer.AtEnd():
raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
_MergeField(tokenizer, sub_message, allow_multiple_scalars)
+
+ if is_map_entry:
+ value_cpptype = field.message_type.fields_by_name['value'].cpp_type
+ if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ value = getattr(message, field.name)[sub_message.key]
+ value.MergeFrom(sub_message.value)
+ else:
+ getattr(message, field.name)[sub_message.key] = sub_message.value
else:
_MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
@@ -701,13 +728,16 @@ class _Tokenizer(object):
String literals (whether bytes or text) can come in multiple adjacent
tokens which are automatically concatenated, like in C or Python. This
method only consumes one token.
+
+ Raises:
+ ParseError: When the wrong format data is found.
"""
text = self.token
if len(text) < 1 or text[0] not in ('\'', '"'):
- raise self._ParseError('Expected string but found: "%r"' % text)
+ raise self._ParseError('Expected string but found: %r' % (text,))
if len(text) < 2 or text[-1] != text[0]:
- raise self._ParseError('String missing ending quote.')
+ raise self._ParseError('String missing ending quote: %r' % (text,))
try:
result = text_encoding.CUnescape(text[1:-1])
diff --git a/python/setup.py b/python/setup.py
index a1365fba..5c321f50 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -91,6 +91,7 @@ def GenerateUnittestProtos():
if not os.path.exists("../.git"):
return
+ generate_proto("../src/google/protobuf/map_unittest.proto")
generate_proto("../src/google/protobuf/unittest.proto")
generate_proto("../src/google/protobuf/unittest_custom_options.proto")
generate_proto("../src/google/protobuf/unittest_import.proto")