diff options
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 151 |
1 files changed, 124 insertions, 27 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 4bea57ac..9ee352d6 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Keep it Python2.5 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. +# # This code is meant to work on Python 2.4 and above only. # # TODO(robinson): Helpers for verbose, common checks like seeing if a @@ -50,11 +54,16 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -import copy_reg +import sys +if sys.version_info[0] < 3: + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + import copy_reg as copyreg +else: + from io import BytesIO + import copyreg import struct import weakref @@ -98,8 +107,8 @@ def InitMessage(descriptor, cls): _AddPropertiesForExtensions(descriptor, cls) _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(cls) - copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -176,7 +185,8 @@ def _AddSlots(message_descriptor, dictionary): '_is_present_in_parent', '_listener', '_listener_for_children', - '__weakref__'] + '__weakref__', + '_oneofs'] def _IsMessageSetExtension(field): @@ -272,7 +282,7 @@ def _DefaultValueConstructorForField(field): message._listener_for_children, field.message_type) return MakeRepeatedMessageDefault else: - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) def MakeRepeatedScalarDefault(message): return containers.RepeatedScalarFieldContainer( message._listener_for_children, type_checker) @@ -301,6 +311,10 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () @@ -440,7 +454,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): """ proto_field_name = field.name property_name = _PropertyName(proto_field_name) - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() @@ -450,14 +464,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name - def setter(self, new_value): - type_checker.CheckValue(new_value) - self._fields[field] = new_value + def field_setter(self, new_value): + # pylint: disable=protected-access + self._fields[field] = type_checker.CheckValue(new_value) # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() + if field.containing_oneof is not None: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -493,7 +514,10 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): if field_value is None: # Construct a new object to represent this field. field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener(self._listener_for_children) + field_value._SetListener( + _OneofListener(self, field) + if field.containing_oneof is not None + else self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -589,6 +613,9 @@ def _AddHasFieldMethod(message_descriptor, cls): for field in message_descriptor.fields: if field.label != _FieldDescriptor.LABEL_REPEATED: singular_fields[field.name] = field + # Fields inside oneofs are never repeated (enforced by the compiler). + for field in message_descriptor.oneofs: + singular_fields[field.name] = field def HasField(self, field_name): try: @@ -597,11 +624,18 @@ def _AddHasFieldMethod(message_descriptor, cls): raise ValueError( 'Protocol message has no singular "%s" field.' % field_name) - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(field) - return value is not None and value._is_present_in_parent + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False else: - return field in self._fields + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + cls.HasField = HasField @@ -611,7 +645,14 @@ def _AddClearFieldMethod(message_descriptor, cls): try: field = message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) if field in self._fields: # Note: If the field is a sub-message, its listener will still point @@ -619,6 +660,9 @@ def _AddClearFieldMethod(message_descriptor, cls): # will call _Modified() and invalidate our byte size. Big deal. del self._fields[field] + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + # Always call _Modified() -- even if nothing was changed, this is # a mutating method, and thus calling it should cause the field to become # present in the parent message. @@ -773,7 +817,7 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializePartialToString(self): - out = StringIO() + out = BytesIO() self._InternalSerialize(out.write) return out.getvalue() cls.SerializePartialToString = SerializePartialToString @@ -796,7 +840,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): # The only reason _InternalParse would return early is if it # encountered an end-group tag. raise message_mod.DecodeError('Unexpected end-group tag.') - except IndexError: + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. raise message_mod.DecodeError('Truncated message.') except struct.error, e: raise message_mod.DecodeError(e) @@ -857,7 +902,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): errors.extend(self.FindInitializationErrors()) return False - for field, value in self._fields.iteritems(): + for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: for element in value: @@ -953,6 +998,24 @@ def _AddMergeFromMethod(cls): cls.MergeFrom = MergeFrom +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -972,9 +1035,9 @@ def _AddMessageMethods(message_descriptor, cls): _AddMergeFromStringMethod(message_descriptor, cls) _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) - -def _AddPrivateHelperMethods(cls): +def _AddPrivateHelperMethods(message_descriptor, cls): """Adds implementation of private helper methods to cls.""" def Modified(self): @@ -992,8 +1055,20 @@ def _AddPrivateHelperMethods(cls): self._is_present_in_parent = True self._listener.Modified() + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + cls._Modified = Modified cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState class _Listener(object): @@ -1042,6 +1117,27 @@ class _Listener(object): pass +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass + + # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. # TODO(robinson): Support iteritems()-style iteration over all @@ -1133,9 +1229,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. type_checker = type_checkers.GetTypeChecker( - extension_handle.cpp_type, extension_handle.type) - type_checker.CheckValue(value) - self._extended_message._fields[extension_handle] = value + extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): |