aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/python_message.py
diff options
context:
space:
mode:
authorGravatar Dan O'Reilly <oreilldf@gmail.com>2015-08-12 23:57:46 -0400
committerGravatar Dan O'Reilly <oreilldf@gmail.com>2015-08-12 23:57:46 -0400
commite47cdd5a559f488ba52756927ce68f4cf93874fa (patch)
tree8ce2723e822808baf58e96f569c86035717ea351 /python/google/protobuf/internal/python_message.py
parentdaeaa6a28b81195f24d89222e649d79c9555af8b (diff)
parent38a56ee4b19d72c2e9d81a08b018704d1addf561 (diff)
Merge remote-tracking branch 'upstream/master' into py2_py3_straddle
Conflicts: python/google/protobuf/descriptor_pool.py python/google/protobuf/internal/api_implementation_default_test.py python/google/protobuf/internal/cpp_message.py python/google/protobuf/internal/descriptor_database_test.py python/google/protobuf/internal/descriptor_pool_test.py python/google/protobuf/internal/descriptor_python_test.py python/google/protobuf/internal/descriptor_test.py python/google/protobuf/internal/generator_test.py python/google/protobuf/internal/message_factory_python_test.py python/google/protobuf/internal/message_factory_test.py python/google/protobuf/internal/message_test.py python/google/protobuf/internal/proto_builder_test.py python/google/protobuf/internal/python_message.py python/google/protobuf/internal/reflection_test.py python/google/protobuf/internal/service_reflection_test.py python/google/protobuf/internal/symbol_database_test.py python/google/protobuf/internal/text_encoding_test.py python/google/protobuf/internal/text_format_test.py python/google/protobuf/internal/unknown_fields_test.py python/google/protobuf/internal/wire_format_test.py python/google/protobuf/pyext/descriptor_cpp2_test.py python/google/protobuf/pyext/message_factory_cpp2_test.py python/google/protobuf/pyext/reflection_cpp2_generated_test.py python/setup.py ruby/lib/google/protobuf/message_exts.rb
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-xpython/google/protobuf/internal/python_message.py235
1 files changed, 200 insertions, 35 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 58c65db9..bb06beb3 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -59,6 +59,7 @@ import weakref
import six
import six.moves.copyreg as copyreg
+import six.string_types
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers
@@ -70,6 +71,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import symbol_database
from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -94,6 +96,7 @@ def InitMessage(descriptor, cls):
for field in descriptor.fields:
_AttachFieldHelpers(cls, field)
+ descriptor._concrete_class = cls # pylint: disable=protected-access
_AddEnumValues(descriptor, cls)
_AddInitMethod(descriptor, cls)
_AddPropertiesForFields(descriptor, cls)
@@ -191,12 +194,37 @@ def _IsMessageSetExtension(field):
field.label == _FieldDescriptor.LABEL_OPTIONAL)
+def _IsMapField(field):
+ return (field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type.has_options and
+ field.message_type.GetOptions().map_entry)
+
+
+def _IsMessageMapField(field):
+ value_type = field.message_type.fields_by_name["value"]
+ return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
+ is_packable = (is_repeated and
+ wire_format.IsTypePackable(field_descriptor.type))
+ if not is_packable:
+ is_packed = False
+ elif field_descriptor.containing_type.syntax == "proto2":
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+ else:
+ has_packed_false = (field_descriptor.has_options and
+ field_descriptor.GetOptions().HasField("packed") and
+ field_descriptor.GetOptions().packed == False)
+ is_packed = not has_packed_false
+ is_map_entry = _IsMapField(field_descriptor)
+
+ if is_map_entry:
+ field_encoder = encoder.MapEncoder(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor)
+ elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
else:
@@ -212,12 +240,27 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor),
- field_descriptor if field_descriptor.containing_oneof is not None
- else None)
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ if is_map_entry:
+ is_message_map = _IsMessageMapField(field_descriptor)
+
+ field_decoder = decoder.MapDecoder(
+ field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
+ is_message_map)
+ else:
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -250,6 +293,26 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
+def _GetInitializeDefaultForMap(field):
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ raise ValueError('map_entry set on non-repeated field %s' % (
+ field.name))
+ fields_by_name = field.message_type.fields_by_name
+ key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
+
+ value_field = fields_by_name['value']
+ if _IsMessageMapField(field):
+ def MakeMessageMapDefault(message):
+ return containers.MessageMap(
+ message._listener_for_children, value_field.message_type, key_checker)
+ return MakeMessageMapDefault
+ else:
+ value_checker = type_checkers.GetTypeChecker(value_field)
+ def MakePrimitiveMapDefault(message):
+ return containers.ScalarMap(
+ message._listener_for_children, key_checker, value_checker)
+ return MakePrimitiveMapDefault
+
def _DefaultValueConstructorForField(field):
"""Returns a function which returns a default value for a field.
@@ -264,6 +327,9 @@ def _DefaultValueConstructorForField(field):
value may refer back to |message| via a weak reference.
"""
+ if _IsMapField(field):
+ return _GetInitializeDefaultForMap(field)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.has_default_value and field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
@@ -289,6 +355,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
result._SetListener(message._listener_for_children)
+ if field.containing_oneof:
+ message._UpdateOneofState(field)
return result
return MakeSubMessageDefault
@@ -312,7 +380,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name):
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
- fields = message_descriptor.fields
+
+ def _GetIntegerEnumValue(enum_type, value):
+ """Convert a string or integer enum value to an integer.
+
+ If the value is a string, it is converted to the enum value in
+ enum_type with the same name. If the value is not a string, it's
+ returned as-is. (No conversion or bounds-checking is done.)
+ """
+ if isinstance(value, six.string_types):
+ try:
+ return enum_type.values_by_name[value].number
+ except KeyError:
+ raise ValueError('Enum type %s: unknown label "%s"' % (
+ enum_type.full_name, value))
+ return value
+
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = len(kwargs) > 0
@@ -335,19 +418,37 @@ def _AddInitMethod(message_descriptor, cls):
if field.label == _FieldDescriptor.LABEL_REPEATED:
copy = field._default_constructor(self)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in field_value:
+ copy[key].MergeFrom(field_value[key])
+ else:
+ copy.update(field_value)
+ else:
+ for val in field_value:
+ if isinstance(val, dict):
+ copy.add(**val)
+ else:
+ copy.add().MergeFrom(val)
else: # Scalar
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = [_GetIntegerEnumValue(field.enum_type, val)
+ for val in field_value]
copy.extend(field_value)
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
+ new_val = field_value
+ if isinstance(field_value, dict):
+ new_val = field.message_type._concrete_class(**field_value)
try:
- copy.MergeFrom(field_value)
+ copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ field_value = _GetIntegerEnumValue(field.enum_type, field_value)
try:
setattr(self, field_name, field_value)
except TypeError:
@@ -469,6 +570,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -476,15 +578,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
def field_setter(self, new_value):
# pylint: disable=protected-access
- self._fields[field] = type_checker.CheckValue(new_value)
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof is not None:
+ if field.containing_oneof:
def setter(self, new_value):
field_setter(self, new_value)
self._UpdateOneofState(field)
@@ -617,24 +728,35 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for field in message_descriptor.oneofs:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor):
try:
@@ -720,6 +842,26 @@ def _AddHasExtensionMethod(cls):
return extension_handle in self._fields
cls.HasExtension = HasExtension
+def _UnpackAny(msg):
+ type_url = msg.type_url
+ db = symbol_database.Default()
+
+ if not type_url:
+ return None
+
+ # TODO(haberman): For now we just strip the hostname. Better logic will be
+ # required.
+ type_name = type_url.split("/")[-1]
+ descriptor = db.pool.FindMessageTypeByName(type_name)
+
+ if descriptor is None:
+ return None
+
+ message_class = db.GetPrototype(descriptor)
+ message = message_class()
+
+ message.ParseFromString(msg.value)
+ return message
def _AddEqualsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -731,6 +873,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
+ if self.DESCRIPTOR.full_name == "google.protobuf.Any":
+ any_a = _UnpackAny(self)
+ any_b = _UnpackAny(other)
+ if any_a and any_b:
+ return any_a == any_b
+
if not self.ListFields() == other.ListFields():
return False
@@ -864,6 +1012,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -877,9 +1026,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -920,6 +1071,9 @@ def _AddIsInitializedMethod(message_descriptor, cls):
for field, value in list(self._fields.items()): # dict can change size!
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if (field.message_type.has_options and
+ field.message_type.GetOptions().map_entry):
+ continue
for element in value:
if not element.IsInitialized():
if errors is not None:
@@ -955,16 +1109,26 @@ def _AddIsInitializedMethod(message_descriptor, cls):
else:
name = field.name
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in range(len(value)):
+ if _IsMapField(field):
+ if _IsMessageMapField(field):
+ for key in value:
+ element = value[key]
+ prefix = "%s[%d]." % (name, key)
+ sub_errors = element.FindInitializationErrors()
+ errors += [prefix + error for error in sub_errors]
+ else:
+ # ScalarMaps can't have any initialization errors.
+ pass
+ elif field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in xrange(len(value)):
element = value[i]
prefix = "%s[%d]." % (name, i)
sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
else:
prefix = name + "."
sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
+ errors += [prefix + error for error in sub_errors]
return errors
@@ -1001,6 +1165,8 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field.
field_value = field._default_constructor(self)
fields[field] = field_value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
field_value.MergeFrom(value)
else:
self._fields[field] = value
@@ -1245,11 +1411,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle)
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):