aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/python_message.py
diff options
context:
space:
mode:
authorGravatar Jisi Liu <jisi.liu@gmail.com>2015-02-25 16:39:11 -0800
committerGravatar Jisi Liu <jisi.liu@gmail.com>2015-02-25 16:39:11 -0800
commitada65567852b96fdb4d070c0c3f86ca7b77824f9 (patch)
treea506994ce921ace3e6f88ca130a17af7f85c3d0f /python/google/protobuf/internal/python_message.py
parent581be24606a925d038f382dc4c86256e2d29e001 (diff)
Down integrate from Google internal.
Change-Id: I34d301133eea9c6f3a822c47d1f91e136fd33145
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-xpython/google/protobuf/internal/python_message.py81
1 files changed, 58 insertions, 23 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 6fda6ae0..6ad0f90d 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -219,12 +219,20 @@ def _AttachFieldHelpers(cls, field_descriptor):
def AddDecoder(wiretype, is_packed):
tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor),
- field_descriptor if field_descriptor.containing_oneof is not None
- else None)
+ decode_type = field_descriptor.type
+ if (decode_type == _FieldDescriptor.TYPE_ENUM and
+ type_checkers.SupportsOpenEnums(field_descriptor)):
+ decode_type = _FieldDescriptor.TYPE_INT32
+
+ oneof_descriptor = None
+ if field_descriptor.containing_oneof is not None:
+ oneof_descriptor = field_descriptor
+
+ field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
+
+ cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
False)
@@ -296,6 +304,8 @@ def _DefaultValueConstructorForField(field):
def MakeSubMessageDefault(message):
result = message_type._concrete_class()
result._SetListener(message._listener_for_children)
+ if field.containing_oneof:
+ message._UpdateOneofState(field)
return result
return MakeSubMessageDefault
@@ -476,6 +486,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
type_checker = type_checkers.GetTypeChecker(field)
default_value = field.default_value
valid_values = set()
+ is_proto3 = field.containing_type.syntax == "proto3"
def getter(self):
# TODO(protobuf-team): This may be broken since there may not be
@@ -483,15 +494,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ clear_when_set_to_default = is_proto3 and not field.containing_oneof
+
def field_setter(self, new_value):
# pylint: disable=protected-access
- self._fields[field] = type_checker.CheckValue(new_value)
+ # Testing the value for truthiness captures all of the proto3 defaults
+ # (0, 0.0, enum 0, and False).
+ new_value = type_checker.CheckValue(new_value)
+ if clear_when_set_to_default and not new_value:
+ self._fields.pop(field, None)
+ else:
+ self._fields[field] = new_value
# Check _cached_byte_size_dirty inline to improve performance, since scalar
# setters are called frequently.
if not self._cached_byte_size_dirty:
self._Modified()
- if field.containing_oneof is not None:
+ if field.containing_oneof:
def setter(self, new_value):
field_setter(self, new_value)
self._UpdateOneofState(field)
@@ -624,24 +644,35 @@ def _AddListFieldsMethod(message_descriptor, cls):
cls.ListFields = ListFields
+_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
+_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- singular_fields = {}
+ is_proto3 = (message_descriptor.syntax == "proto3")
+ error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
+
+ hassable_fields = {}
for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
- # Fields inside oneofs are never repeated (enforced by the compiler).
- for field in message_descriptor.oneofs:
- singular_fields[field.name] = field
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ continue
+ # For proto3, only submessages and fields inside a oneof have presence.
+ if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
+ not field.containing_oneof):
+ continue
+ hassable_fields[field.name] = field
+
+ if not is_proto3:
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for oneof in message_descriptor.oneofs:
+ hassable_fields[oneof.name] = oneof
def HasField(self, field_name):
try:
- field = singular_fields[field_name]
+ field = hassable_fields[field_name]
except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
+ raise ValueError(error_msg % field_name)
if isinstance(field, descriptor_mod.OneofDescriptor):
try:
@@ -871,6 +902,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
+ is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
self._Modified()
@@ -884,9 +916,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ if not is_proto3:
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append(
+ (tag_bytes, buffer[value_start_pos:new_pos]))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -1008,6 +1042,8 @@ def _AddMergeFromMethod(cls):
# Construct a new object to represent this field.
field_value = field._default_constructor(self)
fields[field] = field_value
+ if field.containing_oneof:
+ self._UpdateOneofState(field)
field_value.MergeFrom(value)
else:
self._fields[field] = value
@@ -1252,11 +1288,10 @@ class _ExtensionDict(object):
# It's slightly wasteful to lookup the type checker each time,
# but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle)
+ type_checker = type_checkers.GetTypeChecker(extension_handle)
# pylint: disable=protected-access
self._extended_message._fields[extension_handle] = (
- type_checker.CheckValue(value))
+ type_checker.CheckValue(value))
self._extended_message._Modified()
def _FindExtensionByName(self, name):