aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google
diff options
context:
space:
mode:
authorGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
committerGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
commitfccb146e3fe437b0df1e9c50d4b8e1080ddb4bd9 (patch)
tree9f2d9fe0267d96a54e541377ffeada3d0bff0d1d /python/google
parentd5cf7b55a6a1f959d1646785f63ca2b62da78079 (diff)
Massive roll-up of changes. See CHANGES.txt.
Diffstat (limited to 'python/google')
-rwxr-xr-xpython/google/protobuf/descriptor.py211
-rwxr-xr-xpython/google/protobuf/internal/containers.py69
-rwxr-xr-xpython/google/protobuf/internal/decoder.py770
-rwxr-xr-xpython/google/protobuf/internal/decoder_test.py256
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py227
-rwxr-xr-xpython/google/protobuf/internal/encoder.py888
-rwxr-xr-xpython/google/protobuf/internal/encoder_test.py286
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py107
-rwxr-xr-xpython/google/protobuf/internal/input_stream.py338
-rwxr-xr-xpython/google/protobuf/internal/input_stream_test.py314
-rwxr-xr-xpython/google/protobuf/internal/message_listener.py41
-rwxr-xr-xpython/google/protobuf/internal/message_test.py42
-rwxr-xr-xpython/google/protobuf/internal/output_stream.py125
-rwxr-xr-xpython/google/protobuf/internal/output_stream_test.py178
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py354
-rwxr-xr-xpython/google/protobuf/internal/test_util.py467
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py25
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py127
-rwxr-xr-xpython/google/protobuf/internal/wire_format.py27
-rwxr-xr-xpython/google/protobuf/message.py11
-rwxr-xr-xpython/google/protobuf/reflection.py1487
-rwxr-xr-xpython/google/protobuf/text_format.py5
22 files changed, 3032 insertions, 3323 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 8e3fc2eb..aa4ab969 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -44,12 +44,24 @@ file, in types that make this information accessible in Python.
__author__ = 'robinson@google.com (Will Robinson)'
+
+class Error(Exception):
+ """Base error for this module."""
+
+
class DescriptorBase(object):
"""Descriptors base class.
This class is the base of all descriptor classes. It provides common options
related functionaility.
+
+ Attributes:
+ has_options: True if the descriptor has non-default options. Usually it
+ is not necessary to read this -- just call GetOptions() which will
+ happily return the default instance. However, it's sometimes useful
+ for efficiency, and also useful inside the protobuf implementation to
+ avoid some bootstrapping issues.
"""
def __init__(self, options, options_class_name):
@@ -60,6 +72,9 @@ class DescriptorBase(object):
self._options = options
self._options_class_name = options_class_name
+ # Does this descriptor have non-default options?
+ self.has_options = options is not None
+
def GetOptions(self):
"""Retrieves descriptor options.
@@ -78,7 +93,70 @@ class DescriptorBase(object):
return self._options
-class Descriptor(DescriptorBase):
+class _NestedDescriptorBase(DescriptorBase):
+ """Common class for descriptors that can be nested."""
+
+ def __init__(self, options, options_class_name, name, full_name,
+ file, containing_type, serialized_start=None,
+ serialized_end=None):
+ """Constructor.
+
+ Args:
+ options: Protocol message options or None
+ to use default message options.
+ options_class_name: (str) The class name of the above options.
+
+ name: (str) Name of this protocol message type.
+ full_name: (str) Fully-qualified name of this protocol message type,
+ which will include protocol "package" name and the name of any
+ enclosing types.
+ file: (FileDescriptor) Reference to file info.
+ containing_type: if provided, this is a nested descriptor, with this
+ descriptor as parent, otherwise None.
+ serialized_start: The start index (inclusive) in block in the
+ file.serialized_pb that describes this descriptor.
+ serialized_end: The end index (exclusive) in block in the
+ file.serialized_pb that describes this descriptor.
+ """
+ super(_NestedDescriptorBase, self).__init__(
+ options, options_class_name)
+
+ self.name = name
+ # TODO(falk): Add function to calculate full_name instead of having it in
+ # memory?
+ self.full_name = full_name
+ self.file = file
+ self.containing_type = containing_type
+
+ self._serialized_start = serialized_start
+ self._serialized_end = serialized_end
+
+ def GetTopLevelContainingType(self):
+ """Returns the root if this is a nested type, or itself if its the root."""
+ desc = self
+ while desc.containing_type is not None:
+ desc = desc.containing_type
+ return desc
+
+ def CopyToProto(self, proto):
+ """Copies this to the matching proto in descriptor_pb2.
+
+ Args:
+ proto: An empty proto instance from descriptor_pb2.
+
+ Raises:
+ Error: If self couldnt be serialized, due to to few constructor arguments.
+ """
+ if (self.file is not None and
+ self._serialized_start is not None and
+ self._serialized_end is not None):
+ proto.ParseFromString(self.file.serialized_pb[
+ self._serialized_start:self._serialized_end])
+ else:
+ raise Error('Descriptor does not contain serialization.')
+
+
+class Descriptor(_NestedDescriptorBase):
"""Descriptor for a protocol message type.
@@ -89,10 +167,8 @@ class Descriptor(DescriptorBase):
which will include protocol "package" name and the name of any
enclosing types.
- filename: (str) Name of the .proto file containing this message.
-
containing_type: (Descriptor) Reference to the descriptor of the
- type containing us, or None if we have no containing type.
+ type containing us, or None if this is top-level.
fields: (list of FieldDescriptors) Field descriptors for all
fields in this type.
@@ -123,20 +199,28 @@ class Descriptor(DescriptorBase):
objects as |extensions|, but indexed by "name" attribute of each
FieldDescriptor.
+ is_extendable: Does this type define any extension ranges?
+
options: (descriptor_pb2.MessageOptions) Protocol message options or None
to use default message options.
+
+ file: (FileDescriptor) Reference to file descriptor.
"""
- def __init__(self, name, full_name, filename, containing_type,
- fields, nested_types, enum_types, extensions, options=None):
+ def __init__(self, name, full_name, filename, containing_type, fields,
+ nested_types, enum_types, extensions, options=None,
+ is_extendable=True, extension_ranges=None, file=None,
+ serialized_start=None, serialized_end=None):
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
+
+ Note that filename is an obsolete argument, that is not used anymore.
+ Please use file.name to access this as an attribute.
"""
- super(Descriptor, self).__init__(options, 'MessageOptions')
- self.name = name
- self.full_name = full_name
- self.filename = filename
- self.containing_type = containing_type
+ super(Descriptor, self).__init__(
+ options, 'MessageOptions', name, full_name, file,
+ containing_type, serialized_start=serialized_start,
+ serialized_end=serialized_start)
# We have fields in addition to fields_by_name and fields_by_number,
# so that:
@@ -163,6 +247,20 @@ class Descriptor(DescriptorBase):
for extension in self.extensions:
extension.extension_scope = self
self.extensions_by_name = dict((f.name, f) for f in extensions)
+ self.is_extendable = is_extendable
+ self.extension_ranges = extension_ranges
+
+ self._serialized_start = serialized_start
+ self._serialized_end = serialized_end
+
+ def CopyToProto(self, proto):
+ """Copies this to a descriptor_pb2.DescriptorProto.
+
+ Args:
+ proto: An empty descriptor_pb2.DescriptorProto.
+ """
+ # This function is overriden to give a better doc comment.
+ super(Descriptor, self).CopyToProto(proto)
# TODO(robinson): We should have aggressive checking here,
@@ -195,6 +293,8 @@ class FieldDescriptor(DescriptorBase):
label: (One of the LABEL_* constants below) Tells whether this
field is optional, required, or repeated.
+ has_default_value: (bool) True if this field has a default value defined,
+ otherwise false.
default_value: (Varies) Default value of this field. Only
meaningful for non-repeated scalar fields. Repeated fields
should always set this to [], and non-repeated composite
@@ -272,7 +372,8 @@ class FieldDescriptor(DescriptorBase):
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
- is_extension, extension_scope, options=None):
+ is_extension, extension_scope, options=None,
+ has_default_value=True):
"""The arguments are as described in the description of FieldDescriptor
attributes above.
@@ -288,6 +389,7 @@ class FieldDescriptor(DescriptorBase):
self.type = type
self.cpp_type = cpp_type
self.label = label
+ self.has_default_value = has_default_value
self.default_value = default_value
self.containing_type = containing_type
self.message_type = message_type
@@ -296,7 +398,7 @@ class FieldDescriptor(DescriptorBase):
self.extension_scope = extension_scope
-class EnumDescriptor(DescriptorBase):
+class EnumDescriptor(_NestedDescriptorBase):
"""Descriptor for an enum defined in a .proto file.
@@ -305,7 +407,6 @@ class EnumDescriptor(DescriptorBase):
name: (str) Name of the enum type.
full_name: (str) Full name of the type, including package name
and any enclosing type(s).
- filename: (str) Name of the .proto file in which this appears.
values: (list of EnumValueDescriptors) List of the values
in this enum.
@@ -317,23 +418,41 @@ class EnumDescriptor(DescriptorBase):
type of this enum, or None if this is an enum defined at the
top level in a .proto file. Set by Descriptor's constructor
if we're passed into one.
+ file: (FileDescriptor) Reference to file descriptor.
options: (descriptor_pb2.EnumOptions) Enum options message or
None to use default enum options.
"""
def __init__(self, name, full_name, filename, values,
- containing_type=None, options=None):
- """Arguments are as described in the attribute description above."""
- super(EnumDescriptor, self).__init__(options, 'EnumOptions')
- self.name = name
- self.full_name = full_name
- self.filename = filename
+ containing_type=None, options=None, file=None,
+ serialized_start=None, serialized_end=None):
+ """Arguments are as described in the attribute description above.
+
+ Note that filename is an obsolete argument, that is not used anymore.
+ Please use file.name to access this as an attribute.
+ """
+ super(EnumDescriptor, self).__init__(
+ options, 'EnumOptions', name, full_name, file,
+ containing_type, serialized_start=serialized_start,
+ serialized_end=serialized_start)
+
self.values = values
for value in self.values:
value.type = self
self.values_by_name = dict((v.name, v) for v in values)
self.values_by_number = dict((v.number, v) for v in values)
- self.containing_type = containing_type
+
+ self._serialized_start = serialized_start
+ self._serialized_end = serialized_end
+
+ def CopyToProto(self, proto):
+ """Copies this to a descriptor_pb2.EnumDescriptorProto.
+
+ Args:
+ proto: An empty descriptor_pb2.EnumDescriptorProto.
+ """
+ # This function is overriden to give a better doc comment.
+ super(EnumDescriptor, self).CopyToProto(proto)
class EnumValueDescriptor(DescriptorBase):
@@ -360,7 +479,7 @@ class EnumValueDescriptor(DescriptorBase):
self.type = type
-class ServiceDescriptor(DescriptorBase):
+class ServiceDescriptor(_NestedDescriptorBase):
"""Descriptor for a service.
@@ -372,12 +491,15 @@ class ServiceDescriptor(DescriptorBase):
service.
options: (descriptor_pb2.ServiceOptions) Service options message or
None to use default service options.
+ file: (FileDescriptor) Reference to file info.
"""
- def __init__(self, name, full_name, index, methods, options=None):
- super(ServiceDescriptor, self).__init__(options, 'ServiceOptions')
- self.name = name
- self.full_name = full_name
+ def __init__(self, name, full_name, index, methods, options=None, file=None,
+ serialized_start=None, serialized_end=None):
+ super(ServiceDescriptor, self).__init__(
+ options, 'ServiceOptions', name, full_name, file,
+ None, serialized_start=serialized_start,
+ serialized_end=serialized_end)
self.index = index
self.methods = methods
# Set the containing service for each method in this service.
@@ -391,6 +513,15 @@ class ServiceDescriptor(DescriptorBase):
return method
return None
+ def CopyToProto(self, proto):
+ """Copies this to a descriptor_pb2.ServiceDescriptorProto.
+
+ Args:
+ proto: An empty descriptor_pb2.ServiceDescriptorProto.
+ """
+ # This function is overriden to give a better doc comment.
+ super(ServiceDescriptor, self).CopyToProto(proto)
+
class MethodDescriptor(DescriptorBase):
@@ -423,6 +554,32 @@ class MethodDescriptor(DescriptorBase):
self.output_type = output_type
+class FileDescriptor(DescriptorBase):
+ """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
+
+ name: name of file, relative to root of source tree.
+ package: name of the package
+ serialized_pb: (str) Byte string of serialized
+ descriptor_pb2.FileDescriptorProto.
+ """
+
+ def __init__(self, name, package, options=None, serialized_pb=None):
+ """Constructor."""
+ super(FileDescriptor, self).__init__(options, 'FileOptions')
+
+ self.name = name
+ self.package = package
+ self.serialized_pb = serialized_pb
+
+ def CopyToProto(self, proto):
+ """Copies this to a descriptor_pb2.FileDescriptorProto.
+
+ Args:
+ proto: An empty descriptor_pb2.FileDescriptorProto.
+ """
+ proto.ParseFromString(self.serialized_pb)
+
+
def _ParseOptions(message, string):
"""Parses serialized options.
@@ -430,4 +587,4 @@ def _ParseOptions(message, string):
proto2 files. It must not be used outside proto2.
"""
message.ParseFromString(string)
- return message;
+ return message
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index d8a825df..5cc7d6d0 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -54,8 +54,7 @@ class BaseContainer(object):
Args:
message_listener: A MessageListener implementation.
The RepeatedScalarFieldContainer will call this object's
- TransitionToNonempty() method when it transitions from being empty to
- being nonempty.
+ Modified() method when it is modified.
"""
self._message_listener = message_listener
self._values = []
@@ -73,6 +72,9 @@ class BaseContainer(object):
# The concrete classes should define __eq__.
return not self == other
+ def __repr__(self):
+ return repr(self._values)
+
class RepeatedScalarFieldContainer(BaseContainer):
@@ -86,8 +88,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
Args:
message_listener: A MessageListener implementation.
The RepeatedScalarFieldContainer will call this object's
- TransitionToNonempty() method when it transitions from being empty to
- being nonempty.
+ Modified() method when it is modified.
type_checker: A type_checkers.ValueChecker instance to run on elements
inserted into this container.
"""
@@ -96,44 +97,47 @@ class RepeatedScalarFieldContainer(BaseContainer):
def append(self, value):
"""Appends an item to the list. Similar to list.append()."""
- self.insert(len(self._values), value)
+ self._type_checker.CheckValue(value)
+ self._values.append(value)
+ if not self._message_listener.dirty:
+ self._message_listener.Modified()
def insert(self, key, value):
"""Inserts the item at the specified position. Similar to list.insert()."""
self._type_checker.CheckValue(value)
self._values.insert(key, value)
- self._message_listener.ByteSizeDirty()
- if len(self._values) == 1:
- self._message_listener.TransitionToNonempty()
+ if not self._message_listener.dirty:
+ self._message_listener.Modified()
def extend(self, elem_seq):
"""Extends by appending the given sequence. Similar to list.extend()."""
if not elem_seq:
return
- orig_empty = len(self._values) == 0
new_values = []
for elem in elem_seq:
self._type_checker.CheckValue(elem)
new_values.append(elem)
self._values.extend(new_values)
- self._message_listener.ByteSizeDirty()
- if orig_empty:
- self._message_listener.TransitionToNonempty()
+ self._message_listener.Modified()
+
+ def MergeFrom(self, other):
+ """Appends the contents of another repeated field of the same type to this
+ one. We do not check the types of the individual fields.
+ """
+ self._values.extend(other._values)
+ self._message_listener.Modified()
def remove(self, elem):
"""Removes an item from the list. Similar to list.remove()."""
self._values.remove(elem)
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __setitem__(self, key, value):
"""Sets the item on the specified position."""
- # No need to call TransitionToNonempty(), since if we're able to
- # set the element at this index, we were already nonempty before
- # this method was called.
- self._message_listener.ByteSizeDirty()
self._type_checker.CheckValue(value)
self._values[key] = value
+ self._message_listener.Modified()
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
@@ -146,17 +150,17 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._type_checker.CheckValue(value)
new_values.append(value)
self._values[start:stop] = new_values
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __delitem__(self, key):
"""Deletes the item at the specified position."""
del self._values[key]
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __delslice__(self, start, stop):
"""Deletes the subset of items from between the specified indices."""
del self._values[start:stop]
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __eq__(self, other):
"""Compares the current instance with another one."""
@@ -186,8 +190,7 @@ class RepeatedCompositeFieldContainer(BaseContainer):
Args:
message_listener: A MessageListener implementation.
The RepeatedCompositeFieldContainer will call this object's
- TransitionToNonempty() method when it transitions from being empty to
- being nonempty.
+ Modified() method when it is modified.
message_descriptor: A Descriptor instance describing the protocol type
that should be present in this container. We'll use the
_concrete_class field of this descriptor when the client calls add().
@@ -199,10 +202,24 @@ class RepeatedCompositeFieldContainer(BaseContainer):
new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener)
self._values.append(new_element)
- self._message_listener.ByteSizeDirty()
- self._message_listener.TransitionToNonempty()
+ if not self._message_listener.dirty:
+ self._message_listener.Modified()
return new_element
+ def MergeFrom(self, other):
+ """Appends the contents of another repeated field of the same type to this
+ one, copying each individual message.
+ """
+ message_class = self._message_descriptor._concrete_class
+ listener = self._message_listener
+ values = self._values
+ for message in other._values:
+ new_element = message_class()
+ new_element._SetListener(listener)
+ new_element.MergeFrom(message)
+ values.append(new_element)
+ listener.Modified()
+
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
return self._values[start:stop]
@@ -210,12 +227,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
def __delitem__(self, key):
"""Deletes the item at the specified position."""
del self._values[key]
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __delslice__(self, start, stop):
"""Deletes the subset of items from between the specified indices."""
del self._values[start:stop]
- self._message_listener.ByteSizeDirty()
+ self._message_listener.Modified()
def __eq__(self, other):
"""Compares the current instance with another one."""
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 83d6fe0c..461a30c0 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -28,182 +28,614 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Class for decoding protocol buffer primitives.
-
-Contains the logic for decoding every logical protocol field type
-from one of the 5 physical wire types.
+"""Code for decoding protocol buffer primitives.
+
+This code is very similar to encoder.py -- read the docs for that module first.
+
+A "decoder" is a function with the signature:
+ Decode(buffer, pos, end, message, field_dict)
+The arguments are:
+ buffer: The string containing the encoded message.
+ pos: The current position in the string.
+ end: The position in the string where the current message ends. May be
+ less than len(buffer) if we're reading a sub-message.
+ message: The message object into which we're parsing.
+ field_dict: message._fields (avoids a hashtable lookup).
+The decoder reads the field and stores it into field_dict, returning the new
+buffer position. A decoder for a repeated field may proactively decode all of
+the elements of that field, if they appear consecutively.
+
+Note that decoders may throw any of the following:
+ IndexError: Indicates a truncated message.
+ struct.error: Unpacking of a fixed-width field failed.
+ message.DecodeError: Other errors.
+
+Decoders are expected to raise an exception if they are called with pos > end.
+This allows callers to be lax about bounds checking: it's fineto read past
+"end" as long as you are sure that someone else will notice and throw an
+exception later on.
+
+Something up the call stack is expected to catch IndexError and struct.error
+and convert them to message.DecodeError.
+
+Decoders are constructed using decoder constructors with the signature:
+ MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
+The arguments are:
+ field_number: The field number of the field we want to decode.
+ is_repeated: Is the field a repeated field? (bool)
+ is_packed: Is the field a packed field? (bool)
+ key: The key to use when looking up the field within field_dict.
+ (This is actually the FieldDescriptor but nothing in this
+ file should depend on that.)
+ new_default: A function which takes a message object as a parameter and
+ returns a new instance of the default value for this field.
+ (This is called for repeated fields and sub-messages, when an
+ instance does not already exist.)
+
+As with encoders, we define a decoder constructor for every type of field.
+Then, for every field of every message class we construct an actual decoder.
+That decoder goes into a dict indexed by tag, so when we decode a message
+we repeatedly read a tag, look up the corresponding decoder, and invoke it.
"""
-__author__ = 'robinson@google.com (Will Robinson)'
+__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-from google.protobuf import message
-from google.protobuf.internal import input_stream
+from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
+from google.protobuf import message
+# This is not for optimization, but rather to avoid conflicts with local
+# variables named "message".
+_DecodeError = message.DecodeError
+
+
+def _VarintDecoder(mask):
+ """Return an encoder for a basic varint value (does not include tag).
+
+ Decoded values will be bitwise-anded with the given mask before being
+ returned, e.g. to limit them to 32 bits. The returned decoder does not
+ take the usual "end" parameter -- the caller is expected to do bounds checking
+ after the fact (often the caller can defer such checking until later). The
+ decoder returns a (value, new_pos) pair.
+ """
+
+ local_ord = ord
+ def DecodeVarint(buffer, pos):
+ result = 0
+ shift = 0
+ while 1:
+ b = local_ord(buffer[pos])
+ result |= ((b & 0x7f) << shift)
+ pos += 1
+ if not (b & 0x80):
+ result &= mask
+ return (result, pos)
+ shift += 7
+ if shift >= 64:
+ raise _DecodeError('Too many bytes when decoding varint.')
+ return DecodeVarint
+
+
+def _SignedVarintDecoder(mask):
+ """Like _VarintDecoder() but decodes signed values."""
+
+ local_ord = ord
+ def DecodeVarint(buffer, pos):
+ result = 0
+ shift = 0
+ while 1:
+ b = local_ord(buffer[pos])
+ result |= ((b & 0x7f) << shift)
+ pos += 1
+ if not (b & 0x80):
+ if result > 0x7fffffffffffffff:
+ result -= (1 << 64)
+ result |= ~mask
+ else:
+ result &= mask
+ return (result, pos)
+ shift += 7
+ if shift >= 64:
+ raise _DecodeError('Too many bytes when decoding varint.')
+ return DecodeVarint
+
+
+_DecodeVarint = _VarintDecoder((1 << 64) - 1)
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
+
+# Use these versions for values which must be limited to 32 bits.
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
+
+
+def ReadTag(buffer, pos):
+ """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
+
+ We return the raw bytes of the tag rather than decoding them. The raw
+ bytes can then be used to look up the proper decoder. This effectively allows
+ us to trade some work that would be done in pure-python (decoding a varint)
+ for work that is done in C (searching for a byte string in a hash table).
+ In a low-level language it would be much cheaper to decode the varint and
+ use that, but not in Python.
+ """
+
+ start = pos
+ while ord(buffer[pos]) & 0x80:
+ pos += 1
+ pos += 1
+ return (buffer[start:pos], pos)
+
+
+# --------------------------------------------------------------------
+
+
+def _SimpleDecoder(wire_type, decode_value):
+ """Return a constructor for a decoder for fields of a particular type.
+
+ Args:
+ wire_type: The field's wire type.
+ decode_value: A function which decodes an individual value, e.g.
+ _DecodeVarint()
+ """
+
+ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
+ if is_packed:
+ local_DecodeVarint = _DecodeVarint
+ def DecodePackedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ (endpoint, pos) = local_DecodeVarint(buffer, pos)
+ endpoint += pos
+ if endpoint > end:
+ raise _DecodeError('Truncated message.')
+ while pos < endpoint:
+ (element, pos) = decode_value(buffer, pos)
+ value.append(element)
+ if pos > endpoint:
+ del value[-1] # Discard corrupt value.
+ raise _DecodeError('Packed element was truncated.')
+ return pos
+ return DecodePackedField
+ elif is_repeated:
+ tag_bytes = encoder.TagBytes(field_number, wire_type)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ (element, new_pos) = decode_value(buffer, pos)
+ value.append(element)
+ # Predict that the next tag is another copy of the same repeated
+ # field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+ # Prediction failed. Return.
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ (field_dict[key], pos) = decode_value(buffer, pos)
+ if pos > end:
+ del field_dict[key] # Discard corrupt value.
+ raise _DecodeError('Truncated message.')
+ return pos
+ return DecodeField
+
+ return SpecificDecoder
+
+
+def _ModifiedDecoder(wire_type, decode_value, modify_value):
+ """Like SimpleDecoder but additionally invokes modify_value on every value
+ before storing it. Usually modify_value is ZigZagDecode.
+ """
+
+ # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+ # not enough to make a significant difference.
+
+ def InnerDecode(buffer, pos):
+ (result, new_pos) = decode_value(buffer, pos)
+ return (modify_value(result), new_pos)
+ return _SimpleDecoder(wire_type, InnerDecode)
+
+
+def _StructPackDecoder(wire_type, format):
+ """Return a constructor for a decoder for a fixed-width field.
+
+ Args:
+ wire_type: The field's wire type.
+ format: The format string to pass to struct.unpack().
+ """
+
+ value_size = struct.calcsize(format)
+ local_unpack = struct.unpack
+
+ # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+ # not enough to make a significant difference.
+
+ # Note that we expect someone up-stack to catch struct.error and convert
+ # it to _DecodeError -- this way we don't have to set up exception-
+ # handling blocks every time we parse one value.
+
+ def InnerDecode(buffer, pos):
+ new_pos = pos + value_size
+ result = local_unpack(format, buffer[pos:new_pos])[0]
+ return (result, new_pos)
+ return _SimpleDecoder(wire_type, InnerDecode)
+
+
+# --------------------------------------------------------------------
+
+
+Int32Decoder = EnumDecoder = _SimpleDecoder(
+ wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
+
+Int64Decoder = _SimpleDecoder(
+ wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
+
+UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
+UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
+
+SInt32Decoder = _ModifiedDecoder(
+ wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
+SInt64Decoder = _ModifiedDecoder(
+ wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
+
+BoolDecoder = _ModifiedDecoder(
+ wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
+
+
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
+ """Returns a decoder for a string field."""
+
+ local_DecodeVarint = _DecodeVarint
+ local_unicode = unicode
+
+ assert not is_packed
+ if is_repeated:
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated string.')
+ value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated string.')
+ field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
+ return new_pos
+ return DecodeField
+
+
+def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
+ """Returns a decoder for a bytes field."""
+
+ local_DecodeVarint = _DecodeVarint
+
+ assert not is_packed
+ if is_repeated:
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated string.')
+ value.append(buffer[pos:new_pos])
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated string.')
+ field_dict[key] = buffer[pos:new_pos]
+ return new_pos
+ return DecodeField
+
+
+def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
+ """Returns a decoder for a group field."""
+
+ end_tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_END_GROUP)
+ end_tag_len = len(end_tag_bytes)
+
+ assert not is_packed
+ if is_repeated:
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_START_GROUP)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ # Read sub-message.
+ pos = value.add()._InternalParse(buffer, pos, end)
+ # Read end tag.
+ new_pos = pos+end_tag_len
+ if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+ raise _DecodeError('Missing group end tag.')
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ # Read sub-message.
+ pos = value._InternalParse(buffer, pos, end)
+ # Read end tag.
+ new_pos = pos+end_tag_len
+ if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+ raise _DecodeError('Missing group end tag.')
+ return new_pos
+ return DecodeField
+
+
+def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
+ """Returns a decoder for a message field."""
+
+ local_DecodeVarint = _DecodeVarint
+
+ assert not is_packed
+ if is_repeated:
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_LENGTH_DELIMITED)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ # Read length.
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ # Read sub-message.
+ if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+ # Predict that the next tag is another copy of the same repeated field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+ # Prediction failed. Return.
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ # Read length.
+ (size, pos) = local_DecodeVarint(buffer, pos)
+ new_pos = pos + size
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ # Read sub-message.
+ if value._InternalParse(buffer, pos, new_pos) != new_pos:
+ # The only reason _InternalParse would return early is if it encountered
+ # an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+ return new_pos
+ return DecodeField
+
+
+# --------------------------------------------------------------------
+
+MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
+
+def MessageSetItemDecoder(extensions_by_number):
+ """Returns a decoder for a MessageSet item.
+
+ The parameter is the _extensions_by_number map for the message class.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
+ }
+ }
+ """
+
+ type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
+ message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+
+ local_ReadTag = ReadTag
+ local_DecodeVarint = _DecodeVarint
+ local_SkipField = SkipField
+
+ def DecodeItem(buffer, pos, end, message, field_dict):
+ type_id = -1
+ message_start = -1
+ message_end = -1
+
+ # Technically, type_id and message can appear in any order, so we need
+ # a little loop here.
+ while 1:
+ (tag_bytes, pos) = local_ReadTag(buffer, pos)
+ if tag_bytes == type_id_tag_bytes:
+ (type_id, pos) = local_DecodeVarint(buffer, pos)
+ elif tag_bytes == message_tag_bytes:
+ (size, message_start) = local_DecodeVarint(buffer, pos)
+ pos = message_end = message_start + size
+ elif tag_bytes == item_end_tag_bytes:
+ break
+ else:
+ pos = SkipField(buffer, pos, end, tag_bytes)
+ if pos == -1:
+ raise _DecodeError('Missing group end tag.')
+
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+
+ if type_id == -1:
+ raise _DecodeError('MessageSet item missing type_id.')
+ if message_start == -1:
+ raise _DecodeError('MessageSet item missing message.')
+
+ extension = extensions_by_number.get(type_id)
+ if extension is not None:
+ value = field_dict.get(extension)
+ if value is None:
+ value = field_dict.setdefault(
+ extension, extension.message_type._concrete_class())
+ if value._InternalParse(buffer, message_start,message_end) != message_end:
+ # The only reason _InternalParse would return early is if it encountered
+ # an end-group tag.
+ raise _DecodeError('Unexpected end-group tag.')
+
+ return pos
+
+ return DecodeItem
+
+# --------------------------------------------------------------------
+# Optimization is not as heavy here because calls to SkipField() are rare,
+# except for handling end-group tags.
+
+def _SkipVarint(buffer, pos, end):
+ """Skip a varint value. Returns the new position."""
+
+ while ord(buffer[pos]) & 0x80:
+ pos += 1
+ pos += 1
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ return pos
+
+def _SkipFixed64(buffer, pos, end):
+ """Skip a fixed64 value. Returns the new position."""
+
+ pos += 8
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ return pos
+
+def _SkipLengthDelimited(buffer, pos, end):
+ """Skip a length-delimited value. Returns the new position."""
+
+ (size, pos) = _DecodeVarint(buffer, pos)
+ pos += size
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ return pos
+
+def _SkipGroup(buffer, pos, end):
+ """Skip sub-group. Returns the new position."""
+
+ while 1:
+ (tag_bytes, pos) = ReadTag(buffer, pos)
+ new_pos = SkipField(buffer, pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ pos = new_pos
+
+def _EndGroup(buffer, pos, end):
+ """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
+
+ return -1
+
+def _SkipFixed32(buffer, pos, end):
+ """Skip a fixed32 value. Returns the new position."""
+
+ pos += 4
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ return pos
+
+def _RaiseInvalidWireType(buffer, pos, end):
+ """Skip function for unknown wire types. Raises an exception."""
+
+ raise _DecodeError('Tag had invalid wire type.')
+
+def _FieldSkipper():
+ """Constructs the SkipField function."""
+
+ WIRETYPE_TO_SKIPPER = [
+ _SkipVarint,
+ _SkipFixed64,
+ _SkipLengthDelimited,
+ _SkipGroup,
+ _EndGroup,
+ _SkipFixed32,
+ _RaiseInvalidWireType,
+ _RaiseInvalidWireType,
+ ]
+
+ wiretype_mask = wire_format.TAG_TYPE_MASK
+ local_ord = ord
+
+ def SkipField(buffer, pos, end, tag_bytes):
+ """Skips a field with the specified tag.
+
+ |pos| should point to the byte immediately after the tag.
+
+ Returns:
+ The new position (after the tag value), or -1 if the tag is an end-group
+ tag (in which case the calling loop should break).
+ """
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by WireFormat from the C++ proto2
-# implementation.
-
-
-class Decoder(object):
-
- """Decodes logical protocol buffer fields from the wire."""
+ # The wire type is always in the first byte since varints are little-endian.
+ wire_type = local_ord(tag_bytes[0]) & wiretype_mask
+ return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
- def __init__(self, s):
- """Initializes the decoder to read from s.
+ return SkipField
- Args:
- s: An immutable sequence of bytes, which must be accessible
- via the Python buffer() primitive (i.e., buffer(s)).
- """
- self._stream = input_stream.InputStream(s)
-
- def EndOfStream(self):
- """Returns true iff we've reached the end of the bytes we're reading."""
- return self._stream.EndOfStream()
-
- def Position(self):
- """Returns the 0-indexed position in |s|."""
- return self._stream.Position()
-
- def ReadFieldNumberAndWireType(self):
- """Reads a tag from the wire. Returns a (field_number, wire_type) pair."""
- tag_and_type = self.ReadUInt32()
- return wire_format.UnpackTag(tag_and_type)
-
- def SkipBytes(self, bytes):
- """Skips the specified number of bytes on the wire."""
- self._stream.SkipBytes(bytes)
-
- # Note that the Read*() methods below are not exactly symmetrical with the
- # corresponding Encoder.Append*() methods. Those Encoder methods first
- # encode a tag, but the Read*() methods below assume that the tag has already
- # been read, and that the client wishes to read a field of the specified type
- # starting at the current position.
-
- def ReadInt32(self):
- """Reads and returns a signed, varint-encoded, 32-bit integer."""
- return self._stream.ReadVarint32()
-
- def ReadInt64(self):
- """Reads and returns a signed, varint-encoded, 64-bit integer."""
- return self._stream.ReadVarint64()
-
- def ReadUInt32(self):
- """Reads and returns an signed, varint-encoded, 32-bit integer."""
- return self._stream.ReadVarUInt32()
-
- def ReadUInt64(self):
- """Reads and returns an signed, varint-encoded,64-bit integer."""
- return self._stream.ReadVarUInt64()
-
- def ReadSInt32(self):
- """Reads and returns a signed, zigzag-encoded, varint-encoded,
- 32-bit integer."""
- return wire_format.ZigZagDecode(self._stream.ReadVarUInt32())
-
- def ReadSInt64(self):
- """Reads and returns a signed, zigzag-encoded, varint-encoded,
- 64-bit integer."""
- return wire_format.ZigZagDecode(self._stream.ReadVarUInt64())
-
- def ReadFixed32(self):
- """Reads and returns an unsigned, fixed-width, 32-bit integer."""
- return self._stream.ReadLittleEndian32()
-
- def ReadFixed64(self):
- """Reads and returns an unsigned, fixed-width, 64-bit integer."""
- return self._stream.ReadLittleEndian64()
-
- def ReadSFixed32(self):
- """Reads and returns a signed, fixed-width, 32-bit integer."""
- value = self._stream.ReadLittleEndian32()
- if value >= (1 << 31):
- value -= (1 << 32)
- return value
-
- def ReadSFixed64(self):
- """Reads and returns a signed, fixed-width, 64-bit integer."""
- value = self._stream.ReadLittleEndian64()
- if value >= (1 << 63):
- value -= (1 << 64)
- return value
-
- def ReadFloat(self):
- """Reads and returns a 4-byte floating-point number."""
- serialized = self._stream.ReadBytes(4)
- return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0]
-
- def ReadDouble(self):
- """Reads and returns an 8-byte floating-point number."""
- serialized = self._stream.ReadBytes(8)
- return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0]
-
- def ReadBool(self):
- """Reads and returns a bool."""
- i = self._stream.ReadVarUInt32()
- return bool(i)
-
- def ReadEnum(self):
- """Reads and returns an enum value."""
- return self._stream.ReadVarUInt32()
-
- def ReadString(self):
- """Reads and returns a length-delimited string."""
- bytes = self.ReadBytes()
- return unicode(bytes, 'utf-8')
-
- def ReadBytes(self):
- """Reads and returns a length-delimited byte sequence."""
- length = self._stream.ReadVarUInt32()
- return self._stream.ReadBytes(length)
-
- def ReadMessageInto(self, msg):
- """Calls msg.MergeFromString() to merge
- length-delimited serialized message data into |msg|.
-
- REQUIRES: The decoder must be positioned at the serialized "length"
- prefix to a length-delmiited serialized message.
-
- POSTCONDITION: The decoder is positioned just after the
- serialized message, and we have merged those serialized
- contents into |msg|.
- """
- length = self._stream.ReadVarUInt32()
- sub_buffer = self._stream.GetSubBuffer(length)
- num_bytes_used = msg.MergeFromString(sub_buffer)
- if num_bytes_used != length:
- raise message.DecodeError(
- 'Submessage told to deserialize from %d-byte encoding, '
- 'but used only %d bytes' % (length, num_bytes_used))
- self._stream.SkipBytes(num_bytes_used)
-
- def ReadGroupInto(self, expected_field_number, group):
- """Calls group.MergeFromString() to merge
- END_GROUP-delimited serialized message data into |group|.
- We'll raise an exception if we don't find an END_GROUP
- tag immediately after the serialized message contents.
-
- REQUIRES: The decoder is positioned just after the START_GROUP
- tag for this group.
-
- POSTCONDITION: The decoder is positioned just after the
- END_GROUP tag for this group, and we have merged
- the contents of the group into |group|.
- """
- sub_buffer = self._stream.GetSubBuffer() # No a priori length limit.
- num_bytes_used = group.MergeFromString(sub_buffer)
- if num_bytes_used < 0:
- raise message.DecodeError('Group message reported negative bytes read.')
- self._stream.SkipBytes(num_bytes_used)
- field_number, field_type = self.ReadFieldNumberAndWireType()
- if field_type != wire_format.WIRETYPE_END_GROUP:
- raise message.DecodeError('Group message did not end with an END_GROUP.')
- if field_number != expected_field_number:
- raise message.DecodeError('END_GROUP tag had field '
- 'number %d, was expecting field number %d' % (
- field_number, expected_field_number))
- # We're now positioned just after the END_GROUP tag. Perfect.
+SkipField = _FieldSkipper()
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
deleted file mode 100755
index 98e46472..00000000
--- a/python/google/protobuf/internal/decoder_test.py
+++ /dev/null
@@ -1,256 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.decoder."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import struct
-import unittest
-from google.protobuf.internal import decoder
-from google.protobuf.internal import encoder
-from google.protobuf.internal import input_stream
-from google.protobuf.internal import wire_format
-from google.protobuf import message
-import logging
-import mox
-
-
-class DecoderTest(unittest.TestCase):
-
- def setUp(self):
- self.mox = mox.Mox()
- self.mock_stream = self.mox.CreateMock(input_stream.InputStream)
- self.mock_message = self.mox.CreateMock(message.Message)
-
- def testReadFieldNumberAndWireType(self):
- # Test field numbers that will require various varint sizes.
- for expected_field_number in (1, 15, 16, 2047, 2048):
- for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
- e = encoder.Encoder()
- e.AppendTag(expected_field_number, expected_wire_type)
- s = e.ToString()
- d = decoder.Decoder(s)
- field_number, wire_type = d.ReadFieldNumberAndWireType()
- self.assertEqual(expected_field_number, field_number)
- self.assertEqual(expected_wire_type, wire_type)
-
- def ReadScalarTestHelper(self, test_name, decoder_method, expected_result,
- expected_stream_method_name,
- stream_method_return, *args):
- """Helper for testReadScalars below.
-
- Calls one of the Decoder.Read*() methods and ensures that the results are
- as expected.
-
- Args:
- test_name: Name of this test, used for logging only.
- decoder_method: Unbound decoder.Decoder method to call.
- expected_result: Value we expect returned from decoder_method().
- expected_stream_method_name: (string) Name of the InputStream
- method we expect Decoder to call to actually read the value
- on the wire.
- stream_method_return: Value our mocked-out stream method should
- return to the decoder.
- args: Additional arguments that we expect to be passed to the
- stream method.
- """
- logging.info('Testing %s scalar input.\n'
- 'Calling %r(), and expecting that to call the '
- 'stream method %s(%r), which will return %r. Finally, '
- 'expecting the Decoder method to return %r'% (
- test_name, decoder_method,
- expected_stream_method_name, args, stream_method_return,
- expected_result))
-
- d = decoder.Decoder('')
- d._stream = self.mock_stream
- if decoder_method in (decoder.Decoder.ReadString,
- decoder.Decoder.ReadBytes):
- self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return))
- # We have to use names instead of methods to work around some
- # mox weirdness. (ResetAll() is overzealous).
- expected_stream_method = getattr(self.mock_stream,
- expected_stream_method_name)
- expected_stream_method(*args).AndReturn(stream_method_return)
-
- self.mox.ReplayAll()
- result = decoder_method(d)
- self.assertEqual(expected_result, result)
- self.assert_(isinstance(result, type(expected_result)))
- self.mox.VerifyAll()
- self.mox.ResetAll()
-
- VAL = 1.125 # Perfectly representable as a float (no rounding error).
- LITTLE_FLOAT_VAL = '\x00\x00\x90?'
- LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
-
- def testReadScalars(self):
- test_string = 'I can feel myself getting sutpider.'
- scalar_tests = [
- ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0],
- ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0],
- ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0],
- ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0],
- ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff,
- 'ReadLittleEndian32', 0xffffffff],
- ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff,
- 'ReadLittleEndian64', 0xffffffffffffffff],
- ['sfixed32', decoder.Decoder.ReadSFixed32, long(-1),
- 'ReadLittleEndian32', long(0xffffffff)],
- ['sfixed64', decoder.Decoder.ReadSFixed64, long(-1),
- 'ReadLittleEndian64', 0xffffffffffffffff],
- ['float', decoder.Decoder.ReadFloat, self.VAL,
- 'ReadBytes', self.LITTLE_FLOAT_VAL, 4],
- ['double', decoder.Decoder.ReadDouble, self.VAL,
- 'ReadBytes', self.LITTLE_DOUBLE_VAL, 8],
- ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1],
- ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23],
- ['string', decoder.Decoder.ReadString,
- unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
- len(test_string)],
- ['utf8-string', decoder.Decoder.ReadString,
- unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
- len(test_string)],
- ['bytes', decoder.Decoder.ReadBytes,
- test_string, 'ReadBytes', test_string, len(test_string)],
- # We test zigzag decoding routines more extensively below.
- ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1],
- ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1],
- ]
- # Ensure that we're testing different Decoder methods and using
- # different test names in all test cases above.
- self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
- self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
- for args in scalar_tests:
- self.ReadScalarTestHelper(*args)
-
- def testReadMessageInto(self):
- length = 23
- def Test(simulate_error):
- d = decoder.Decoder('')
- d._stream = self.mock_stream
- self.mock_stream.ReadVarUInt32().AndReturn(length)
- sub_buffer = object()
- self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer)
-
- if simulate_error:
- self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1)
- self.mox.ReplayAll()
- self.assertRaises(
- message.DecodeError, d.ReadMessageInto, self.mock_message)
- else:
- self.mock_message.MergeFromString(sub_buffer).AndReturn(length)
- self.mock_stream.SkipBytes(length)
- self.mox.ReplayAll()
- d.ReadMessageInto(self.mock_message)
-
- self.mox.VerifyAll()
- self.mox.ResetAll()
-
- Test(simulate_error=False)
- Test(simulate_error=True)
-
- def testReadGroupInto_Success(self):
- # Test both the empty and nonempty cases.
- for num_bytes in (5, 0):
- field_number = expected_field_number = 10
- d = decoder.Decoder('')
- d._stream = self.mock_stream
- sub_buffer = object()
- self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
- self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes)
- self.mock_stream.SkipBytes(num_bytes)
- self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
- field_number, wire_format.WIRETYPE_END_GROUP))
- self.mox.ReplayAll()
- d.ReadGroupInto(expected_field_number, self.mock_message)
- self.mox.VerifyAll()
- self.mox.ResetAll()
-
- def ReadGroupInto_FailureTestHelper(self, bytes_read):
- d = decoder.Decoder('')
- d._stream = self.mock_stream
- sub_buffer = object()
- self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
- self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read)
- return d
-
- def testReadGroupInto_NegativeBytesReported(self):
- expected_field_number = 10
- d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1)
- self.mox.ReplayAll()
- self.assertRaises(message.DecodeError,
- d.ReadGroupInto, expected_field_number,
- self.mock_message)
- self.mox.VerifyAll()
-
- def testReadGroupInto_NoEndGroupTag(self):
- field_number = expected_field_number = 10
- num_bytes = 5
- d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
- self.mock_stream.SkipBytes(num_bytes)
- # Right field number, wrong wire type.
- self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
- field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
- self.mox.ReplayAll()
- self.assertRaises(message.DecodeError,
- d.ReadGroupInto, expected_field_number,
- self.mock_message)
- self.mox.VerifyAll()
-
- def testReadGroupInto_WrongFieldNumberInEndGroupTag(self):
- expected_field_number = 10
- field_number = expected_field_number + 1
- num_bytes = 5
- d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
- self.mock_stream.SkipBytes(num_bytes)
- # Wrong field number, right wire type.
- self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
- field_number, wire_format.WIRETYPE_END_GROUP))
- self.mox.ReplayAll()
- self.assertRaises(message.DecodeError,
- d.ReadGroupInto, expected_field_number,
- self.mock_message)
- self.mox.VerifyAll()
-
- def testSkipBytes(self):
- d = decoder.Decoder('')
- num_bytes = 1024
- self.mock_stream.SkipBytes(num_bytes)
- d._stream = self.mock_stream
- self.mox.ReplayAll()
- d.SkipBytes(num_bytes)
- self.mox.VerifyAll()
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index eb9f2be8..05c27452 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -35,16 +35,30 @@
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
+from google.protobuf import text_format
+
+
+TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """
+name: 'TestEmptyMessage'
+"""
+
class DescriptorTest(unittest.TestCase):
def setUp(self):
+ self.my_file = descriptor.FileDescriptor(
+ name='some/filename/some.proto',
+ package='protobuf_unittest'
+ )
self.my_enum = descriptor.EnumDescriptor(
name='ForeignEnum',
full_name='protobuf_unittest.ForeignEnum',
- filename='ForeignEnum',
+ filename=None,
+ file=self.my_file,
values=[
descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
@@ -53,7 +67,8 @@ class DescriptorTest(unittest.TestCase):
self.my_message = descriptor.Descriptor(
name='NestedMessage',
full_name='protobuf_unittest.TestAllTypes.NestedMessage',
- filename='some/filename/some.proto',
+ filename=None,
+ file=self.my_file,
containing_type=None,
fields=[
descriptor.FieldDescriptor(
@@ -61,7 +76,7 @@ class DescriptorTest(unittest.TestCase):
full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
index=0, number=1,
type=5, cpp_type=1, label=1,
- default_value=0,
+ has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None),
],
@@ -80,6 +95,7 @@ class DescriptorTest(unittest.TestCase):
self.my_service = descriptor.ServiceDescriptor(
name='TestServiceWithOptions',
full_name='protobuf_unittest.TestServiceWithOptions',
+ file=self.my_file,
index=0,
methods=[
self.my_method
@@ -109,5 +125,210 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_service.GetOptions(),
descriptor_pb2.ServiceOptions())
+ def testFileDescriptorReferences(self):
+ self.assertEqual(self.my_enum.file, self.my_file)
+ self.assertEqual(self.my_message.file, self.my_file)
+
+ def testFileDescriptor(self):
+ self.assertEqual(self.my_file.name, 'some/filename/some.proto')
+ self.assertEqual(self.my_file.package, 'protobuf_unittest')
+
+
+class DescriptorCopyToProtoTest(unittest.TestCase):
+ """Tests for CopyTo functions of Descriptor."""
+
+ def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
+ expected_proto = expected_class()
+ text_format.Merge(expected_ascii, expected_proto)
+
+ self.assertEqual(
+ actual_proto, expected_proto,
+ 'Not equal,\nActual:\n%s\nExpected:\n%s\n'
+ % (str(actual_proto), str(expected_proto)))
+
+ def _InternalTestCopyToProto(self, desc, expected_proto_class,
+ expected_proto_ascii):
+ actual = expected_proto_class()
+ desc.CopyToProto(actual)
+ self._AssertProtoEqual(
+ actual, expected_proto_class, expected_proto_ascii)
+
+ def testCopyToProto_EmptyMessage(self):
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestEmptyMessage.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII)
+
+ def testCopyToProto_NestedMessage(self):
+ TEST_NESTED_MESSAGE_ASCII = """
+ name: 'NestedMessage'
+ field: <
+ name: 'bb'
+ number: 1
+ label: 1 # Optional
+ type: 5 # TYPE_INT32
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_NESTED_MESSAGE_ASCII)
+
+ def testCopyToProto_ForeignNestedMessage(self):
+ TEST_FOREIGN_NESTED_ASCII = """
+ name: 'TestForeignNested'
+ field: <
+ name: 'foreign_nested'
+ number: 1
+ label: 1 # Optional
+ type: 11 # TYPE_MESSAGE
+ type_name: '.protobuf_unittest.TestAllTypes.NestedMessage'
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestForeignNested.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_FOREIGN_NESTED_ASCII)
+
+ def testCopyToProto_ForeignEnum(self):
+ TEST_FOREIGN_ENUM_ASCII = """
+ name: 'ForeignEnum'
+ value: <
+ name: 'FOREIGN_FOO'
+ number: 4
+ >
+ value: <
+ name: 'FOREIGN_BAR'
+ number: 5
+ >
+ value: <
+ name: 'FOREIGN_BAZ'
+ number: 6
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2._FOREIGNENUM,
+ descriptor_pb2.EnumDescriptorProto,
+ TEST_FOREIGN_ENUM_ASCII)
+
+ def testCopyToProto_Options(self):
+ TEST_DEPRECATED_FIELDS_ASCII = """
+ name: 'TestDeprecatedFields'
+ field: <
+ name: 'deprecated_int32'
+ number: 1
+ label: 1 # Optional
+ type: 5 # TYPE_INT32
+ options: <
+ deprecated: true
+ >
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestDeprecatedFields.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_DEPRECATED_FIELDS_ASCII)
+
+ def testCopyToProto_AllExtensions(self):
+ TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """
+ name: 'TestEmptyMessageWithExtensions'
+ extension_range: <
+ start: 1
+ end: 536870912
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII)
+
+ def testCopyToProto_SeveralExtensions(self):
+ TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """
+ name: 'TestMultipleExtensionRanges'
+ extension_range: <
+ start: 42
+ end: 43
+ >
+ extension_range: <
+ start: 4143
+ end: 4244
+ >
+ extension_range: <
+ start: 65536
+ end: 536870912
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR,
+ descriptor_pb2.DescriptorProto,
+ TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
+
+ def testCopyToProto_FileDescriptor(self):
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
+ name: 'google/protobuf/unittest_import.proto'
+ package: 'protobuf_unittest_import'
+ message_type: <
+ name: 'ImportMessage'
+ field: <
+ name: 'd'
+ number: 1
+ label: 1 # Optional
+ type: 5 # TYPE_INT32
+ >
+ >
+ """ +
+ """enum_type: <
+ name: 'ImportEnum'
+ value: <
+ name: 'IMPORT_FOO'
+ number: 7
+ >
+ value: <
+ name: 'IMPORT_BAR'
+ number: 8
+ >
+ value: <
+ name: 'IMPORT_BAZ'
+ number: 9
+ >
+ >
+ options: <
+ java_package: 'com.google.protobuf.test'
+ optimize_for: 1 # SPEED
+ >
+ """)
+
+ self._InternalTestCopyToProto(
+ unittest_import_pb2.DESCRIPTOR,
+ descriptor_pb2.FileDescriptorProto,
+ UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
+
+ def testCopyToProto_ServiceDescriptor(self):
+ TEST_SERVICE_ASCII = """
+ name: 'TestService'
+ method: <
+ name: 'Foo'
+ input_type: '.protobuf_unittest.FooRequest'
+ output_type: '.protobuf_unittest.FooResponse'
+ >
+ method: <
+ name: 'Bar'
+ input_type: '.protobuf_unittest.BarRequest'
+ output_type: '.protobuf_unittest.BarResponse'
+ >
+ """
+
+ self._InternalTestCopyToProto(
+ unittest_pb2.TestService.DESCRIPTOR,
+ descriptor_pb2.ServiceDescriptorProto,
+ TEST_SERVICE_ASCII)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 3ec3b2b1..aa05d5b3 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -28,253 +28,659 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Class for encoding protocol message primitives.
+"""Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type
into one of the 5 physical wire types.
+
+This code is designed to push the Python interpreter's performance to the
+limits.
+
+The basic idea is that at startup time, for every field (i.e. every
+FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
+sizer takes a value of this field's type and computes its byte size. The
+encoder takes a writer function and a value. It encodes the value into byte
+strings and invokes the writer function to write those strings. Typically the
+writer function is the write() method of a cStringIO.
+
+We try to do as much work as possible when constructing the writer and the
+sizer rather than when calling them. In particular:
+* We copy any needed global functions to local variables, so that we do not need
+ to do costly global table lookups at runtime.
+* Similarly, we try to do any attribute lookups at startup time if possible.
+* Every field's tag is encoded to bytes at startup, since it can't change at
+ runtime.
+* Whatever component of the field size we can compute at startup, we do.
+* We *avoid* sharing code if doing so would make the code slower and not sharing
+ does not burden us too much. For example, encoders for repeated fields do
+ not just call the encoders for singular fields in a loop because this would
+ add an extra function call overhead for every loop iteration; instead, we
+ manually inline the single-value encoder into the loop.
+* If a Python function lacks a return statement, Python actually generates
+ instructions to pop the result of the last statement off the stack, push
+ None onto the stack, and then return that. If we really don't care what
+ value is returned, then we can save two instructions by returning the
+ result of the last statement. It looks funny but it helps.
+* We assume that type and bounds checking has happened at a higher level.
"""
-__author__ = 'robinson@google.com (Will Robinson)'
+__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-from google.protobuf import message
from google.protobuf.internal import wire_format
-from google.protobuf.internal import output_stream
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by WireFormat from the C++ proto2
-# implementation.
-
-
-class Encoder(object):
-
- """Encodes logical protocol buffer fields to the wire format."""
-
- def __init__(self):
- self._stream = output_stream.OutputStream()
-
- def ToString(self):
- """Returns all values encoded in this object as a string."""
- return self._stream.ToString()
-
- # Append*NoTag methods. These are necessary for serializing packed
- # repeated fields. The Append*() methods call these methods to do
- # the actual serialization.
- def AppendInt32NoTag(self, value):
- """Appends a 32-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarint32(value)
-
- def AppendInt64NoTag(self, value):
- """Appends a 64-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarint64(value)
-
- def AppendUInt32NoTag(self, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarUInt32(unsigned_value)
-
- def AppendUInt64NoTag(self, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarUInt64(unsigned_value)
-
- def AppendSInt32NoTag(self, value):
- """Appends a 32-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- zigzag_value = wire_format.ZigZagEncode(value)
- self._stream.AppendVarUInt32(zigzag_value)
-
- def AppendSInt64NoTag(self, value):
- """Appends a 64-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- zigzag_value = wire_format.ZigZagEncode(value)
- self._stream.AppendVarUInt64(zigzag_value)
-
- def AppendFixed32NoTag(self, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self._stream.AppendLittleEndian32(unsigned_value)
-
- def AppendFixed64NoTag(self, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self._stream.AppendLittleEndian64(unsigned_value)
-
- def AppendSFixed32NoTag(self, value):
- """Appends a signed 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- sign = (value & 0x80000000) and -1 or 0
- if value >> 32 != sign:
- raise message.EncodeError('SFixed32 out of range: %d' % value)
- self._stream.AppendLittleEndian32(value & 0xffffffff)
-
- def AppendSFixed64NoTag(self, value):
- """Appends a signed 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- sign = (value & 0x8000000000000000) and -1 or 0
- if value >> 64 != sign:
- raise message.EncodeError('SFixed64 out of range: %d' % value)
- self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
-
- def AppendFloatNoTag(self, value):
- """Appends a floating-point number to our buffer."""
- self._stream.AppendRawBytes(
- struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
-
- def AppendDoubleNoTag(self, value):
- """Appends a double-precision floating-point number to our buffer."""
- self._stream.AppendRawBytes(
- struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
-
- def AppendBoolNoTag(self, value):
- """Appends a boolean to our buffer."""
- self.AppendInt32NoTag(value)
-
- def AppendEnumNoTag(self, value):
- """Appends an enum value to our buffer."""
- self.AppendInt32NoTag(value)
-
-
- # All the Append*() methods below first append a tag+type pair to the buffer
- # before appending the specified value.
-
- def AppendInt32(self, field_number, value):
- """Appends a 32-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendInt32NoTag(value)
-
- def AppendInt64(self, field_number, value):
- """Appends a 64-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendInt64NoTag(value)
-
- def AppendUInt32(self, field_number, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendUInt32NoTag(unsigned_value)
-
- def AppendUInt64(self, field_number, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendUInt64NoTag(unsigned_value)
-
- def AppendSInt32(self, field_number, value):
- """Appends a 32-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendSInt32NoTag(value)
-
- def AppendSInt64(self, field_number, value):
- """Appends a 64-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendSInt64NoTag(value)
-
- def AppendFixed32(self, field_number, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendFixed32NoTag(unsigned_value)
-
- def AppendFixed64(self, field_number, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendFixed64NoTag(unsigned_value)
-
- def AppendSFixed32(self, field_number, value):
- """Appends a signed 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendSFixed32NoTag(value)
-
- def AppendSFixed64(self, field_number, value):
- """Appends a signed 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendSFixed64NoTag(value)
-
- def AppendFloat(self, field_number, value):
- """Appends a floating-point number to our buffer."""
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendFloatNoTag(value)
-
- def AppendDouble(self, field_number, value):
- """Appends a double-precision floating-point number to our buffer."""
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendDoubleNoTag(value)
-
- def AppendBool(self, field_number, value):
- """Appends a boolean to our buffer."""
- self.AppendInt32(field_number, value)
-
- def AppendEnum(self, field_number, value):
- """Appends an enum value to our buffer."""
- self.AppendInt32(field_number, value)
-
- def AppendString(self, field_number, value):
- """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer,
- with the length varint-encoded.
- """
- self.AppendBytes(field_number, value.encode('utf-8'))
-
- def AppendBytes(self, field_number, value):
- """Appends a length-prefixed sequence of bytes to our buffer, with the
- length varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- self._stream.AppendVarUInt32(len(value))
- self._stream.AppendRawBytes(value)
-
- # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to
- # avoid the extra string copy here. We can do so if we widen the Message
- # interface to be able to serialize to a stream in addition to a string. The
- # challenge when thinking ahead to the Python/C API implementation of Message
- # is finding a stream-like Python thing to which we can write raw bytes
- # from C. I'm not sure such a thing exists(?). (array.array is pretty much
- # what we want, but it's not directly exposed in the Python/C API).
-
- def AppendGroup(self, field_number, group):
- """Appends a group to our buffer.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
- self._stream.AppendRawBytes(group.SerializeToString())
- self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
-
- def AppendMessage(self, field_number, msg):
- """Appends a nested message to our buffer.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- self._stream.AppendVarUInt32(msg.ByteSize())
- self._stream.AppendRawBytes(msg.SerializeToString())
-
- def AppendMessageSetItem(self, field_number, msg):
- """Appends an item using the message set wire format.
-
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
+
+
+def _VarintSize(value):
+ """Compute the size of a varint value."""
+ if value <= 0x7f: return 1
+ if value <= 0x3fff: return 2
+ if value <= 0x1fffff: return 3
+ if value <= 0xfffffff: return 4
+ if value <= 0x7ffffffff: return 5
+ if value <= 0x3ffffffffff: return 6
+ if value <= 0x1ffffffffffff: return 7
+ if value <= 0xffffffffffffff: return 8
+ if value <= 0x7fffffffffffffff: return 9
+ return 10
+
+
+def _SignedVarintSize(value):
+ """Compute the size of a signed varint value."""
+ if value < 0: return 10
+ if value <= 0x7f: return 1
+ if value <= 0x3fff: return 2
+ if value <= 0x1fffff: return 3
+ if value <= 0xfffffff: return 4
+ if value <= 0x7ffffffff: return 5
+ if value <= 0x3ffffffffff: return 6
+ if value <= 0x1ffffffffffff: return 7
+ if value <= 0xffffffffffffff: return 8
+ if value <= 0x7fffffffffffffff: return 9
+ return 10
+
+
+def _TagSize(field_number):
+ """Returns the number of bytes required to serialize a tag with this field
+ number."""
+ # Just pass in type 0, since the type won't affect the tag+type size.
+ return _VarintSize(wire_format.PackTag(field_number, 0))
+
+
+# --------------------------------------------------------------------
+# In this section we define some generic sizers. Each of these functions
+# takes parameters specific to a particular field type, e.g. int32 or fixed64.
+# It returns another function which in turn takes parameters specific to a
+# particular field, e.g. the field number and whether it is repeated or packed.
+# Look at the next section to see how these are used.
+
+
+def _SimpleSizer(compute_value_size):
+ """A sizer which uses the function compute_value_size to compute the size of
+ each value. Typically compute_value_size is _VarintSize."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = 0
+ for element in value:
+ result += compute_value_size(element)
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += compute_value_size(element)
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + compute_value_size(value)
+ return FieldSize
+
+ return SpecificSizer
+
+
+def _ModifiedSizer(compute_value_size, modify_value):
+ """Like SimpleSizer, but modify_value is invoked on each value before it is
+ passed to compute_value_size. modify_value is typically ZigZagEncode."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = 0
+ for element in value:
+ result += compute_value_size(modify_value(element))
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += compute_value_size(modify_value(element))
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + compute_value_size(modify_value(value))
+ return FieldSize
+
+ return SpecificSizer
+
+
+def _FixedSizer(value_size):
+ """Like _SimpleSizer except for a fixed-size field. The input is the size
+ of one value."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = len(value) * value_size
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ element_size = value_size + tag_size
+ def RepeatedFieldSize(value):
+ return len(value) * element_size
+ return RepeatedFieldSize
+ else:
+ field_size = value_size + tag_size
+ def FieldSize(value):
+ return field_size
+ return FieldSize
+
+ return SpecificSizer
+
+
+# ====================================================================
+# Here we declare a sizer constructor for each field type. Each "sizer
+# constructor" is a function that takes (field_number, is_repeated, is_packed)
+# as parameters and returns a sizer, which in turn takes a field value as
+# a parameter and returns its encoded size.
+
+
+Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
+
+UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
+
+SInt32Sizer = SInt64Sizer = _ModifiedSizer(
+ _SignedVarintSize, wire_format.ZigZagEncode)
+
+Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
+Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
+
+BoolSizer = _FixedSizer(1)
+
+
+def StringSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a string field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = local_len(element.encode('utf-8'))
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = local_len(value.encode('utf-8'))
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+def BytesSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a bytes field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = local_len(element)
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = local_len(value)
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+def GroupSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a group field."""
+
+ tag_size = _TagSize(field_number) * 2
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += element.ByteSize()
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + value.ByteSize()
+ return FieldSize
+
+
+def MessageSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a message field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = element.ByteSize()
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = value.ByteSize()
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+# --------------------------------------------------------------------
+# MessageSet is special.
+
+
+def MessageSetItemSizer(field_number):
+ """Returns a sizer for extensions of MessageSet.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
+ }
+ }
+ """
+ static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
+ _TagSize(3))
+ local_VarintSize = _VarintSize
+
+ def FieldSize(value):
+ l = value.ByteSize()
+ return static_size + local_VarintSize(l) + l
+
+ return FieldSize
+
+
+# ====================================================================
+# Encoders!
+
+
+def _VarintEncoder():
+ """Return an encoder for a basic varint value (does not include tag)."""
+
+ local_chr = chr
+ def EncodeVarint(write, value):
+ bits = value & 0x7f
+ value >>= 7
+ while value:
+ write(local_chr(0x80|bits))
+ bits = value & 0x7f
+ value >>= 7
+ return write(local_chr(bits))
+
+ return EncodeVarint
+
+
+def _SignedVarintEncoder():
+ """Return an encoder for a basic signed varint value (does not include
+ tag)."""
+
+ local_chr = chr
+ def EncodeSignedVarint(write, value):
+ if value < 0:
+ value += (1 << 64)
+ bits = value & 0x7f
+ value >>= 7
+ while value:
+ write(local_chr(0x80|bits))
+ bits = value & 0x7f
+ value >>= 7
+ return write(local_chr(bits))
+
+ return EncodeSignedVarint
+
+
+_EncodeVarint = _VarintEncoder()
+_EncodeSignedVarint = _SignedVarintEncoder()
+
+
+def _VarintBytes(value):
+ """Encode the given integer as a varint and return the bytes. This is only
+ called at startup time so it doesn't need to be fast."""
+
+ pieces = []
+ _EncodeVarint(pieces.append, value)
+ return "".join(pieces)
+
+
+def TagBytes(field_number, wire_type):
+ """Encode the given tag and return the bytes. Only called at startup."""
+
+ return _VarintBytes(wire_format.PackTag(field_number, wire_type))
+
+# --------------------------------------------------------------------
+# As with sizers (see above), we have a number of common encoder
+# implementations.
+
+
+def _SimpleEncoder(wire_type, encode_value, compute_value_size):
+ """Return a constructor for an encoder for fields of a particular type.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ encode_value: A function which encodes an individual value, e.g.
+ _EncodeVarint().
+ compute_value_size: A function which computes the size of an individual
+ value, e.g. _VarintSize().
+ """
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ size = 0
+ for element in value:
+ size += compute_value_size(element)
+ local_EncodeVarint(write, size)
+ for element in value:
+ encode_value(write, element)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ encode_value(write, element)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return encode_value(write, value)
+ return EncodeField
+
+ return SpecificEncoder
+
+
+def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
+ """Like SimpleEncoder but additionally invokes modify_value on every value
+ before passing it to encode_value. Usually modify_value is ZigZagEncode."""
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ size = 0
+ for element in value:
+ size += compute_value_size(modify_value(element))
+ local_EncodeVarint(write, size)
+ for element in value:
+ encode_value(write, modify_value(element))
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ encode_value(write, modify_value(element))
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return encode_value(write, modify_value(value))
+ return EncodeField
+
+ return SpecificEncoder
+
+
+def _StructPackEncoder(wire_type, format):
+ """Return a constructor for an encoder for a fixed-width field.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ format: The format string to pass to struct.pack().
+ """
+
+ value_size = struct.calcsize(format)
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ local_struct_pack = struct.pack
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value) * value_size)
+ for element in value:
+ write(local_struct_pack(format, element))
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ write(local_struct_pack(format, element))
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return write(local_struct_pack(format, value))
+ return EncodeField
+
+ return SpecificEncoder
+
+
+# ====================================================================
+# Here we declare an encoder constructor for each field type. These work
+# very similarly to sizer constructors, described earlier.
+
+
+Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
+
+UInt32Encoder = UInt64Encoder = _SimpleEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
+
+SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
+ wire_format.ZigZagEncode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
+
+
+def BoolEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a boolean field."""
+
+ false_byte = chr(0)
+ true_byte = chr(1)
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value))
+ for element in value:
+ if element:
+ write(true_byte)
+ else:
+ write(false_byte)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ if element:
+ write(true_byte)
+ else:
+ write(false_byte)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ if value:
+ return write(true_byte)
+ return write(false_byte)
+ return EncodeField
+
+
+def StringEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a string field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ encoded = element.encode('utf-8')
+ write(tag)
+ local_EncodeVarint(write, local_len(encoded))
+ write(encoded)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ encoded = value.encode('utf-8')
+ write(tag)
+ local_EncodeVarint(write, local_len(encoded))
+ return write(encoded)
+ return EncodeField
+
+
+def BytesEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a bytes field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag)
+ local_EncodeVarint(write, local_len(element))
+ write(element)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(tag)
+ local_EncodeVarint(write, local_len(value))
+ return write(value)
+ return EncodeField
+
+
+def GroupEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a group field."""
+
+ start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
+ end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(start_tag)
+ element._InternalSerialize(write)
+ write(end_tag)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(start_tag)
+ value._InternalSerialize(write)
+ return write(end_tag)
+ return EncodeField
+
+
+def MessageEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a message field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag)
+ local_EncodeVarint(write, element.ByteSize())
+ element._InternalSerialize(write)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(tag)
+ local_EncodeVarint(write, value.ByteSize())
+ return value._InternalSerialize(write)
+ return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, MessageSet is special.
+
+
+def MessageSetItemEncoder(field_number):
+ """Encoder for extensions of MessageSet.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
}
- """
- self.AppendTag(1, wire_format.WIRETYPE_START_GROUP)
- self.AppendInt32(2, field_number)
- self.AppendMessage(3, msg)
- self.AppendTag(1, wire_format.WIRETYPE_END_GROUP)
-
- def AppendTag(self, field_number, wire_type):
- """Appends a tag containing field number and wire type information."""
- self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type))
+ }
+ """
+ start_bytes = "".join([
+ TagBytes(1, wire_format.WIRETYPE_START_GROUP),
+ TagBytes(2, wire_format.WIRETYPE_VARINT),
+ _VarintBytes(field_number),
+ TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
+ end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+ local_EncodeVarint = _EncodeVarint
+
+ def EncodeField(write, value):
+ write(start_bytes)
+ local_EncodeVarint(write, value.ByteSize())
+ value._InternalSerialize(write)
+ return write(end_bytes)
+
+ return EncodeField
diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py
deleted file mode 100755
index bf75ea80..00000000
--- a/python/google/protobuf/internal/encoder_test.py
+++ /dev/null
@@ -1,286 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.encoder."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import struct
-import logging
-import unittest
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import encoder
-from google.protobuf.internal import output_stream
-from google.protobuf import message
-import mox
-
-
-class EncoderTest(unittest.TestCase):
-
- def setUp(self):
- self.mox = mox.Mox()
- self.encoder = encoder.Encoder()
- self.mock_stream = self.mox.CreateMock(output_stream.OutputStream)
- self.mock_message = self.mox.CreateMock(message.Message)
- self.encoder._stream = self.mock_stream
-
- def PackTag(self, field_number, wire_type):
- return wire_format.PackTag(field_number, wire_type)
-
- def AppendScalarTestHelper(self, test_name, encoder_method,
- expected_stream_method_name,
- wire_type, field_value,
- expected_value=None, expected_length=None,
- is_tag_test=True):
- """Helper for testAppendScalars.
-
- Calls one of the Encoder methods, and ensures that the Encoder
- in turn makes the expected calls into its OutputStream.
-
- Args:
- test_name: Name of this test, used only for logging.
- encoder_method: Callable on self.encoder. This is the Encoder
- method we're testing. If is_tag_test=True, the encoder method
- accepts a field_number and field_value. if is_tag_test=False,
- the encoder method accepts a field_value.
- expected_stream_method_name: (string) Name of the OutputStream
- method we expect Encoder to call to actually put the value
- on the wire.
- wire_type: The WIRETYPE_* constant we expect encoder to
- use in the specified encoder_method.
- field_value: The value we're trying to encode. Passed
- into encoder_method.
- expected_value: The value we expect Encoder to pass into
- the OutputStream method. If None, we expect field_value
- to pass through unmodified.
- expected_length: The length we expect Encoder to pass to the
- AppendVarUInt32 method. If None we expect the length of the
- field_value.
- is_tag_test: A Boolean. If True (the default), we append the
- the packed field number and wire_type to the stream before
- the field value.
- """
- if expected_value is None:
- expected_value = field_value
-
- logging.info('Testing %s scalar output.\n'
- 'Calling %r(%r), and expecting that to call the '
- 'stream method %s(%r).' % (
- test_name, encoder_method, field_value,
- expected_stream_method_name, expected_value))
-
- if is_tag_test:
- field_number = 10
- # Should first append the field number and type information.
- self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
- # If we're length-delimited, we should then append the length.
- if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
- if expected_length is None:
- expected_length = len(field_value)
- self.mock_stream.AppendVarUInt32(expected_length)
-
- # Should then append the value itself.
- # We have to use names instead of methods to work around some
- # mox weirdness. (ResetAll() is overzealous).
- expected_stream_method = getattr(self.mock_stream,
- expected_stream_method_name)
- expected_stream_method(expected_value)
-
- self.mox.ReplayAll()
- if is_tag_test:
- encoder_method(field_number, field_value)
- else:
- encoder_method(field_value)
- self.mox.VerifyAll()
- self.mox.ResetAll()
-
- VAL = 1.125 # Perfectly representable as a float (no rounding error).
- LITTLE_FLOAT_VAL = '\x00\x00\x90?'
- LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
-
- def testAppendScalars(self):
- utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'
- utf8_string = unicode(utf8_bytes, 'utf-8')
- scalar_tests = [
- ['int32', self.encoder.AppendInt32, 'AppendVarint32',
- wire_format.WIRETYPE_VARINT, 0],
- ['int64', self.encoder.AppendInt64, 'AppendVarint64',
- wire_format.WIRETYPE_VARINT, 0],
- ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32',
- wire_format.WIRETYPE_VARINT, 0],
- ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64',
- wire_format.WIRETYPE_VARINT, 0],
- ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32',
- wire_format.WIRETYPE_FIXED32, 0],
- ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64',
- wire_format.WIRETYPE_FIXED64, 0],
- ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32',
- wire_format.WIRETYPE_FIXED32, -1, 0xffffffff],
- ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64',
- wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff],
- ['float', self.encoder.AppendFloat, 'AppendRawBytes',
- wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL],
- ['double', self.encoder.AppendDouble, 'AppendRawBytes',
- wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL],
- ['bool', self.encoder.AppendBool, 'AppendVarint32',
- wire_format.WIRETYPE_VARINT, False],
- ['enum', self.encoder.AppendEnum, 'AppendVarint32',
- wire_format.WIRETYPE_VARINT, 0],
- ['string', self.encoder.AppendString, 'AppendRawBytes',
- wire_format.WIRETYPE_LENGTH_DELIMITED,
- "You're in a maze of twisty little passages, all alike."],
- ['utf8-string', self.encoder.AppendString, 'AppendRawBytes',
- wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string,
- utf8_bytes, len(utf8_bytes)],
- # We test zigzag encoding routines more extensively below.
- ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32',
- wire_format.WIRETYPE_VARINT, -1, 1],
- ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64',
- wire_format.WIRETYPE_VARINT, -1, 1],
- ]
- # Ensure that we're testing different Encoder methods and using
- # different test names in all test cases above.
- self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
- self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
- for args in scalar_tests:
- self.AppendScalarTestHelper(*args)
-
- def testAppendScalarsWithoutTags(self):
- scalar_no_tag_tests = [
- ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0],
- ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0],
- ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0],
- ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0],
- ['fixed32', self.encoder.AppendFixed32NoTag,
- 'AppendLittleEndian32', None, 0],
- ['fixed64', self.encoder.AppendFixed64NoTag,
- 'AppendLittleEndian64', None, 0],
- ['sfixed32', self.encoder.AppendSFixed32NoTag,
- 'AppendLittleEndian32', None, 0],
- ['sfixed64', self.encoder.AppendSFixed64NoTag,
- 'AppendLittleEndian64', None, 0],
- ['float', self.encoder.AppendFloatNoTag,
- 'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL],
- ['double', self.encoder.AppendDoubleNoTag,
- 'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL],
- ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
- ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
- ['sint32', self.encoder.AppendSInt32NoTag,
- 'AppendVarUInt32', None, -1, 1],
- ['sint64', self.encoder.AppendSInt64NoTag,
- 'AppendVarUInt64', None, -1, 1],
- ]
-
- self.assertEqual(len(scalar_no_tag_tests),
- len(set(t[0] for t in scalar_no_tag_tests)))
- self.assert_(len(scalar_no_tag_tests) >=
- len(set(t[1] for t in scalar_no_tag_tests)))
- for args in scalar_no_tag_tests:
- # For no tag tests, the wire_type is not used, so we put in None.
- self.AppendScalarTestHelper(is_tag_test=False, *args)
-
- def testAppendGroup(self):
- field_number = 23
- # Should first append the start-group marker.
- self.mock_stream.AppendVarUInt32(
- self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP))
- # Should then serialize itself.
- self.mock_message.SerializeToString().AndReturn('foo')
- self.mock_stream.AppendRawBytes('foo')
- # Should finally append the end-group marker.
- self.mock_stream.AppendVarUInt32(
- self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP))
-
- self.mox.ReplayAll()
- self.encoder.AppendGroup(field_number, self.mock_message)
- self.mox.VerifyAll()
-
- def testAppendMessage(self):
- field_number = 23
- byte_size = 42
- # Should first append the field number and type information.
- self.mock_stream.AppendVarUInt32(
- self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
- # Should then append its length.
- self.mock_message.ByteSize().AndReturn(byte_size)
- self.mock_stream.AppendVarUInt32(byte_size)
- # Should then serialize itself to the encoder.
- self.mock_message.SerializeToString().AndReturn('foo')
- self.mock_stream.AppendRawBytes('foo')
-
- self.mox.ReplayAll()
- self.encoder.AppendMessage(field_number, self.mock_message)
- self.mox.VerifyAll()
-
- def testAppendMessageSetItem(self):
- field_number = 23
- byte_size = 42
- # Should first append the field number and type information.
- self.mock_stream.AppendVarUInt32(
- self.PackTag(1, wire_format.WIRETYPE_START_GROUP))
- self.mock_stream.AppendVarUInt32(
- self.PackTag(2, wire_format.WIRETYPE_VARINT))
- self.mock_stream.AppendVarint32(field_number)
- self.mock_stream.AppendVarUInt32(
- self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED))
- # Should then append its length.
- self.mock_message.ByteSize().AndReturn(byte_size)
- self.mock_stream.AppendVarUInt32(byte_size)
- # Should then serialize itself to the encoder.
- self.mock_message.SerializeToString().AndReturn('foo')
- self.mock_stream.AppendRawBytes('foo')
- self.mock_stream.AppendVarUInt32(
- self.PackTag(1, wire_format.WIRETYPE_END_GROUP))
-
- self.mox.ReplayAll()
- self.encoder.AppendMessageSetItem(field_number, self.mock_message)
- self.mox.VerifyAll()
-
- def testAppendSFixed(self):
- # Most of our bounds-checking is done in output_stream.py,
- # but encoder.py is responsible for transforming signed
- # fixed-width integers into unsigned ones, so we test here
- # to ensure that we're not losing any entropy when we do
- # that conversion.
- field_number = 10
- self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
- 10, wire_format.UINT32_MAX + 1)
- self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
- 10, -(1 << 32))
- self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
- 10, wire_format.UINT64_MAX + 1)
- self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
- 10, -(1 << 64))
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index 11fcfa0c..dd27c9a3 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -35,15 +35,20 @@
# indirect testing of the protocol compiler output.
"""Unittest that directly tests the output of the pure-Python protocol
-compiler. See //net/proto2/internal/reflection_test.py for a test which
+compiler. See //google/protobuf/reflection_test.py for a test which
further ensures that we can use Python protocol message objects as we expect.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
+from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
+from google.protobuf import unittest_no_generic_services_pb2
+
+
+MAX_EXTENSION = 536870912
class GeneratorTest(unittest.TestCase):
@@ -71,6 +76,31 @@ class GeneratorTest(unittest.TestCase):
self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
+ def testExtremeDefaultValues(self):
+ message = unittest_pb2.TestExtremeDefaultValues()
+ self.assertEquals(float('inf'), message.inf_double)
+ self.assertEquals(float('-inf'), message.neg_inf_double)
+ self.assert_(message.nan_double != message.nan_double)
+ self.assertEquals(float('inf'), message.inf_float)
+ self.assertEquals(float('-inf'), message.neg_inf_float)
+ self.assert_(message.nan_float != message.nan_float)
+
+ def testHasDefaultValues(self):
+ desc = unittest_pb2.TestAllTypes.DESCRIPTOR
+
+ expected_has_default_by_name = {
+ 'optional_int32': False,
+ 'repeated_int32': False,
+ 'optional_nested_message': False,
+ 'default_int32': True,
+ }
+
+ has_default_by_name = dict(
+ [(f.name, f.has_default_value)
+ for f in desc.fields
+ if f.name in expected_has_default_by_name])
+ self.assertEqual(expected_has_default_by_name, has_default_by_name)
+
def testContainingTypeBehaviorForExtensions(self):
self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
unittest_pb2.TestAllExtensions.DESCRIPTOR)
@@ -95,6 +125,81 @@ class GeneratorTest(unittest.TestCase):
proto = unittest_mset_pb2.TestMessageSet()
self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+ def testNestedTypes(self):
+ self.assertEquals(
+ set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
+ set([
+ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
+ unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR,
+ unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR,
+ ]))
+ self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, [])
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, [])
+
+ def testContainingType(self):
+ self.assertTrue(
+ unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None)
+ self.assertTrue(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type,
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+
+ def testContainingTypeInEnumDescriptor(self):
+ self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None)
+ self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type,
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+
+ def testPackage(self):
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.file.package,
+ 'protobuf_unittest')
+ desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR
+ self.assertEqual(desc.file.package, 'protobuf_unittest')
+ self.assertEqual(
+ unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package,
+ 'protobuf_unittest_import')
+
+ self.assertEqual(
+ unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest')
+ self.assertEqual(
+ unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package,
+ 'protobuf_unittest')
+ self.assertEqual(
+ unittest_import_pb2._IMPORTENUM.file.package,
+ 'protobuf_unittest_import')
+
+ def testExtensionRange(self):
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, [])
+ self.assertEqual(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
+ [(1, MAX_EXTENSION)])
+ self.assertEqual(
+ unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
+ [(42, 43), (4143, 4244), (65536, MAX_EXTENSION)])
+
+ def testFileDescriptor(self):
+ self.assertEqual(unittest_pb2.DESCRIPTOR.name,
+ 'google/protobuf/unittest.proto')
+ self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
+ self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
+
+ def testNoGenericServices(self):
+ # unittest_no_generic_services.proto should contain defs for everything
+ # except services.
+ self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
+ self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
+ self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
+ self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService"))
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py
deleted file mode 100755
index 7bda17e3..00000000
--- a/python/google/protobuf/internal/input_stream.py
+++ /dev/null
@@ -1,338 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""InputStream is the primitive interface for reading bits from the wire.
-
-All protocol buffer deserialization can be expressed in terms of
-the InputStream primitives provided here.
-"""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import array
-import struct
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by CodedInputStream from the C++
-# proto2 implementation.
-
-
-class InputStreamBuffer(object):
-
- """Contains all logic for reading bits, and dealing with stream position.
-
- If an InputStream method ever raises an exception, the stream is left
- in an indeterminate state and is not safe for further use.
- """
-
- def __init__(self, s):
- # What we really want is something like array('B', s), where elements we
- # read from the array are already given to us as one-byte integers. BUT
- # using array() instead of buffer() would force full string copies to result
- # from each GetSubBuffer() call.
- #
- # So, if the N serialized bytes of a single protocol buffer object are
- # split evenly between 2 child messages, and so on recursively, using
- # array('B', s) instead of buffer() would incur an additional N*logN bytes
- # copied during deserialization.
- #
- # The higher constant overhead of having to ord() for every byte we read
- # from the buffer in _ReadVarintHelper() could definitely lead to worse
- # performance in many real-world scenarios, even if the asymptotic
- # complexity is better. However, our real answer is that the mythical
- # Python/C extension module output mode for the protocol compiler will
- # be blazing-fast and will eliminate most use of this class anyway.
- self._buffer = buffer(s)
- self._pos = 0
-
- def EndOfStream(self):
- """Returns true iff we're at the end of the stream.
- If this returns true, then a call to any other InputStream method
- will raise an exception.
- """
- return self._pos >= len(self._buffer)
-
- def Position(self):
- """Returns the current position in the stream, or equivalently, the
- number of bytes read so far.
- """
- return self._pos
-
- def GetSubBuffer(self, size=None):
- """Returns a sequence-like object that represents a portion of our
- underlying sequence.
-
- Position 0 in the returned object corresponds to self.Position()
- in this stream.
-
- If size is specified, then the returned object ends after the
- next "size" bytes in this stream. If size is not specified,
- then the returned object ends at the end of this stream.
-
- We guarantee that the returned object R supports the Python buffer
- interface (and thus that the call buffer(R) will work).
-
- Note that the returned buffer is read-only.
-
- The intended use for this method is for nested-message and nested-group
- deserialization, where we want to make a recursive MergeFromString()
- call on the portion of the original sequence that contains the serialized
- nested message. (And we'd like to do so without making unnecessary string
- copies).
-
- REQUIRES: size is nonnegative.
- """
- # Note that buffer() doesn't perform any actual string copy.
- if size is None:
- return buffer(self._buffer, self._pos)
- else:
- if size < 0:
- raise message.DecodeError('Negative size %d' % size)
- return buffer(self._buffer, self._pos, size)
-
- def SkipBytes(self, num_bytes):
- """Skip num_bytes bytes ahead, or go to the end of the stream, whichever
- comes first.
-
- REQUIRES: num_bytes is nonnegative.
- """
- if num_bytes < 0:
- raise message.DecodeError('Negative num_bytes %d' % num_bytes)
- self._pos += num_bytes
- self._pos = min(self._pos, len(self._buffer))
-
- def ReadBytes(self, size):
- """Reads up to 'size' bytes from the stream, stopping early
- only if we reach the end of the stream. Returns the bytes read
- as a string.
- """
- if size < 0:
- raise message.DecodeError('Negative size %d' % size)
- s = (self._buffer[self._pos : self._pos + size])
- self._pos += len(s) # Only advance by the number of bytes actually read.
- return s
-
- def ReadLittleEndian32(self):
- """Interprets the next 4 bytes of the stream as a little-endian
- encoded, unsiged 32-bit integer, and returns that integer.
- """
- try:
- i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
- self._buffer[self._pos : self._pos + 4])
- self._pos += 4
- return i[0] # unpack() result is a 1-element tuple.
- except struct.error, e:
- raise message.DecodeError(e)
-
- def ReadLittleEndian64(self):
- """Interprets the next 8 bytes of the stream as a little-endian
- encoded, unsiged 64-bit integer, and returns that integer.
- """
- try:
- i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
- self._buffer[self._pos : self._pos + 8])
- self._pos += 8
- return i[0] # unpack() result is a 1-element tuple.
- except struct.error, e:
- raise message.DecodeError(e)
-
- def ReadVarint32(self):
- """Reads a varint from the stream, interprets this varint
- as a signed, 32-bit integer, and returns the integer.
- """
- i = self.ReadVarint64()
- if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
- raise message.DecodeError('Value out of range for int32: %d' % i)
- return int(i)
-
- def ReadVarUInt32(self):
- """Reads a varint from the stream, interprets this varint
- as an unsigned, 32-bit integer, and returns the integer.
- """
- i = self.ReadVarUInt64()
- if i > wire_format.UINT32_MAX:
- raise message.DecodeError('Value out of range for uint32: %d' % i)
- return i
-
- def ReadVarint64(self):
- """Reads a varint from the stream, interprets this varint
- as a signed, 64-bit integer, and returns the integer.
- """
- i = self.ReadVarUInt64()
- if i > wire_format.INT64_MAX:
- i -= (1 << 64)
- return i
-
- def ReadVarUInt64(self):
- """Reads a varint from the stream, interprets this varint
- as an unsigned, 64-bit integer, and returns the integer.
- """
- i = self._ReadVarintHelper()
- if not 0 <= i <= wire_format.UINT64_MAX:
- raise message.DecodeError('Value out of range for uint64: %d' % i)
- return i
-
- def _ReadVarintHelper(self):
- """Helper for the various varint-reading methods above.
- Reads an unsigned, varint-encoded integer from the stream and
- returns this integer.
-
- Does no bounds checking except to ensure that we read at most as many bytes
- as could possibly be present in a varint-encoded 64-bit number.
- """
- result = 0
- shift = 0
- while 1:
- if shift >= 64:
- raise message.DecodeError('Too many bytes when decoding varint.')
- try:
- b = ord(self._buffer[self._pos])
- except IndexError:
- raise message.DecodeError('Truncated varint.')
- self._pos += 1
- result |= ((b & 0x7f) << shift)
- shift += 7
- if not (b & 0x80):
- return result
-
-
-class InputStreamArray(object):
-
- """Contains all logic for reading bits, and dealing with stream position.
-
- If an InputStream method ever raises an exception, the stream is left
- in an indeterminate state and is not safe for further use.
-
- This alternative to InputStreamBuffer is used in environments where buffer()
- is unavailble, such as Google App Engine.
- """
-
- def __init__(self, s):
- self._buffer = array.array('B', s)
- self._pos = 0
-
- def EndOfStream(self):
- return self._pos >= len(self._buffer)
-
- def Position(self):
- return self._pos
-
- def GetSubBuffer(self, size=None):
- if size is None:
- return self._buffer[self._pos : ].tostring()
- else:
- if size < 0:
- raise message.DecodeError('Negative size %d' % size)
- return self._buffer[self._pos : self._pos + size].tostring()
-
- def SkipBytes(self, num_bytes):
- if num_bytes < 0:
- raise message.DecodeError('Negative num_bytes %d' % num_bytes)
- self._pos += num_bytes
- self._pos = min(self._pos, len(self._buffer))
-
- def ReadBytes(self, size):
- if size < 0:
- raise message.DecodeError('Negative size %d' % size)
- s = self._buffer[self._pos : self._pos + size].tostring()
- self._pos += len(s) # Only advance by the number of bytes actually read.
- return s
-
- def ReadLittleEndian32(self):
- try:
- i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
- self._buffer[self._pos : self._pos + 4])
- self._pos += 4
- return i[0] # unpack() result is a 1-element tuple.
- except struct.error, e:
- raise message.DecodeError(e)
-
- def ReadLittleEndian64(self):
- try:
- i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
- self._buffer[self._pos : self._pos + 8])
- self._pos += 8
- return i[0] # unpack() result is a 1-element tuple.
- except struct.error, e:
- raise message.DecodeError(e)
-
- def ReadVarint32(self):
- i = self.ReadVarint64()
- if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
- raise message.DecodeError('Value out of range for int32: %d' % i)
- return int(i)
-
- def ReadVarUInt32(self):
- i = self.ReadVarUInt64()
- if i > wire_format.UINT32_MAX:
- raise message.DecodeError('Value out of range for uint32: %d' % i)
- return i
-
- def ReadVarint64(self):
- i = self.ReadVarUInt64()
- if i > wire_format.INT64_MAX:
- i -= (1 << 64)
- return i
-
- def ReadVarUInt64(self):
- i = self._ReadVarintHelper()
- if not 0 <= i <= wire_format.UINT64_MAX:
- raise message.DecodeError('Value out of range for uint64: %d' % i)
- return i
-
- def _ReadVarintHelper(self):
- result = 0
- shift = 0
- while 1:
- if shift >= 64:
- raise message.DecodeError('Too many bytes when decoding varint.')
- try:
- b = self._buffer[self._pos]
- except IndexError:
- raise message.DecodeError('Truncated varint.')
- self._pos += 1
- result |= ((b & 0x7f) << shift)
- shift += 7
- if not (b & 0x80):
- return result
-
-
-try:
- buffer('')
- InputStream = InputStreamBuffer
-except NotImplementedError:
- # Google App Engine: dev_appserver.py
- InputStream = InputStreamArray
-except RuntimeError:
- # Google App Engine: production
- InputStream = InputStreamArray
diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py
deleted file mode 100755
index ecec7f7d..00000000
--- a/python/google/protobuf/internal/input_stream_test.py
+++ /dev/null
@@ -1,314 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.input_stream."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import unittest
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import input_stream
-
-
-class InputStreamBufferTest(unittest.TestCase):
-
- def setUp(self):
- self.__original_input_stream = input_stream.InputStream
- input_stream.InputStream = input_stream.InputStreamBuffer
-
- def tearDown(self):
- input_stream.InputStream = self.__original_input_stream
-
- def testEndOfStream(self):
- stream = input_stream.InputStream('abcd')
- self.assertFalse(stream.EndOfStream())
- self.assertEqual('abcd', stream.ReadBytes(10))
- self.assertTrue(stream.EndOfStream())
-
- def testPosition(self):
- stream = input_stream.InputStream('abcd')
- self.assertEqual(0, stream.Position())
- self.assertEqual(0, stream.Position()) # No side-effects.
- stream.ReadBytes(1)
- self.assertEqual(1, stream.Position())
- stream.ReadBytes(1)
- self.assertEqual(2, stream.Position())
- stream.ReadBytes(10)
- self.assertEqual(4, stream.Position()) # Can't go past end of stream.
-
- def testGetSubBuffer(self):
- stream = input_stream.InputStream('abcd')
- # Try leaving out the size.
- self.assertEqual('abcd', str(stream.GetSubBuffer()))
- stream.SkipBytes(1)
- # GetSubBuffer() always starts at current size.
- self.assertEqual('bcd', str(stream.GetSubBuffer()))
- # Try 0-size.
- self.assertEqual('', str(stream.GetSubBuffer(0)))
- # Negative sizes should raise an error.
- self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1)
- # Positive sizes should work as expected.
- self.assertEqual('b', str(stream.GetSubBuffer(1)))
- self.assertEqual('bc', str(stream.GetSubBuffer(2)))
- # Sizes longer than remaining bytes in the buffer should
- # return the whole remaining buffer.
- self.assertEqual('bcd', str(stream.GetSubBuffer(1000)))
-
- def testSkipBytes(self):
- stream = input_stream.InputStream('')
- # Skipping bytes when at the end of stream
- # should have no effect.
- stream.SkipBytes(0)
- stream.SkipBytes(1)
- stream.SkipBytes(2)
- self.assertTrue(stream.EndOfStream())
- self.assertEqual(0, stream.Position())
-
- # Try skipping within a stream.
- stream = input_stream.InputStream('abcd')
- self.assertEqual(0, stream.Position())
- stream.SkipBytes(1)
- self.assertEqual(1, stream.Position())
- stream.SkipBytes(10) # Can't skip past the end.
- self.assertEqual(4, stream.Position())
-
- # Ensure that a negative skip raises an exception.
- stream = input_stream.InputStream('abcd')
- stream.SkipBytes(1)
- self.assertRaises(message.DecodeError, stream.SkipBytes, -1)
-
- def testReadBytes(self):
- s = 'abcd'
- # Also test going past the total stream length.
- for i in range(len(s) + 10):
- stream = input_stream.InputStream(s)
- self.assertEqual(s[:i], stream.ReadBytes(i))
- self.assertEqual(min(i, len(s)), stream.Position())
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadBytes, -1)
-
- def EnsureFailureOnEmptyStream(self, input_stream_method):
- """Helper for integer-parsing tests below.
- Ensures that the given InputStream method raises a DecodeError
- if called on a stream with no bytes remaining.
- """
- stream = input_stream.InputStream('')
- self.assertRaises(message.DecodeError, input_stream_method, stream)
-
- def testReadLittleEndian32(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32)
- s = ''
- # Read 0.
- s += '\x00\x00\x00\x00'
- # Read 1.
- s += '\x01\x00\x00\x00'
- # Read a bunch of different bytes.
- s += '\x01\x02\x03\x04'
- # Read max unsigned 32-bit int.
- s += '\xff\xff\xff\xff'
- # Try a read with fewer than 4 bytes left in the stream.
- s += '\x00\x00\x00'
- stream = input_stream.InputStream(s)
- self.assertEqual(0, stream.ReadLittleEndian32())
- self.assertEqual(4, stream.Position())
- self.assertEqual(1, stream.ReadLittleEndian32())
- self.assertEqual(8, stream.Position())
- self.assertEqual(0x04030201, stream.ReadLittleEndian32())
- self.assertEqual(12, stream.Position())
- self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32())
- self.assertEqual(16, stream.Position())
- self.assertRaises(message.DecodeError, stream.ReadLittleEndian32)
-
- def testReadLittleEndian64(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64)
- s = ''
- # Read 0.
- s += '\x00\x00\x00\x00\x00\x00\x00\x00'
- # Read 1.
- s += '\x01\x00\x00\x00\x00\x00\x00\x00'
- # Read a bunch of different bytes.
- s += '\x01\x02\x03\x04\x05\x06\x07\x08'
- # Read max unsigned 64-bit int.
- s += '\xff\xff\xff\xff\xff\xff\xff\xff'
- # Try a read with fewer than 8 bytes left in the stream.
- s += '\x00\x00\x00'
- stream = input_stream.InputStream(s)
- self.assertEqual(0, stream.ReadLittleEndian64())
- self.assertEqual(8, stream.Position())
- self.assertEqual(1, stream.ReadLittleEndian64())
- self.assertEqual(16, stream.Position())
- self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64())
- self.assertEqual(24, stream.Position())
- self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64())
- self.assertEqual(32, stream.Position())
- self.assertRaises(message.DecodeError, stream.ReadLittleEndian64)
-
- def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method):
- """Helper for tests below that test successful reads of various varints.
-
- Args:
- varints_and_ints: Iterable of (str, integer) pairs, where the string
- gives the wire encoding and the integer gives the value we expect
- to be returned by the read_method upon encountering this string.
- read_method: Unbound InputStream method that is capable of reading
- the encoded strings provided in the first elements of varints_and_ints.
- """
- s = ''.join(s for s, i in varints_and_ints)
- stream = input_stream.InputStream(s)
- expected_pos = 0
- self.assertEqual(expected_pos, stream.Position())
- for s, expected_int in varints_and_ints:
- self.assertEqual(expected_int, read_method(stream))
- expected_pos += len(s)
- self.assertEqual(expected_pos, stream.Position())
-
- def testReadVarint32Success(self):
- varints_and_ints = [
- ('\x00', 0),
- ('\x01', 1),
- ('\x7f', 127),
- ('\x80\x01', 128),
- ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
- ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX),
- ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN),
- ]
- self.ReadVarintSuccessTestHelper(varints_and_ints,
- input_stream.InputStream.ReadVarint32)
-
- def testReadVarint32Failure(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32)
-
- # Try and fail to read INT32_MAX + 1.
- s = '\x80\x80\x80\x80\x08'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
- # Try and fail to read INT32_MIN - 1.
- s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
- # Try and fail to read something that looks like
- # a varint with more than 10 bytes.
- s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarint32)
-
- def testReadVarUInt32Success(self):
- varints_and_ints = [
- ('\x00', 0),
- ('\x01', 1),
- ('\x7f', 127),
- ('\x80\x01', 128),
- ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX),
- ]
- self.ReadVarintSuccessTestHelper(varints_and_ints,
- input_stream.InputStream.ReadVarUInt32)
-
- def testReadVarUInt32Failure(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32)
- # Try and fail to read UINT32_MAX + 1
- s = '\x80\x80\x80\x80\x10'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
-
- # Try and fail to read something that looks like
- # a varint with more than 10 bytes.
- s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
-
- def testReadVarint64Success(self):
- varints_and_ints = [
- ('\x00', 0),
- ('\x01', 1),
- ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
- ('\x7f', 127),
- ('\x80\x01', 128),
- ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX),
- ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN),
- ]
- self.ReadVarintSuccessTestHelper(varints_and_ints,
- input_stream.InputStream.ReadVarint64)
-
- def testReadVarint64Failure(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64)
- # Try and fail to read something with the mythical 64th bit set.
- s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarint64)
-
- # Try and fail to read something that looks like
- # a varint with more than 10 bytes.
- s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarint64)
-
- def testReadVarUInt64Success(self):
- varints_and_ints = [
- ('\x00', 0),
- ('\x01', 1),
- ('\x7f', 127),
- ('\x80\x01', 128),
- ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63),
- ]
- self.ReadVarintSuccessTestHelper(varints_and_ints,
- input_stream.InputStream.ReadVarUInt64)
-
- def testReadVarUInt64Failure(self):
- self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64)
- # Try and fail to read something with the mythical 64th bit set.
- s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
-
- # Try and fail to read something that looks like
- # a varint with more than 10 bytes.
- s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
- stream = input_stream.InputStream(s)
- self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
-
-
-class InputStreamArrayTest(InputStreamBufferTest):
-
- def setUp(self):
- # Test InputStreamArray against the same tests in InputStreamBuffer
- self.__original_input_stream = input_stream.InputStream
- input_stream.InputStream = input_stream.InputStreamArray
-
- def tearDown(self):
- input_stream.InputStream = self.__original_input_stream
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py
index 43978952..1080234d 100755
--- a/python/google/protobuf/internal/message_listener.py
+++ b/python/google/protobuf/internal/message_listener.py
@@ -39,22 +39,34 @@ __author__ = 'robinson@google.com (Will Robinson)'
class MessageListener(object):
- """Listens for transitions to nonempty and for invalidations of cached
- byte sizes. Meant to be registered via Message._SetListener().
+ """Listens for modifications made to a message. Meant to be registered via
+ Message._SetListener().
+
+ Attributes:
+ dirty: If True, then calling Modified() would be a no-op. This can be
+ used to avoid these calls entirely in the common case.
"""
- def TransitionToNonempty(self):
- """Called the *first* time that this message becomes nonempty.
- Implementations are free (but not required) to call this method multiple
- times after the message has become nonempty.
- """
- raise NotImplementedError
+ def Modified(self):
+ """Called every time the message is modified in such a way that the parent
+ message may need to be updated. This currently means either:
+ (a) The message was modified for the first time, so the parent message
+ should henceforth mark the message as present.
+ (b) The message's cached byte size became dirty -- i.e. the message was
+ modified for the first time after a previous call to ByteSize().
+ Therefore the parent should also mark its byte size as dirty.
+ Note that (a) implies (b), since new objects start out with a client cached
+ size (zero). However, we document (a) explicitly because it is important.
+
+ Modified() will *only* be called in response to one of these two events --
+ not every time the sub-message is modified.
- def ByteSizeDirty(self):
- """Called *every* time the cached byte size value
- for this object is invalidated (transitions from being
- "clean" to "dirty").
+ Note that if the listener's |dirty| attribute is true, then calling
+ Modified at the moment would be a no-op, so it can be skipped. Performance-
+ sensitive callers should check this attribute directly before calling since
+ it will be true most of the time.
"""
+
raise NotImplementedError
@@ -62,8 +74,5 @@ class NullMessageListener(object):
"""No-op MessageListener implementation."""
- def TransitionToNonempty(self):
- pass
-
- def ByteSizeDirty(self):
+ def Modified(self):
pass
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index df344cf0..73a9a3a3 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -30,7 +30,16 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Tests python protocol buffers against the golden message."""
+"""Tests python protocol buffers against the golden message.
+
+Note that the golden messages exercise every known field type, thus this
+test ends up exercising and verifying nearly all of the parsing and
+serialization code in the whole library.
+
+TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
+sense to call this a test of the "message" module, which only declares an
+abstract interface.
+"""
__author__ = 'gps@google.com (Gregory P. Smith)'
@@ -40,14 +49,41 @@ from google.protobuf import unittest_pb2
from google.protobuf.internal import test_util
-class MessageTest(test_util.GoldenMessageTestCase):
+class MessageTest(unittest.TestCase):
def testGoldenMessage(self):
golden_data = test_util.GoldenFile('golden_message').read()
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
- self.ExpectAllFieldsSet(golden_message)
+ test_util.ExpectAllFieldsSet(self, golden_message)
+ self.assertTrue(golden_message.SerializeToString() == golden_data)
+
+ def testGoldenExtensions(self):
+ golden_data = test_util.GoldenFile('golden_message').read()
+ golden_message = unittest_pb2.TestAllExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(all_set)
+ self.assertEquals(all_set, golden_message)
+ self.assertTrue(golden_message.SerializeToString() == golden_data)
+
+ def testGoldenPackedMessage(self):
+ golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_message = unittest_pb2.TestPackedTypes()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(all_set)
+ self.assertEquals(all_set, golden_message)
+ self.assertTrue(all_set.SerializeToString() == golden_data)
+ def testGoldenPackedExtensions(self):
+ golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_message = unittest_pb2.TestPackedExtensions()
+ golden_message.ParseFromString(golden_data)
+ all_set = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(all_set)
+ self.assertEquals(all_set, golden_message)
+ self.assertTrue(all_set.SerializeToString() == golden_data)
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/output_stream.py b/python/google/protobuf/internal/output_stream.py
deleted file mode 100755
index 6c2d6f6b..00000000
--- a/python/google/protobuf/internal/output_stream.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""OutputStream is the primitive interface for sticking bits on the wire.
-
-All protocol buffer serialization can be expressed in terms of
-the OutputStream primitives provided here.
-"""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import array
-import struct
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by CodedOutputStream from the C++
-# proto2 implementation.
-
-
-class OutputStream(object):
-
- """Contains all logic for writing bits, and ToString() to get the result."""
-
- def __init__(self):
- self._buffer = array.array('B')
-
- def AppendRawBytes(self, raw_bytes):
- """Appends raw_bytes to our internal buffer."""
- self._buffer.fromstring(raw_bytes)
-
- def AppendLittleEndian32(self, unsigned_value):
- """Appends an unsigned 32-bit integer to the internal buffer,
- in little-endian byte order.
- """
- if not 0 <= unsigned_value <= wire_format.UINT32_MAX:
- raise message.EncodeError(
- 'Unsigned 32-bit out of range: %d' % unsigned_value)
- self._buffer.fromstring(struct.pack(
- wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value))
-
- def AppendLittleEndian64(self, unsigned_value):
- """Appends an unsigned 64-bit integer to the internal buffer,
- in little-endian byte order.
- """
- if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
- raise message.EncodeError(
- 'Unsigned 64-bit out of range: %d' % unsigned_value)
- self._buffer.fromstring(struct.pack(
- wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value))
-
- def AppendVarint32(self, value):
- """Appends a signed 32-bit integer to the internal buffer,
- encoded as a varint. (Note that a negative varint32 will
- always require 10 bytes of space.)
- """
- if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX:
- raise message.EncodeError('Value out of range: %d' % value)
- self.AppendVarint64(value)
-
- def AppendVarUInt32(self, value):
- """Appends an unsigned 32-bit integer to the internal buffer,
- encoded as a varint.
- """
- if not 0 <= value <= wire_format.UINT32_MAX:
- raise message.EncodeError('Value out of range: %d' % value)
- self.AppendVarUInt64(value)
-
- def AppendVarint64(self, value):
- """Appends a signed 64-bit integer to the internal buffer,
- encoded as a varint.
- """
- if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX:
- raise message.EncodeError('Value out of range: %d' % value)
- if value < 0:
- value += (1 << 64)
- self.AppendVarUInt64(value)
-
- def AppendVarUInt64(self, unsigned_value):
- """Appends an unsigned 64-bit integer to the internal buffer,
- encoded as a varint.
- """
- if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
- raise message.EncodeError('Value out of range: %d' % unsigned_value)
- while True:
- bits = unsigned_value & 0x7f
- unsigned_value >>= 7
- if not unsigned_value:
- self._buffer.append(bits)
- break
- self._buffer.append(0x80|bits)
-
- def ToString(self):
- """Returns a string containing the bytes in our internal buffer."""
- return self._buffer.tostring()
diff --git a/python/google/protobuf/internal/output_stream_test.py b/python/google/protobuf/internal/output_stream_test.py
deleted file mode 100755
index df92eecd..00000000
--- a/python/google/protobuf/internal/output_stream_test.py
+++ /dev/null
@@ -1,178 +0,0 @@
-#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.output_stream."""
-
-__author__ = 'robinson@google.com (Will Robinson)'
-
-import unittest
-from google.protobuf import message
-from google.protobuf.internal import output_stream
-from google.protobuf.internal import wire_format
-
-
-class OutputStreamTest(unittest.TestCase):
-
- def setUp(self):
- self.stream = output_stream.OutputStream()
-
- def testAppendRawBytes(self):
- # Empty string.
- self.stream.AppendRawBytes('')
- self.assertEqual('', self.stream.ToString())
-
- # Nonempty string.
- self.stream.AppendRawBytes('abc')
- self.assertEqual('abc', self.stream.ToString())
-
- # Ensure that we're actually appending.
- self.stream.AppendRawBytes('def')
- self.assertEqual('abcdef', self.stream.ToString())
-
- def AppendNumericTestHelper(self, append_fn, values_and_strings):
- """For each (value, expected_string) pair in values_and_strings,
- calls an OutputStream.Append*(value) method on an OutputStream and ensures
- that the string written to that stream matches expected_string.
-
- Args:
- append_fn: Unbound OutputStream method that takes an integer or
- long value as input.
- values_and_strings: Iterable of (value, expected_string) pairs.
- """
- for conversion in (int, long):
- for value, string in values_and_strings:
- stream = output_stream.OutputStream()
- expected_string = ''
- append_fn(stream, conversion(value))
- expected_string += string
- self.assertEqual(expected_string, stream.ToString())
-
- def AppendOverflowTestHelper(self, append_fn, value):
- """Calls an OutputStream.Append*(value) method and asserts
- that the method raises message.EncodeError.
-
- Args:
- append_fn: Unbound OutputStream method that takes an integer or
- long value as input.
- value: Value to pass to append_fn which should cause an
- message.EncodeError.
- """
- stream = output_stream.OutputStream()
- self.assertRaises(message.EncodeError, append_fn, stream, value)
-
- def testAppendLittleEndian32(self):
- append_fn = output_stream.OutputStream.AppendLittleEndian32
- values_and_expected_strings = [
- (0, '\x00\x00\x00\x00'),
- (1, '\x01\x00\x00\x00'),
- ((1 << 32) - 1, '\xff\xff\xff\xff'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, 1 << 32)
- self.AppendOverflowTestHelper(append_fn, -1)
-
- def testAppendLittleEndian64(self):
- append_fn = output_stream.OutputStream.AppendLittleEndian64
- values_and_expected_strings = [
- (0, '\x00\x00\x00\x00\x00\x00\x00\x00'),
- (1, '\x01\x00\x00\x00\x00\x00\x00\x00'),
- ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, 1 << 64)
- self.AppendOverflowTestHelper(append_fn, -1)
-
- def testAppendVarint32(self):
- append_fn = output_stream.OutputStream.AppendVarint32
- values_and_expected_strings = [
- (0, '\x00'),
- (1, '\x01'),
- (127, '\x7f'),
- (128, '\x80\x01'),
- (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
- (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'),
- (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1)
- self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1)
-
- def testAppendVarUInt32(self):
- append_fn = output_stream.OutputStream.AppendVarUInt32
- values_and_expected_strings = [
- (0, '\x00'),
- (1, '\x01'),
- (127, '\x7f'),
- (128, '\x80\x01'),
- (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, -1)
- self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1)
-
- def testAppendVarint64(self):
- append_fn = output_stream.OutputStream.AppendVarint64
- values_and_expected_strings = [
- (0, '\x00'),
- (1, '\x01'),
- (127, '\x7f'),
- (128, '\x80\x01'),
- (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
- (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'),
- (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1)
- self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1)
-
- def testAppendVarUInt64(self):
- append_fn = output_stream.OutputStream.AppendVarUInt64
- values_and_expected_strings = [
- (0, '\x00'),
- (1, '\x01'),
- (127, '\x7f'),
- (128, '\x80\x01'),
- (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
- ]
- self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
-
- self.AppendOverflowTestHelper(append_fn, -1)
- self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 86101774..2c9fa30b 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -38,6 +38,7 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)'
import operator
+import struct
import unittest
# TODO(robinson): When we split this test in two, only some of these imports
@@ -56,6 +57,51 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import decoder
+class _MiniDecoder(object):
+ """Decodes a stream of values from a string.
+
+ Once upon a time we actually had a class called decoder.Decoder. Then we
+ got rid of it during a redesign that made decoding much, much faster overall.
+ But a couple tests in this file used it to check that the serialized form of
+ a message was correct. So, this class implements just the methods that were
+ used by said tests, so that we don't have to rewrite the tests.
+ """
+
+ def __init__(self, bytes):
+ self._bytes = bytes
+ self._pos = 0
+
+ def ReadVarint(self):
+ result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
+ return result
+
+ ReadInt32 = ReadVarint
+ ReadInt64 = ReadVarint
+ ReadUInt32 = ReadVarint
+ ReadUInt64 = ReadVarint
+
+ def ReadSInt64(self):
+ return wire_format.ZigZagDecode(self.ReadVarint())
+
+ ReadSInt32 = ReadSInt64
+
+ def ReadFieldNumberAndWireType(self):
+ return wire_format.UnpackTag(self.ReadVarint())
+
+ def ReadFloat(self):
+ result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+ self._pos += 4
+ return result
+
+ def ReadDouble(self):
+ result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+ self._pos += 8
+ return result
+
+ def EndOfStream(self):
+ return self._pos == len(self._bytes)
+
+
class ReflectionTest(unittest.TestCase):
def assertIs(self, values, others):
@@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase):
for i in range(len(values)):
self.assertTrue(values[i] is others[i])
+ def testScalarConstructor(self):
+ # Constructor with only scalar types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=24,
+ optional_double=54.321,
+ optional_string='optional_string')
+
+ self.assertEqual(24, proto.optional_int32)
+ self.assertEqual(54.321, proto.optional_double)
+ self.assertEqual('optional_string', proto.optional_string)
+
+ def testRepeatedScalarConstructor(self):
+ # Constructor with only repeated scalar types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ repeated_int32=[1, 2, 3, 4],
+ repeated_double=[1.23, 54.321],
+ repeated_bool=[True, False, False],
+ repeated_string=["optional_string"])
+
+ self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
+ self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+ self.assertEquals([True, False, False], list(proto.repeated_bool))
+ self.assertEquals(["optional_string"], list(proto.repeated_string))
+
+ def testRepeatedCompositeConstructor(self):
+ # Constructor with only repeated composite types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ repeated_nested_message=[
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ repeated_foreign_message=[
+ unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ repeatedgroup=[
+ unittest_pb2.TestAllTypes.RepeatedGroup(),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
+
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ list(proto.repeated_nested_message))
+ self.assertEquals(
+ [unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ list(proto.repeated_foreign_message))
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.RepeatedGroup(),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
+ unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
+ list(proto.repeatedgroup))
+
+ def testMixedConstructor(self):
+ # Constructor with only mixed types should succeed.
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=24,
+ optional_string='optional_string',
+ repeated_double=[1.23, 54.321],
+ repeated_bool=[True, False, False],
+ repeated_nested_message=[
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ repeated_foreign_message=[
+ unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)])
+
+ self.assertEqual(24, proto.optional_int32)
+ self.assertEqual('optional_string', proto.optional_string)
+ self.assertEquals([1.23, 54.321], list(proto.repeated_double))
+ self.assertEquals([True, False, False], list(proto.repeated_bool))
+ self.assertEquals(
+ [unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.FOO),
+ unittest_pb2.TestAllTypes.NestedMessage(
+ bb=unittest_pb2.TestAllTypes.BAR)],
+ list(proto.repeated_nested_message))
+ self.assertEquals(
+ [unittest_pb2.ForeignMessage(c=-43),
+ unittest_pb2.ForeignMessage(c=45324),
+ unittest_pb2.ForeignMessage(c=12)],
+ list(proto.repeated_foreign_message))
+
def testSimpleHasBits(self):
# Test a scalar.
proto = unittest_pb2.TestAllTypes()
@@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase):
proto.optional_fixed32 = 1
proto.optional_int32 = 5
proto.optional_string = 'foo'
+ # Access sub-message but don't set it yet.
+ nested_message = proto.optional_nested_message
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
(proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
(proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
proto.ListFields())
+ proto.optional_nested_message.bb = 123
+ self.assertEqual(
+ [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
+ (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
+ (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
+ (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
+ nested_message) ],
+ proto.ListFields())
+
def testRepeatedListFields(self):
proto = unittest_pb2.TestAllTypes()
proto.repeated_fixed32.append(1)
@@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_string.append('baz')
proto.repeated_string.extend(str(x) for x in xrange(2))
proto.optional_int32 = 21
+ proto.repeated_bool # Access but don't set anything; should not be listed.
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
(proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
@@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase):
extendee_proto.ClearExtension(extension)
extension_proto.foreign_message_int = 23
- self.assertTrue(not toplevel.HasField('submessage'))
self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
def testExtensionFailureModes(self):
@@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase):
empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto)
+ def assertInitialized(self, proto):
+ self.assertTrue(proto.IsInitialized())
+ # Neither method should raise an exception.
+ proto.SerializeToString()
+ proto.SerializePartialToString()
+
+ def assertNotInitialized(self, proto):
+ self.assertFalse(proto.IsInitialized())
+ self.assertRaises(message.EncodeError, proto.SerializeToString)
+ # "Partial" serialization doesn't care if message is uninitialized.
+ proto.SerializePartialToString()
+
def testIsInitialized(self):
# Trivial cases - all optional fields and extensions.
proto = unittest_pb2.TestAllTypes()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
proto = unittest_pb2.TestAllExtensions()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# The case of uninitialized required fields.
proto = unittest_pb2.TestRequired()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.a = proto.b = proto.c = 2
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# The case of uninitialized submessage.
proto = unittest_pb2.TestRequiredForeign()
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
proto.optional_message.a = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.optional_message.b = 0
proto.optional_message.c = 0
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized repeated submessage.
message1 = proto.repeated_message.add()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message1.a = message1.b = message1.c = 0
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized repeated group in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add()
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message1.a = 1
message1.b = 1
message1.c = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
message2.a = 2
message2.b = 2
message2.c = 2
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
# Uninitialized nonrepeated message in an extension.
proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1
- self.assertFalse(proto.IsInitialized())
+ self.assertNotInitialized(proto)
proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3
- self.assertTrue(proto.IsInitialized())
+ self.assertInitialized(proto)
+
+ # Try passing an errors list.
+ errors = []
+ proto = unittest_pb2.TestRequired()
+ self.assertFalse(proto.IsInitialized(errors))
+ self.assertEqual(errors, ['a', 'b', 'c'])
def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes()
@@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase):
test_utf8_bytes, len(test_utf8_bytes) * '\xff')
self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
+ def testEmptyNestedMessage(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.MergeFrom(
+ unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.CopyFrom(
+ unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.MergeFromString('')
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_nested_message.ParseFromString('')
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ serialized = proto.SerializeToString()
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.MergeFromString(serialized)
+ self.assertTrue(proto2.HasField('optional_nested_message'))
+
+ def testSetInParent(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertFalse(proto.HasField('optionalgroup'))
+ proto.optionalgroup.SetInParent()
+ self.assertTrue(proto.HasField('optionalgroup'))
+
# Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes.
@@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase):
second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto)
+ def testSerializeNegativeValues(self):
+ first_proto = unittest_pb2.TestAllTypes()
+
+ first_proto.optional_int32 = -1
+ first_proto.optional_int64 = -(2 << 40)
+ first_proto.optional_sint32 = -3
+ first_proto.optional_sint64 = -(4 << 40)
+ first_proto.optional_sfixed32 = -5
+ first_proto.optional_sfixed64 = -(6 << 40)
+
+ second_proto = unittest_pb2.TestAllTypes.FromString(
+ first_proto.SerializeToString())
+
+ self.assertEqual(first_proto, second_proto)
+
+ def testParseTruncated(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(first_proto)
+ serialized = first_proto.SerializeToString()
+
+ for truncation_point in xrange(len(serialized) + 1):
+ try:
+ second_proto = unittest_pb2.TestAllTypes()
+ unknown_fields = unittest_pb2.TestEmptyMessage()
+ pos = second_proto._InternalParse(serialized, 0, truncation_point)
+ # If we didn't raise an error then we read exactly the amount expected.
+ self.assertEqual(truncation_point, pos)
+
+ # Parsing to unknown fields should not throw if parsing to known fields
+ # did not.
+ try:
+ pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
+ self.assertEqual(truncation_point, pos2)
+ except message.DecodeError:
+ self.fail('Parsing unknown fields failed when parsing known fields '
+ 'did not.')
+ except message.DecodeError:
+ # Parsing unknown fields should also fail.
+ self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
+ serialized, 0, truncation_point)
+
def testCanonicalSerializationOrder(self):
proto = more_messages_pb2.OutOfOrderFields()
# These are also their tag numbers. Even though we're setting these in
@@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase):
proto.optional_int32 = 1
serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized))
- d = decoder.Decoder(serialized)
+ d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(1, d.ReadInt32())
@@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.a is not set.')
+ 'Message is missing required fields: a,b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.b is not set.')
+ 'Message is missing required fields: b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Required field protobuf_unittest.TestRequired.c is not set.')
+ 'Message is missing required fields: c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
+ def testSerializeUninitializedSubMessage(self):
+ proto = unittest_pb2.TestRequiredForeign()
+
+ # Sub-message doesn't exist yet, so this succeeds.
+ proto.SerializeToString()
+
+ proto.optional_message.a = 1
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Message is missing required fields: '
+ 'optional_message.b,optional_message.c')
+
+ proto.optional_message.b = 2
+ proto.optional_message.c = 3
+ proto.SerializeToString()
+
+ proto.repeated_message.add().a = 1
+ proto.repeated_message.add().b = 2
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Message is missing required fields: '
+ 'repeated_message[0].b,repeated_message[0].c,'
+ 'repeated_message[1].a,repeated_message[1].c')
+
+ proto.repeated_message[0].b = 2
+ proto.repeated_message[0].c = 3
+ proto.repeated_message[1].a = 1
+ proto.repeated_message[1].c = 3
+ proto.SerializeToString()
+
def testSerializeAllPackedFields(self):
first_proto = unittest_pb2.TestPackedTypes()
second_proto = unittest_pb2.TestPackedTypes()
@@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase):
proto.packed_float.append(2.0) # 4 bytes, will be before double
serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized))
- d = decoder.Decoder(serialized)
+ d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
self.assertEqual(1+1+1+2, d.ReadInt32())
@@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1000.0, d.ReadDouble())
self.assertTrue(d.EndOfStream())
+ def testParsePackedFromUnpacked(self):
+ unpacked = unittest_pb2.TestUnpackedTypes()
+ test_util.SetAllUnpackedFields(unpacked)
+ packed = unittest_pb2.TestPackedTypes()
+ packed.MergeFromString(unpacked.SerializeToString())
+ expected = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(expected)
+ self.assertEqual(expected, packed)
+
+ def testParseUnpackedFromPacked(self):
+ packed = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(packed)
+ unpacked = unittest_pb2.TestUnpackedTypes()
+ unpacked.MergeFromString(packed.SerializeToString())
+ expected = unittest_pb2.TestUnpackedTypes()
+ test_util.SetAllUnpackedFields(expected)
+ self.assertEqual(expected, unpacked)
+
def testFieldNumbers(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
@@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase):
field_descriptor.label)
-class UtilityTest(unittest.TestCase):
-
- def testImergeSorted(self):
- ImergeSorted = reflection._ImergeSorted
- # Various types of emptiness.
- self.assertEqual([], list(ImergeSorted()))
- self.assertEqual([], list(ImergeSorted([])))
- self.assertEqual([], list(ImergeSorted([], [])))
-
- # One nonempty list.
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
-
- # Merging some nonempty lists together.
- self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
- self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
-
- # Elements repeated across component iterators.
- self.assertEqual([1, 2, 2, 3, 3],
- list(ImergeSorted([1, 2], [3], [2, 3])))
-
- # Elements repeated within an iterator.
- self.assertEqual([1, 2, 2, 3, 3],
- list(ImergeSorted([1, 2, 2], [3], [3])))
-
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 1a0da552..1df16194 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -31,14 +31,13 @@
"""Utilities for Python proto2 tests.
This is intentionally modeled on C++ code in
-//net/proto2/internal/test_util.*.
+//google/protobuf/test_util.*.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import os.path
-import unittest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
@@ -353,198 +352,198 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized):
raise ValueError('Expected %r, found %r' % (expected, serialized))
-class GoldenMessageTestCase(unittest.TestCase):
- """This adds methods to TestCase useful for verifying our Golden Message."""
-
- def ExpectAllFieldsSet(self, message):
- """Check all fields for correct values have after Set*Fields() is called."""
- self.assertTrue(message.HasField('optional_int32'))
- self.assertTrue(message.HasField('optional_int64'))
- self.assertTrue(message.HasField('optional_uint32'))
- self.assertTrue(message.HasField('optional_uint64'))
- self.assertTrue(message.HasField('optional_sint32'))
- self.assertTrue(message.HasField('optional_sint64'))
- self.assertTrue(message.HasField('optional_fixed32'))
- self.assertTrue(message.HasField('optional_fixed64'))
- self.assertTrue(message.HasField('optional_sfixed32'))
- self.assertTrue(message.HasField('optional_sfixed64'))
- self.assertTrue(message.HasField('optional_float'))
- self.assertTrue(message.HasField('optional_double'))
- self.assertTrue(message.HasField('optional_bool'))
- self.assertTrue(message.HasField('optional_string'))
- self.assertTrue(message.HasField('optional_bytes'))
-
- self.assertTrue(message.HasField('optionalgroup'))
- self.assertTrue(message.HasField('optional_nested_message'))
- self.assertTrue(message.HasField('optional_foreign_message'))
- self.assertTrue(message.HasField('optional_import_message'))
-
- self.assertTrue(message.optionalgroup.HasField('a'))
- self.assertTrue(message.optional_nested_message.HasField('bb'))
- self.assertTrue(message.optional_foreign_message.HasField('c'))
- self.assertTrue(message.optional_import_message.HasField('d'))
-
- self.assertTrue(message.HasField('optional_nested_enum'))
- self.assertTrue(message.HasField('optional_foreign_enum'))
- self.assertTrue(message.HasField('optional_import_enum'))
-
- self.assertTrue(message.HasField('optional_string_piece'))
- self.assertTrue(message.HasField('optional_cord'))
-
- self.assertEqual(101, message.optional_int32)
- self.assertEqual(102, message.optional_int64)
- self.assertEqual(103, message.optional_uint32)
- self.assertEqual(104, message.optional_uint64)
- self.assertEqual(105, message.optional_sint32)
- self.assertEqual(106, message.optional_sint64)
- self.assertEqual(107, message.optional_fixed32)
- self.assertEqual(108, message.optional_fixed64)
- self.assertEqual(109, message.optional_sfixed32)
- self.assertEqual(110, message.optional_sfixed64)
- self.assertEqual(111, message.optional_float)
- self.assertEqual(112, message.optional_double)
- self.assertEqual(True, message.optional_bool)
- self.assertEqual('115', message.optional_string)
- self.assertEqual('116', message.optional_bytes)
-
- self.assertEqual(117, message.optionalgroup.a);
- self.assertEqual(118, message.optional_nested_message.bb)
- self.assertEqual(119, message.optional_foreign_message.c)
- self.assertEqual(120, message.optional_import_message.d)
-
- self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.optional_nested_enum)
- self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum)
- self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.optional_import_enum)
-
- # -----------------------------------------------------------------
-
- self.assertEqual(2, len(message.repeated_int32))
- self.assertEqual(2, len(message.repeated_int64))
- self.assertEqual(2, len(message.repeated_uint32))
- self.assertEqual(2, len(message.repeated_uint64))
- self.assertEqual(2, len(message.repeated_sint32))
- self.assertEqual(2, len(message.repeated_sint64))
- self.assertEqual(2, len(message.repeated_fixed32))
- self.assertEqual(2, len(message.repeated_fixed64))
- self.assertEqual(2, len(message.repeated_sfixed32))
- self.assertEqual(2, len(message.repeated_sfixed64))
- self.assertEqual(2, len(message.repeated_float))
- self.assertEqual(2, len(message.repeated_double))
- self.assertEqual(2, len(message.repeated_bool))
- self.assertEqual(2, len(message.repeated_string))
- self.assertEqual(2, len(message.repeated_bytes))
-
- self.assertEqual(2, len(message.repeatedgroup))
- self.assertEqual(2, len(message.repeated_nested_message))
- self.assertEqual(2, len(message.repeated_foreign_message))
- self.assertEqual(2, len(message.repeated_import_message))
- self.assertEqual(2, len(message.repeated_nested_enum))
- self.assertEqual(2, len(message.repeated_foreign_enum))
- self.assertEqual(2, len(message.repeated_import_enum))
-
- self.assertEqual(2, len(message.repeated_string_piece))
- self.assertEqual(2, len(message.repeated_cord))
-
- self.assertEqual(201, message.repeated_int32[0])
- self.assertEqual(202, message.repeated_int64[0])
- self.assertEqual(203, message.repeated_uint32[0])
- self.assertEqual(204, message.repeated_uint64[0])
- self.assertEqual(205, message.repeated_sint32[0])
- self.assertEqual(206, message.repeated_sint64[0])
- self.assertEqual(207, message.repeated_fixed32[0])
- self.assertEqual(208, message.repeated_fixed64[0])
- self.assertEqual(209, message.repeated_sfixed32[0])
- self.assertEqual(210, message.repeated_sfixed64[0])
- self.assertEqual(211, message.repeated_float[0])
- self.assertEqual(212, message.repeated_double[0])
- self.assertEqual(True, message.repeated_bool[0])
- self.assertEqual('215', message.repeated_string[0])
- self.assertEqual('216', message.repeated_bytes[0])
-
- self.assertEqual(217, message.repeatedgroup[0].a)
- self.assertEqual(218, message.repeated_nested_message[0].bb)
- self.assertEqual(219, message.repeated_foreign_message[0].c)
- self.assertEqual(220, message.repeated_import_message[0].d)
-
- self.assertEqual(unittest_pb2.TestAllTypes.BAR,
- message.repeated_nested_enum[0])
- self.assertEqual(unittest_pb2.FOREIGN_BAR,
- message.repeated_foreign_enum[0])
- self.assertEqual(unittest_import_pb2.IMPORT_BAR,
- message.repeated_import_enum[0])
-
- self.assertEqual(301, message.repeated_int32[1])
- self.assertEqual(302, message.repeated_int64[1])
- self.assertEqual(303, message.repeated_uint32[1])
- self.assertEqual(304, message.repeated_uint64[1])
- self.assertEqual(305, message.repeated_sint32[1])
- self.assertEqual(306, message.repeated_sint64[1])
- self.assertEqual(307, message.repeated_fixed32[1])
- self.assertEqual(308, message.repeated_fixed64[1])
- self.assertEqual(309, message.repeated_sfixed32[1])
- self.assertEqual(310, message.repeated_sfixed64[1])
- self.assertEqual(311, message.repeated_float[1])
- self.assertEqual(312, message.repeated_double[1])
- self.assertEqual(False, message.repeated_bool[1])
- self.assertEqual('315', message.repeated_string[1])
- self.assertEqual('316', message.repeated_bytes[1])
-
- self.assertEqual(317, message.repeatedgroup[1].a)
- self.assertEqual(318, message.repeated_nested_message[1].bb)
- self.assertEqual(319, message.repeated_foreign_message[1].c)
- self.assertEqual(320, message.repeated_import_message[1].d)
-
- self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.repeated_nested_enum[1])
- self.assertEqual(unittest_pb2.FOREIGN_BAZ,
- message.repeated_foreign_enum[1])
- self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.repeated_import_enum[1])
-
- # -----------------------------------------------------------------
-
- self.assertTrue(message.HasField('default_int32'))
- self.assertTrue(message.HasField('default_int64'))
- self.assertTrue(message.HasField('default_uint32'))
- self.assertTrue(message.HasField('default_uint64'))
- self.assertTrue(message.HasField('default_sint32'))
- self.assertTrue(message.HasField('default_sint64'))
- self.assertTrue(message.HasField('default_fixed32'))
- self.assertTrue(message.HasField('default_fixed64'))
- self.assertTrue(message.HasField('default_sfixed32'))
- self.assertTrue(message.HasField('default_sfixed64'))
- self.assertTrue(message.HasField('default_float'))
- self.assertTrue(message.HasField('default_double'))
- self.assertTrue(message.HasField('default_bool'))
- self.assertTrue(message.HasField('default_string'))
- self.assertTrue(message.HasField('default_bytes'))
-
- self.assertTrue(message.HasField('default_nested_enum'))
- self.assertTrue(message.HasField('default_foreign_enum'))
- self.assertTrue(message.HasField('default_import_enum'))
-
- self.assertEqual(401, message.default_int32)
- self.assertEqual(402, message.default_int64)
- self.assertEqual(403, message.default_uint32)
- self.assertEqual(404, message.default_uint64)
- self.assertEqual(405, message.default_sint32)
- self.assertEqual(406, message.default_sint64)
- self.assertEqual(407, message.default_fixed32)
- self.assertEqual(408, message.default_fixed64)
- self.assertEqual(409, message.default_sfixed32)
- self.assertEqual(410, message.default_sfixed64)
- self.assertEqual(411, message.default_float)
- self.assertEqual(412, message.default_double)
- self.assertEqual(False, message.default_bool)
- self.assertEqual('415', message.default_string)
- self.assertEqual('416', message.default_bytes)
-
- self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum)
- self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum)
- self.assertEqual(unittest_import_pb2.IMPORT_FOO,
- message.default_import_enum)
+def ExpectAllFieldsSet(test_case, message):
+ """Check all fields for correct values have after Set*Fields() is called."""
+ test_case.assertTrue(message.HasField('optional_int32'))
+ test_case.assertTrue(message.HasField('optional_int64'))
+ test_case.assertTrue(message.HasField('optional_uint32'))
+ test_case.assertTrue(message.HasField('optional_uint64'))
+ test_case.assertTrue(message.HasField('optional_sint32'))
+ test_case.assertTrue(message.HasField('optional_sint64'))
+ test_case.assertTrue(message.HasField('optional_fixed32'))
+ test_case.assertTrue(message.HasField('optional_fixed64'))
+ test_case.assertTrue(message.HasField('optional_sfixed32'))
+ test_case.assertTrue(message.HasField('optional_sfixed64'))
+ test_case.assertTrue(message.HasField('optional_float'))
+ test_case.assertTrue(message.HasField('optional_double'))
+ test_case.assertTrue(message.HasField('optional_bool'))
+ test_case.assertTrue(message.HasField('optional_string'))
+ test_case.assertTrue(message.HasField('optional_bytes'))
+
+ test_case.assertTrue(message.HasField('optionalgroup'))
+ test_case.assertTrue(message.HasField('optional_nested_message'))
+ test_case.assertTrue(message.HasField('optional_foreign_message'))
+ test_case.assertTrue(message.HasField('optional_import_message'))
+
+ test_case.assertTrue(message.optionalgroup.HasField('a'))
+ test_case.assertTrue(message.optional_nested_message.HasField('bb'))
+ test_case.assertTrue(message.optional_foreign_message.HasField('c'))
+ test_case.assertTrue(message.optional_import_message.HasField('d'))
+
+ test_case.assertTrue(message.HasField('optional_nested_enum'))
+ test_case.assertTrue(message.HasField('optional_foreign_enum'))
+ test_case.assertTrue(message.HasField('optional_import_enum'))
+
+ test_case.assertTrue(message.HasField('optional_string_piece'))
+ test_case.assertTrue(message.HasField('optional_cord'))
+
+ test_case.assertEqual(101, message.optional_int32)
+ test_case.assertEqual(102, message.optional_int64)
+ test_case.assertEqual(103, message.optional_uint32)
+ test_case.assertEqual(104, message.optional_uint64)
+ test_case.assertEqual(105, message.optional_sint32)
+ test_case.assertEqual(106, message.optional_sint64)
+ test_case.assertEqual(107, message.optional_fixed32)
+ test_case.assertEqual(108, message.optional_fixed64)
+ test_case.assertEqual(109, message.optional_sfixed32)
+ test_case.assertEqual(110, message.optional_sfixed64)
+ test_case.assertEqual(111, message.optional_float)
+ test_case.assertEqual(112, message.optional_double)
+ test_case.assertEqual(True, message.optional_bool)
+ test_case.assertEqual('115', message.optional_string)
+ test_case.assertEqual('116', message.optional_bytes)
+
+ test_case.assertEqual(117, message.optionalgroup.a)
+ test_case.assertEqual(118, message.optional_nested_message.bb)
+ test_case.assertEqual(119, message.optional_foreign_message.c)
+ test_case.assertEqual(120, message.optional_import_message.d)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ message.optional_foreign_enum)
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
+
+ # -----------------------------------------------------------------
+
+ test_case.assertEqual(2, len(message.repeated_int32))
+ test_case.assertEqual(2, len(message.repeated_int64))
+ test_case.assertEqual(2, len(message.repeated_uint32))
+ test_case.assertEqual(2, len(message.repeated_uint64))
+ test_case.assertEqual(2, len(message.repeated_sint32))
+ test_case.assertEqual(2, len(message.repeated_sint64))
+ test_case.assertEqual(2, len(message.repeated_fixed32))
+ test_case.assertEqual(2, len(message.repeated_fixed64))
+ test_case.assertEqual(2, len(message.repeated_sfixed32))
+ test_case.assertEqual(2, len(message.repeated_sfixed64))
+ test_case.assertEqual(2, len(message.repeated_float))
+ test_case.assertEqual(2, len(message.repeated_double))
+ test_case.assertEqual(2, len(message.repeated_bool))
+ test_case.assertEqual(2, len(message.repeated_string))
+ test_case.assertEqual(2, len(message.repeated_bytes))
+
+ test_case.assertEqual(2, len(message.repeatedgroup))
+ test_case.assertEqual(2, len(message.repeated_nested_message))
+ test_case.assertEqual(2, len(message.repeated_foreign_message))
+ test_case.assertEqual(2, len(message.repeated_import_message))
+ test_case.assertEqual(2, len(message.repeated_nested_enum))
+ test_case.assertEqual(2, len(message.repeated_foreign_enum))
+ test_case.assertEqual(2, len(message.repeated_import_enum))
+
+ test_case.assertEqual(2, len(message.repeated_string_piece))
+ test_case.assertEqual(2, len(message.repeated_cord))
+
+ test_case.assertEqual(201, message.repeated_int32[0])
+ test_case.assertEqual(202, message.repeated_int64[0])
+ test_case.assertEqual(203, message.repeated_uint32[0])
+ test_case.assertEqual(204, message.repeated_uint64[0])
+ test_case.assertEqual(205, message.repeated_sint32[0])
+ test_case.assertEqual(206, message.repeated_sint64[0])
+ test_case.assertEqual(207, message.repeated_fixed32[0])
+ test_case.assertEqual(208, message.repeated_fixed64[0])
+ test_case.assertEqual(209, message.repeated_sfixed32[0])
+ test_case.assertEqual(210, message.repeated_sfixed64[0])
+ test_case.assertEqual(211, message.repeated_float[0])
+ test_case.assertEqual(212, message.repeated_double[0])
+ test_case.assertEqual(True, message.repeated_bool[0])
+ test_case.assertEqual('215', message.repeated_string[0])
+ test_case.assertEqual('216', message.repeated_bytes[0])
+
+ test_case.assertEqual(217, message.repeatedgroup[0].a)
+ test_case.assertEqual(218, message.repeated_nested_message[0].bb)
+ test_case.assertEqual(219, message.repeated_foreign_message[0].c)
+ test_case.assertEqual(220, message.repeated_import_message[0].d)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[0])
+ test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
+ message.repeated_foreign_enum[0])
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
+
+ test_case.assertEqual(301, message.repeated_int32[1])
+ test_case.assertEqual(302, message.repeated_int64[1])
+ test_case.assertEqual(303, message.repeated_uint32[1])
+ test_case.assertEqual(304, message.repeated_uint64[1])
+ test_case.assertEqual(305, message.repeated_sint32[1])
+ test_case.assertEqual(306, message.repeated_sint64[1])
+ test_case.assertEqual(307, message.repeated_fixed32[1])
+ test_case.assertEqual(308, message.repeated_fixed64[1])
+ test_case.assertEqual(309, message.repeated_sfixed32[1])
+ test_case.assertEqual(310, message.repeated_sfixed64[1])
+ test_case.assertEqual(311, message.repeated_float[1])
+ test_case.assertEqual(312, message.repeated_double[1])
+ test_case.assertEqual(False, message.repeated_bool[1])
+ test_case.assertEqual('315', message.repeated_string[1])
+ test_case.assertEqual('316', message.repeated_bytes[1])
+
+ test_case.assertEqual(317, message.repeatedgroup[1].a)
+ test_case.assertEqual(318, message.repeated_nested_message[1].bb)
+ test_case.assertEqual(319, message.repeated_foreign_message[1].c)
+ test_case.assertEqual(320, message.repeated_import_message[1].d)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.repeated_nested_enum[1])
+ test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ message.repeated_foreign_enum[1])
+ test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
+
+ # -----------------------------------------------------------------
+
+ test_case.assertTrue(message.HasField('default_int32'))
+ test_case.assertTrue(message.HasField('default_int64'))
+ test_case.assertTrue(message.HasField('default_uint32'))
+ test_case.assertTrue(message.HasField('default_uint64'))
+ test_case.assertTrue(message.HasField('default_sint32'))
+ test_case.assertTrue(message.HasField('default_sint64'))
+ test_case.assertTrue(message.HasField('default_fixed32'))
+ test_case.assertTrue(message.HasField('default_fixed64'))
+ test_case.assertTrue(message.HasField('default_sfixed32'))
+ test_case.assertTrue(message.HasField('default_sfixed64'))
+ test_case.assertTrue(message.HasField('default_float'))
+ test_case.assertTrue(message.HasField('default_double'))
+ test_case.assertTrue(message.HasField('default_bool'))
+ test_case.assertTrue(message.HasField('default_string'))
+ test_case.assertTrue(message.HasField('default_bytes'))
+
+ test_case.assertTrue(message.HasField('default_nested_enum'))
+ test_case.assertTrue(message.HasField('default_foreign_enum'))
+ test_case.assertTrue(message.HasField('default_import_enum'))
+
+ test_case.assertEqual(401, message.default_int32)
+ test_case.assertEqual(402, message.default_int64)
+ test_case.assertEqual(403, message.default_uint32)
+ test_case.assertEqual(404, message.default_uint64)
+ test_case.assertEqual(405, message.default_sint32)
+ test_case.assertEqual(406, message.default_sint64)
+ test_case.assertEqual(407, message.default_fixed32)
+ test_case.assertEqual(408, message.default_fixed64)
+ test_case.assertEqual(409, message.default_sfixed32)
+ test_case.assertEqual(410, message.default_sfixed64)
+ test_case.assertEqual(411, message.default_float)
+ test_case.assertEqual(412, message.default_double)
+ test_case.assertEqual(False, message.default_bool)
+ test_case.assertEqual('415', message.default_string)
+ test_case.assertEqual('416', message.default_bytes)
+
+ test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ message.default_nested_enum)
+ test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
+ message.default_foreign_enum)
+ test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
def GoldenFile(filename):
"""Finds the given golden file and returns a file object representing it."""
@@ -570,21 +569,21 @@ def SetAllPackedFields(message):
Args:
message: A unittest_pb2.TestPackedTypes instance.
"""
- message.packed_int32.extend([101, 102])
- message.packed_int64.extend([103, 104])
- message.packed_uint32.extend([105, 106])
- message.packed_uint64.extend([107, 108])
- message.packed_sint32.extend([109, 110])
- message.packed_sint64.extend([111, 112])
- message.packed_fixed32.extend([113, 114])
- message.packed_fixed64.extend([115, 116])
- message.packed_sfixed32.extend([117, 118])
- message.packed_sfixed64.extend([119, 120])
- message.packed_float.extend([121.0, 122.0])
- message.packed_double.extend([122.0, 123.0])
+ message.packed_int32.extend([601, 701])
+ message.packed_int64.extend([602, 702])
+ message.packed_uint32.extend([603, 703])
+ message.packed_uint64.extend([604, 704])
+ message.packed_sint32.extend([605, 705])
+ message.packed_sint64.extend([606, 706])
+ message.packed_fixed32.extend([607, 707])
+ message.packed_fixed64.extend([608, 708])
+ message.packed_sfixed32.extend([609, 709])
+ message.packed_sfixed64.extend([610, 710])
+ message.packed_float.extend([611.0, 711.0])
+ message.packed_double.extend([612.0, 712.0])
message.packed_bool.extend([True, False])
- message.packed_enum.extend([unittest_pb2.FOREIGN_FOO,
- unittest_pb2.FOREIGN_BAR])
+ message.packed_enum.extend([unittest_pb2.FOREIGN_BAR,
+ unittest_pb2.FOREIGN_BAZ])
def SetAllPackedExtensions(message):
@@ -596,17 +595,41 @@ def SetAllPackedExtensions(message):
extensions = message.Extensions
pb2 = unittest_pb2
- extensions[pb2.packed_int32_extension].append(101)
- extensions[pb2.packed_int64_extension].append(102)
- extensions[pb2.packed_uint32_extension].append(103)
- extensions[pb2.packed_uint64_extension].append(104)
- extensions[pb2.packed_sint32_extension].append(105)
- extensions[pb2.packed_sint64_extension].append(106)
- extensions[pb2.packed_fixed32_extension].append(107)
- extensions[pb2.packed_fixed64_extension].append(108)
- extensions[pb2.packed_sfixed32_extension].append(109)
- extensions[pb2.packed_sfixed64_extension].append(110)
- extensions[pb2.packed_float_extension].append(111.0)
- extensions[pb2.packed_double_extension].append(112.0)
- extensions[pb2.packed_bool_extension].append(True)
- extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ)
+ extensions[pb2.packed_int32_extension].extend([601, 701])
+ extensions[pb2.packed_int64_extension].extend([602, 702])
+ extensions[pb2.packed_uint32_extension].extend([603, 703])
+ extensions[pb2.packed_uint64_extension].extend([604, 704])
+ extensions[pb2.packed_sint32_extension].extend([605, 705])
+ extensions[pb2.packed_sint64_extension].extend([606, 706])
+ extensions[pb2.packed_fixed32_extension].extend([607, 707])
+ extensions[pb2.packed_fixed64_extension].extend([608, 708])
+ extensions[pb2.packed_sfixed32_extension].extend([609, 709])
+ extensions[pb2.packed_sfixed64_extension].extend([610, 710])
+ extensions[pb2.packed_float_extension].extend([611.0, 711.0])
+ extensions[pb2.packed_double_extension].extend([612.0, 712.0])
+ extensions[pb2.packed_bool_extension].extend([True, False])
+ extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR,
+ unittest_pb2.FOREIGN_BAZ])
+
+
+def SetAllUnpackedFields(message):
+ """Sets every field in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestUnpackedTypes instance.
+ """
+ message.unpacked_int32.extend([601, 701])
+ message.unpacked_int64.extend([602, 702])
+ message.unpacked_uint32.extend([603, 703])
+ message.unpacked_uint64.extend([604, 704])
+ message.unpacked_sint32.extend([605, 705])
+ message.unpacked_sint64.extend([606, 706])
+ message.unpacked_fixed32.extend([607, 707])
+ message.unpacked_fixed64.extend([608, 708])
+ message.unpacked_sfixed32.extend([609, 709])
+ message.unpacked_sfixed64.extend([610, 710])
+ message.unpacked_float.extend([611.0, 711.0])
+ message.unpacked_double.extend([612.0, 712.0])
+ message.unpacked_bool.extend([True, False])
+ message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
+ unittest_pb2.FOREIGN_BAZ])
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 0cf27186..0208139e 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -43,7 +43,7 @@ from google.protobuf import unittest_pb2
from google.protobuf import unittest_mset_pb2
-class TextFormatTest(test_util.GoldenMessageTestCase):
+class TextFormatTest(unittest.TestCase):
def ReadGolden(self, golden_filename):
f = test_util.GoldenFile(golden_filename)
golden_lines = f.readlines()
@@ -149,7 +149,7 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
parsed_message = unittest_pb2.TestAllTypes()
text_format.Merge(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
- self.ExpectAllFieldsSet(message)
+ test_util.ExpectAllFieldsSet(self, message)
def testMergeAllExtensions(self):
message = unittest_pb2.TestAllExtensions()
@@ -212,12 +212,18 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
text_format.Merge, text, message)
def testMergeBadExtension(self):
- message = unittest_pb2.TestAllTypes()
+ message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
self.assertRaisesWithMessage(
text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.',
text_format.Merge, text, message)
+ message = unittest_pb2.TestAllTypes()
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
+ 'extensions.'),
+ text_format.Merge, text, message)
def testMergeGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
@@ -231,6 +237,19 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
text_format.ParseError, '1:16 : Expected "}".',
text_format.Merge, text, message)
+ def testMergeEmptyGroup(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: {}'
+ text_format.Merge(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
+ message.Clear()
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'OptionalGroup: <>'
+ text_format.Merge(text, message)
+ self.assertTrue(message.HasField('optionalgroup'))
+
def testMergeBadEnumValue(self):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: BARR'
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index a3bc57ff..2b3cd4de 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -192,47 +192,72 @@ TYPE_TO_BYTE_SIZE_FN = {
}
-# Maps from field type to an unbound Encoder method F, such that
-# F(encoder, field_number, value) will append the serialization
-# of a value of this type to the encoder.
-_Encoder = encoder.Encoder
-TYPE_TO_SERIALIZE_METHOD = {
- _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble,
- _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat,
- _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64,
- _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64,
- _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32,
- _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64,
- _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32,
- _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool,
- _FieldDescriptor.TYPE_STRING: _Encoder.AppendString,
- _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup,
- _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage,
- _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes,
- _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32,
- _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum,
- _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32,
- _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64,
- _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32,
- _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64,
+# Maps from field types to encoder constructors.
+TYPE_TO_ENCODER = {
+ _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
+ _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
+ _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
+ _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
+ _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
+ _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
+ _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
+ _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
+ _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
+ _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
+ _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
+ _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
+ _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
+ _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
+ _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
+ _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
+ _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
+ _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
}
-TYPE_TO_NOTAG_SERIALIZE_METHOD = {
- _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag,
- _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag,
- _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag,
- _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag,
- _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag,
- _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag,
- _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag,
- _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag,
- _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag,
- _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag,
- _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag,
- _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag,
- _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag,
- _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag,
+# Maps from field types to sizer constructors.
+TYPE_TO_SIZER = {
+ _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
+ _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
+ _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
+ _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
+ _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
+ _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
+ _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
+ _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
+ _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
+ _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
+ _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
+ _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
+ _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
+ _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
+ _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
+ _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
+ _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
+ _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
+ }
+
+
+# Maps from field type to a decoder constructor.
+TYPE_TO_DECODER = {
+ _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
+ _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
+ _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
+ _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
+ _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
+ _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
+ _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
+ _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
+ _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
+ _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
+ _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
+ _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
+ _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
+ _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
+ _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
+ _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
+ _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
+ _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
}
# Maps from field type to expected wiretype.
@@ -259,29 +284,3 @@ FIELD_TYPE_TO_WIRE_TYPE = {
_FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
}
-
-
-# Maps from field type to an unbound Decoder method F,
-# such that F(decoder) will read a field of the requested type.
-#
-# Note that Message and Group are intentionally missing here.
-# They're handled by _RecursivelyMerge().
-_Decoder = decoder.Decoder
-TYPE_TO_DESERIALIZE_METHOD = {
- _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble,
- _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat,
- _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64,
- _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64,
- _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32,
- _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64,
- _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32,
- _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool,
- _FieldDescriptor.TYPE_STRING: _Decoder.ReadString,
- _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes,
- _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32,
- _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum,
- _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32,
- _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64,
- _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32,
- _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64,
- }
diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py
index da6464de..c941fe1a 100755
--- a/python/google/protobuf/internal/wire_format.py
+++ b/python/google/protobuf/internal/wire_format.py
@@ -33,16 +33,17 @@
__author__ = 'robinson@google.com (Will Robinson)'
import struct
+from google.protobuf import descriptor
from google.protobuf import message
TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
-_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
+TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
# These numbers identify the wire type of a protocol buffer value.
# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
# tag-and-type to store one of these WIRETYPE_* constants.
-# These values must match WireType enum in //net/proto2/public/wire_format.h.
+# These values must match WireType enum in google/protobuf/wire_format.h.
WIRETYPE_VARINT = 0
WIRETYPE_FIXED64 = 1
WIRETYPE_LENGTH_DELIMITED = 2
@@ -93,7 +94,7 @@ def UnpackTag(tag):
"""The inverse of PackTag(). Given an unsigned 32-bit number,
returns a (field_number, wire_type) tuple.
"""
- return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK)
+ return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
def ZigZagEncode(value):
@@ -245,3 +246,23 @@ def _VarUInt64ByteSizeNoTag(uint64):
if uint64 > UINT64_MAX:
raise message.EncodeError('Value out of range: %d' % uint64)
return 10
+
+
+NON_PACKABLE_TYPES = (
+ descriptor.FieldDescriptor.TYPE_STRING,
+ descriptor.FieldDescriptor.TYPE_GROUP,
+ descriptor.FieldDescriptor.TYPE_MESSAGE,
+ descriptor.FieldDescriptor.TYPE_BYTES
+)
+
+
+def IsTypePackable(field_type):
+ """Return true iff packable = true is valid for fields of this type.
+
+ Args:
+ field_type: a FieldDescriptor::Type value.
+
+ Returns:
+ True iff fields of this type are packable.
+ """
+ return field_type not in NON_PACKABLE_TYPES
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index 9a88bdc8..f8398474 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -99,7 +99,7 @@ class Message(object):
Args:
other_msg: Message to copy into the current one.
"""
- if self == other_msg:
+ if self is other_msg:
return
self.Clear()
self.MergeFrom(other_msg)
@@ -108,6 +108,15 @@ class Message(object):
"""Clears all data that was set in the message."""
raise NotImplementedError
+ def SetInParent(self):
+ """Mark this as present in the parent.
+
+ This normally happens automatically when you assign a field of a
+ sub-message, but sometimes you want to make the sub-message
+ present while keeping it empty. If you find yourself using this,
+ you may want to reconsider your design."""
+ raise NotImplementedError
+
def IsInitialized(self):
"""Checks if the message is initialized.
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index d65d8b67..5b238031 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -50,9 +50,13 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
-import heapq
-import threading
+try:
+ from cStringIO import StringIO
+except ImportError:
+ from StringIO import StringIO
+import struct
import weakref
+
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import containers
from google.protobuf.internal import decoder
@@ -139,14 +143,26 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ cls._decoders_by_tag = {}
+ cls._extensions_by_name = {}
+ cls._extensions_by_number = {}
+ if (descriptor.has_options and
+ descriptor.GetOptions().message_set_wire_format):
+ cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+ decoder.MessageSetItemDecoder(cls._extensions_by_number))
+
# We act as a "friend" class of the descriptor, setting
# its _concrete_class attribute the first time we use a
# given descriptor to initialize a concrete protocol message
- # class.
+ # class. We also attach stuff to each FieldDescriptor for quick
+ # lookup later on.
concrete_class_attr_name = '_concrete_class'
if not hasattr(descriptor, concrete_class_attr_name):
setattr(descriptor, concrete_class_attr_name, cls)
- cls._known_extensions = []
+ for field in descriptor.fields:
+ _AttachFieldHelpers(cls, field)
+
_AddEnumValues(descriptor, cls)
_AddInitMethod(descriptor, cls)
_AddPropertiesForFields(descriptor, cls)
@@ -184,30 +200,33 @@ def _PropertyName(proto_field_name):
# return proto_field_name + "_"
# return proto_field_name
# """
+ # Kenton says: The above is a BAD IDEA. People rely on being able to use
+ # getattr() and setattr() to reflectively manipulate field values. If we
+ # rename the properties, then every such user has to also make sure to apply
+ # the same transformation. Note that currently if you name a field "yield",
+ # you can still access it just fine using getattr/setattr -- it's not even
+ # that cumbersome to do so.
+ # TODO(kenton): Remove this method entirely if/when everyone agrees with my
+ # position.
return proto_field_name
-def _ValueFieldName(proto_field_name):
- """Returns the name of the (internal) instance attribute which objects
- should use to store the current value for a given protocol message field.
-
- Args:
- proto_field_name: The protocol message field name, exactly
- as it appears (or would appear) in a .proto file.
- """
- return '_value_' + proto_field_name
+def _VerifyExtensionHandle(message, extension_handle):
+ """Verify that the given extension handle is valid."""
+ if not isinstance(extension_handle, _FieldDescriptor):
+ raise KeyError('HasExtension() expects an extension handle, got: %s' %
+ extension_handle)
-def _HasFieldName(proto_field_name):
- """Returns the name of the (internal) instance attribute which
- objects should use to store a boolean telling whether this field
- is explicitly set or not.
+ if not extension_handle.is_extension:
+ raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
- Args:
- proto_field_name: The protocol message field name, exactly
- as it appears (or would appear) in a .proto file.
- """
- return '_has_' + proto_field_name
+ if extension_handle.containing_type is not message.DESCRIPTOR:
+ raise KeyError('Extension "%s" extends message type "%s", but this '
+ 'message is of type "%s".' %
+ (extension_handle.full_name,
+ extension_handle.containing_type.full_name,
+ message.DESCRIPTOR.full_name))
def _AddSlots(message_descriptor, dictionary):
@@ -218,16 +237,57 @@ def _AddSlots(message_descriptor, dictionary):
message_descriptor: A Descriptor instance describing this message type.
dictionary: Class dictionary to which we'll add a '__slots__' entry.
"""
- field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields]
- field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields
- if f.label != _FieldDescriptor.LABEL_REPEATED)
- field_names.extend(('Extensions',
- '_cached_byte_size',
- '_cached_byte_size_dirty',
- '_called_transition_to_nonempty',
- '_listener',
- '_lock', '__weakref__'))
- dictionary['__slots__'] = field_names
+ dictionary['__slots__'] = ['_cached_byte_size',
+ '_cached_byte_size_dirty',
+ '_fields',
+ '_is_present_in_parent',
+ '_listener',
+ '_listener_for_children',
+ '__weakref__']
+
+
+def _IsMessageSetExtension(field):
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type == field.extension_scope and
+ field.label == _FieldDescriptor.LABEL_OPTIONAL)
+
+
+def _AttachFieldHelpers(cls, field_descriptor):
+ is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+
+ if _IsMessageSetExtension(field_descriptor):
+ field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
+ sizer = encoder.MessageSetItemSizer(field_descriptor.number)
+ else:
+ field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed)
+ sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed)
+
+ field_descriptor._encoder = field_encoder
+ field_descriptor._sizer = sizer
+ field_descriptor._default_constructor = _DefaultValueConstructorForField(
+ field_descriptor)
+
+ def AddDecoder(wiretype, is_packed):
+ tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+ cls._decoders_by_tag[tag_bytes] = (
+ type_checkers.TYPE_TO_DECODER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor))
+
+ AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
+ False)
+
+ if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
+ # To support wire compatibility of adding packed = true, add a decoder for
+ # packed values regardless of the field's options.
+ AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
@@ -249,44 +309,51 @@ def _AddEnumValues(descriptor, cls):
setattr(cls, enum_value.name, enum_value.number)
-def _DefaultValueForField(message, field):
- """Returns a default value for a field.
+def _DefaultValueConstructorForField(field):
+ """Returns a function which returns a default value for a field.
Args:
+ field: FieldDescriptor object for this field.
+
+ The returned function has one argument:
message: Message instance containing this field, or a weakref proxy
of same.
- field: FieldDescriptor object for this field.
- Returns: A default value for this field. May refer back to |message|
- via a weak reference.
+ That function in turn returns a default value for this field. The default
+ value may refer back to |message| via a weak reference.
"""
- # TODO(robinson): Only the repeated fields need a reference to 'message' (so
- # that they can set the 'has' bit on the containing Message when someone
- # append()s a value). We could special-case this, and avoid an extra
- # function call on __init__() and Clear() for non-repeated fields.
-
- # TODO(robinson): Find a better place for the default value assertion in this
- # function. No need to repeat them every time the client calls Clear('foo').
- # (We should probably just assert these things once and as early as possible,
- # by tightening checking in the descriptor classes.)
+
if field.label == _FieldDescriptor.LABEL_REPEATED:
if field.default_value != []:
raise ValueError('Repeated field default value not empty list: %s' % (
field.default_value))
- listener = _Listener(message, None)
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
# We can't look at _concrete_class yet since it might not have
# been set. (Depends on order in which we initialize the classes).
- return containers.RepeatedCompositeFieldContainer(
- listener, field.message_type)
+ message_type = field.message_type
+ def MakeRepeatedMessageDefault(message):
+ return containers.RepeatedCompositeFieldContainer(
+ message._listener_for_children, field.message_type)
+ return MakeRepeatedMessageDefault
else:
- return containers.RepeatedScalarFieldContainer(
- listener, type_checkers.GetTypeChecker(field.cpp_type, field.type))
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+ def MakeRepeatedScalarDefault(message):
+ return containers.RepeatedScalarFieldContainer(
+ message._listener_for_children, type_checker)
+ return MakeRepeatedScalarDefault
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- assert field.default_value is None
+ # _concrete_class may not yet be initialized.
+ message_type = field.message_type
+ def MakeSubMessageDefault(message):
+ result = message_type._concrete_class()
+ result._SetListener(message._listener_for_children)
+ return result
+ return MakeSubMessageDefault
- return field.default_value
+ def MakeScalarDefault(message):
+ return field.default_value
+ return MakeScalarDefault
def _AddInitMethod(message_descriptor, cls):
@@ -295,21 +362,29 @@ def _AddInitMethod(message_descriptor, cls):
def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = False
+ self._fields = {}
+ self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
- self._called_transition_to_nonempty = False
- # TODO(robinson): We should only create a lock if we really need one
- # in this class.
- self._lock = threading.Lock()
- for field in fields:
- default_value = _DefaultValueForField(self, field)
- python_field_name = _ValueFieldName(field.name)
- setattr(self, python_field_name, default_value)
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- setattr(self, _HasFieldName(field.name), False)
- self.Extensions = _ExtensionDict(self, cls._known_extensions)
+ self._listener_for_children = _Listener(self)
for field_name, field_value in kwargs.iteritems():
field = _GetFieldByName(message_descriptor, field_name)
- _MergeFieldOrExtension(self, field, field_value)
+ if field is None:
+ raise TypeError("%s() got an unexpected keyword argument '%s'" %
+ (message_descriptor.name, field_name))
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ copy = field._default_constructor(self)
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
+ for val in field_value:
+ copy.add().MergeFrom(val)
+ else: # Scalar
+ copy.extend(field_value)
+ self._fields[field] = copy
+ elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ copy = field._default_constructor(self)
+ copy.MergeFrom(field_value)
+ self._fields[field] = copy
+ else:
+ self._fields[field] = field_value
init.__module__ = None
init.__doc__ = None
@@ -336,6 +411,11 @@ def _AddPropertiesForFields(descriptor, cls):
for field in descriptor.fields:
_AddPropertiesForField(field, cls)
+ if descriptor.is_extendable:
+ # _ExtensionDict is just an adaptor with no state so we allocate a new one
+ # every time it is accessed.
+ cls.Extensions = property(lambda self: _ExtensionDict(self))
+
def _AddPropertiesForField(field, cls):
"""Adds a public property for a protocol message field.
@@ -377,11 +457,22 @@ def _AddPropertiesForRepeatedField(field, cls):
cls: The class we're constructing.
"""
proto_field_name = field.name
- python_field_name = _ValueFieldName(proto_field_name)
property_name = _PropertyName(proto_field_name)
def getter(self):
- return getattr(self, python_field_name)
+ field_value = self._fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = field._default_constructor(self)
+
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ field_value = self._fields.setdefault(field, field_value)
+ return field_value
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
@@ -407,21 +498,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
cls: The class we're constructing.
"""
proto_field_name = field.name
- python_field_name = _ValueFieldName(proto_field_name)
- has_field_name = _HasFieldName(proto_field_name)
property_name = _PropertyName(proto_field_name)
type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+ default_value = field.default_value
def getter(self):
- return getattr(self, python_field_name)
+ return self._fields.get(field, default_value)
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
def setter(self, new_value):
type_checker.CheckValue(new_value)
- setattr(self, has_field_name, True)
- self._MarkByteSizeDirty()
- self._MaybeCallTransitionToNonemptyCallback()
- setattr(self, python_field_name, new_value)
+ self._fields[field] = new_value
+ # Check _cached_byte_size_dirty inline to improve performance, since scalar
+ # setters are called frequently.
+ if not self._cached_byte_size_dirty:
+ self._Modified()
setter.__module__ = None
setter.__doc__ = 'Setter for %s.' % proto_field_name
@@ -444,25 +535,23 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# TODO(robinson): Remove duplication with similar method
# for non-repeated scalars.
proto_field_name = field.name
- python_field_name = _ValueFieldName(proto_field_name)
- has_field_name = _HasFieldName(proto_field_name)
property_name = _PropertyName(proto_field_name)
message_type = field.message_type
def getter(self):
- # TODO(robinson): Appropriately scary note about double-checked locking.
- field_value = getattr(self, python_field_name)
+ field_value = self._fields.get(field)
if field_value is None:
- self._lock.acquire()
- try:
- field_value = getattr(self, python_field_name)
- if field_value is None:
- field_class = message_type._concrete_class
- field_value = field_class()
- field_value._SetListener(_Listener(self, has_field_name))
- setattr(self, python_field_name, field_value)
- finally:
- self._lock.release()
+ # Construct a new object to represent this field.
+ field_value = message_type._concrete_class()
+ field_value._SetListener(self._listener_for_children)
+
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ field_value = self._fields.setdefault(field, field_value)
return field_value
getter.__module__ = None
getter.__doc__ = 'Getter for %s.' % proto_field_name
@@ -490,7 +579,27 @@ def _AddStaticMethods(cls):
# TODO(robinson): This probably needs to be thread-safe(?)
def RegisterExtension(extension_handle):
extension_handle.containing_type = cls.DESCRIPTOR
- cls._known_extensions.append(extension_handle)
+ _AttachFieldHelpers(cls, extension_handle)
+
+ # Try to insert our extension, failing if an extension with the same number
+ # already exists.
+ actual_handle = cls._extensions_by_number.setdefault(
+ extension_handle.number, extension_handle)
+ if actual_handle is not extension_handle:
+ raise AssertionError(
+ 'Extensions "%s" and "%s" both try to extend message type "%s" with '
+ 'field number %d.' %
+ (extension_handle.full_name, actual_handle.full_name,
+ cls.DESCRIPTOR.full_name, extension_handle.number))
+
+ cls._extensions_by_name[extension_handle.full_name] = extension_handle
+
+ handle = extension_handle # avoid line wrapping
+ if _IsMessageSetExtension(handle):
+ # MessageSet extension. Also register under type name.
+ cls._extensions_by_name[
+ extension_handle.message_type.full_name] = extension_handle
+
cls.RegisterExtension = staticmethod(RegisterExtension)
def FromString(s):
@@ -500,115 +609,107 @@ def _AddStaticMethods(cls):
cls.FromString = staticmethod(FromString)
+def _IsPresent(item):
+ """Given a (FieldDescriptor, value) tuple from _fields, return true if the
+ value should be included in the list returned by ListFields()."""
+
+ if item[0].label == _FieldDescriptor.LABEL_REPEATED:
+ return bool(item[1])
+ elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ return item[1]._is_present_in_parent
+ else:
+ return True
+
+
def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- # Ensure that we always list in ascending field-number order.
- # For non-extension fields, we can do the sort once, here, at import-time.
- # For extensions, we sort on each ListFields() call, though
- # we could do better if we have to.
- fields = sorted(message_descriptor.fields, key=lambda f: f.number)
- has_field_names = (_HasFieldName(f.name) for f in fields)
- value_field_names = (_ValueFieldName(f.name) for f in fields)
- triplets = zip(has_field_names, value_field_names, fields)
-
def ListFields(self):
- # We need to list all extension and non-extension fields
- # together, in sorted order by field number.
-
- # Step 0: Get an iterator over all "set" non-extension fields,
- # sorted by field number.
- # This iterator yields (field_number, field_descriptor, value) tuples.
- def SortedSetFieldsIter():
- # Note that triplets is already sorted by field number.
- for has_field_name, value_field_name, field_descriptor in triplets:
- if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
- value = getattr(self, _ValueFieldName(field_descriptor.name))
- if len(value) > 0:
- yield (field_descriptor.number, field_descriptor, value)
- elif getattr(self, _HasFieldName(field_descriptor.name)):
- value = getattr(self, _ValueFieldName(field_descriptor.name))
- yield (field_descriptor.number, field_descriptor, value)
- sorted_fields = SortedSetFieldsIter()
-
- # Step 1: Get an iterator over all "set" extension fields,
- # sorted by field number.
- # This iterator ALSO yields (field_number, field_descriptor, value) tuples.
- # TODO(robinson): It's not necessary to repeat this with each
- # serialization call. We can do better.
- sorted_extension_fields = sorted(
- [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()])
-
- # Step 2: Create a composite iterator that merges the extension-
- # and non-extension fields, and that still yields fields in
- # sorted order.
- all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields)
-
- # Step 3: Strip off the field numbers and return.
- return [field[1:] for field in all_set_fields]
+ all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
+ all_fields.sort(key = lambda item: item[0].number)
+ return all_fields
cls.ListFields = ListFields
-def _AddHasFieldMethod(cls):
+
+def _AddHasFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
+
+ singular_fields = {}
+ for field in message_descriptor.fields:
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ singular_fields[field.name] = field
+
def HasField(self, field_name):
try:
- return getattr(self, _HasFieldName(field_name))
- except AttributeError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ field = singular_fields[field_name]
+ except KeyError:
+ raise ValueError(
+ 'Protocol message has no singular "%s" field.' % field_name)
+
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(field)
+ return value is not None and value._is_present_in_parent
+ else:
+ return field in self._fields
cls.HasField = HasField
-def _AddClearFieldMethod(cls):
+def _AddClearFieldMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def ClearField(self, field_name):
- field = _GetFieldByName(self.DESCRIPTOR, field_name)
- proto_field_name = field.name
- python_field_name = _ValueFieldName(proto_field_name)
- has_field_name = _HasFieldName(proto_field_name)
- default_value = _DefaultValueForField(self, field)
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- self._MarkByteSizeDirty()
- else:
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- old_field_value = getattr(self, python_field_name)
- if old_field_value is not None:
- # Snip the old object out of the object tree.
- old_field_value._SetListener(None)
- if getattr(self, has_field_name):
- setattr(self, has_field_name, False)
- # Set dirty bit on ourself and parents only if
- # we're actually changing state.
- self._MarkByteSizeDirty()
- setattr(self, python_field_name, default_value)
+ try:
+ field = message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+ if field in self._fields:
+ # Note: If the field is a sub-message, its listener will still point
+ # at us. That's fine, because the worst than can happen is that it
+ # will call _Modified() and invalidate our byte size. Big deal.
+ del self._fields[field]
+
+ # Always call _Modified() -- even if nothing was changed, this is
+ # a mutating method, and thus calling it should cause the field to become
+ # present in the parent message.
+ self._Modified()
+
cls.ClearField = ClearField
def _AddClearExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def ClearExtension(self, extension_handle):
- self.Extensions._ClearExtension(extension_handle)
+ _VerifyExtensionHandle(self, extension_handle)
+
+ # Similar to ClearField(), above.
+ if extension_handle in self._fields:
+ del self._fields[extension_handle]
+ self._Modified()
cls.ClearExtension = ClearExtension
-def _AddClearMethod(cls):
+def _AddClearMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def Clear(self):
# Clear fields.
- fields = self.DESCRIPTOR.fields
- for field in fields:
- self.ClearField(field.name)
- # Clear extensions.
- extensions = self.Extensions._ListSetExtensions()
- for extension in extensions:
- self.ClearExtension(extension[0])
+ self._fields = {}
+ self._Modified()
cls.Clear = Clear
def _AddHasExtensionMethod(cls):
"""Helper for _AddMessageMethods()."""
def HasExtension(self, extension_handle):
- return self.Extensions._HasExtension(extension_handle)
+ _VerifyExtensionHandle(self, extension_handle)
+ if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
+ raise KeyError('"%s" is repeated.' % extension_handle.full_name)
+
+ if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(extension_handle)
+ return value is not None and value._is_present_in_parent
+ else:
+ return extension_handle in self._fields
cls.HasExtension = HasExtension
@@ -622,26 +723,8 @@ def _AddEqualsMethod(message_descriptor, cls):
if self is other:
return True
- # Compare all fields contained directly in this message.
- for field_descriptor in message_descriptor.fields:
- label = field_descriptor.label
- property_name = _PropertyName(field_descriptor.name)
- # Non-repeated field equality requires matching "has" bits as well
- # as having an equal value.
- if label != _FieldDescriptor.LABEL_REPEATED:
- self_has = self.HasField(property_name)
- other_has = other.HasField(property_name)
- if self_has != other_has:
- return False
- if not self_has:
- # If the "has" bit for this field is False, we must stop here.
- # Otherwise we will recurse forever on recursively-defined protos.
- continue
- if getattr(self, property_name) != getattr(other, property_name):
- return False
+ return self.ListFields() == other.ListFields()
- # Compare the extensions present in both messages.
- return self.Extensions == other.Extensions
cls.__eq__ = __eq__
@@ -685,618 +768,202 @@ def _BytesForNonRepeatedElement(value, field_number, field_type):
def _AddByteSizeMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def BytesForField(message, field, value):
- """Returns the number of bytes required to serialize a single field
- in message. The field may be repeated or not, composite or not.
-
- Args:
- message: The Message instance containing a field of the given type.
- field: A FieldDescriptor describing the field of interest.
- value: The value whose byte size we're interested in.
-
- Returns: The number of bytes required to serialize the current value
- of "field" in "message", including space for tags and any other
- necessary information.
- """
-
- if _MessageSetField(field):
- return wire_format.MessageSetItemByteSize(field.number, value)
-
- field_number, field_type = field.number, field.type
-
- # Repeated fields.
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- elements = value
- else:
- elements = [value]
-
- if field.GetOptions().packed:
- content_size = _ContentBytesForPackedField(message, field, elements)
- if content_size:
- tag_size = wire_format.TagByteSize(field_number)
- length_size = wire_format.Int32ByteSizeNoTag(content_size)
- return tag_size + length_size + content_size
- else:
- return 0
- else:
- return sum(_BytesForNonRepeatedElement(element, field_number, field_type)
- for element in elements)
-
- def _ContentBytesForPackedField(self, field, value):
- """Returns the number of bytes required to serialize the actual
- content of a packed field (not including the tag or the encoding
- of the length.
-
- Args:
- self: The Message instance containing a field of the given type.
- field: A FieldDescriptor describing the field of interest.
- value: The value whose byte size we're interested in.
-
- Returns: The number of bytes required to serialize the current value
- of the packed "field" in "message", excluding space for tags and the
- length encoding.
- """
- size = sum(_BytesForNonRepeatedElement(element, field.number, field.type)
- for element in value)
- # In the packed case, there are no per element tags.
- return size - wire_format.TagByteSize(field.number) * len(value)
-
- fields = message_descriptor.fields
- has_field_names = (_HasFieldName(f.name) for f in fields)
- zipped = zip(has_field_names, fields)
-
def ByteSize(self):
if not self._cached_byte_size_dirty:
return self._cached_byte_size
size = 0
- # Hardcoded fields first.
- for has_field_name, field in zipped:
- if (field.label == _FieldDescriptor.LABEL_REPEATED
- or getattr(self, has_field_name)):
- value = getattr(self, _ValueFieldName(field.name))
- size += BytesForField(self, field, value)
- # Extensions next.
- for field, value in self.Extensions._ListSetExtensions():
- size += BytesForField(self, field, value)
+ for field_descriptor, field_value in self.ListFields():
+ size += field_descriptor._sizer(field_value)
self._cached_byte_size = size
self._cached_byte_size_dirty = False
+ self._listener_for_children.dirty = False
return size
- cls._ContentBytesForPackedField = _ContentBytesForPackedField
cls.ByteSize = ByteSize
-def _MessageSetField(field_descriptor):
- """Checks if a field should be serialized using the message set wire format.
-
- Args:
- field_descriptor: Descriptor of the field.
-
- Returns:
- True if the field should be serialized using the message set wire format,
- false otherwise.
- """
- return (field_descriptor.is_extension and
- field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and
- field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
- field_descriptor.containing_type.GetOptions().message_set_wire_format)
-
-
-def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder):
- """Appends the serialization of a single value to encoder.
-
- Args:
- value: Value to serialize.
- field_number: Field number of this value.
- field_descriptor: Descriptor of the field to serialize.
- encoder: encoder.Encoder object to which we should serialize this value.
- """
- if _MessageSetField(field_descriptor):
- encoder.AppendMessageSetItem(field_number, value)
- return
-
- try:
- method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type]
- method(encoder, field_number, value)
- except KeyError:
- raise message_mod.EncodeError('Unrecognized field type: %d' %
- field_descriptor.type)
-
-
-def _ImergeSorted(*streams):
- """Merges N sorted iterators into a single sorted iterator.
- Each element in streams must be an iterable that yields
- its elements in sorted order, and the elements contained
- in each stream must all be comparable.
-
- There may be repeated elements in the component streams or
- across the streams; the repeated elements will all be repeated
- in the merged iterator as well.
-
- I believe that the heapq module at HEAD in the Python
- sources has a method like this, but for now we roll our own.
- """
- iters = [iter(stream) for stream in streams]
- heap = []
- for index, it in enumerate(iters):
- try:
- heap.append((it.next(), index))
- except StopIteration:
- pass
- heapq.heapify(heap)
-
- while heap:
- smallest_value, idx = heap[0]
- yield smallest_value
- try:
- next_element = iters[idx].next()
- heapq.heapreplace(heap, (next_element, idx))
- except StopIteration:
- heapq.heappop(heap)
-
-
def _AddSerializeToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def SerializeToString(self):
# Check if the message has all of its required fields set.
errors = []
- if not _InternalIsInitialized(self, errors):
- raise message_mod.EncodeError('\n'.join(errors))
+ if not self.IsInitialized():
+ raise message_mod.EncodeError(
+ 'Message is missing required fields: ' +
+ ','.join(self.FindInitializationErrors()))
return self.SerializePartialToString()
cls.SerializeToString = SerializeToString
def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- Encoder = encoder.Encoder
def SerializePartialToString(self):
- encoder = Encoder()
- # We need to serialize all extension and non-extension fields
- # together, in sorted order by field number.
- for field_descriptor, field_value in self.ListFields():
- if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
- repeated_value = field_value
- else:
- repeated_value = [field_value]
- if field_descriptor.GetOptions().packed:
- # First, write the field number and WIRETYPE_LENGTH_DELIMITED.
- field_number = field_descriptor.number
- encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- # Next, write the number of bytes.
- content_bytes = self._ContentBytesForPackedField(
- field_descriptor, field_value)
- encoder.AppendInt32NoTag(content_bytes)
- # Finally, write the actual values.
- try:
- method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[
- field_descriptor.type]
- for value in repeated_value:
- method(encoder, value)
- except KeyError:
- raise message_mod.EncodeError('Unrecognized field type: %d' %
- field_descriptor.type)
- else:
- for element in repeated_value:
- _SerializeValueToEncoder(element, field_descriptor.number,
- field_descriptor, encoder)
- return encoder.ToString()
-
+ out = StringIO()
+ self._InternalSerialize(out.write)
+ return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
+ def InternalSerialize(self, write_bytes):
+ for field_descriptor, field_value in self.ListFields():
+ field_descriptor._encoder(write_bytes, field_value)
+ cls._InternalSerialize = InternalSerialize
-def _WireTypeForFieldType(field_type):
- """Given a field type, returns the expected wire type."""
- try:
- return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type]
- except KeyError:
- raise message_mod.DecodeError('Unknown field type: %d' % field_type)
-
-
-def _WireTypeForField(field_descriptor):
- """Given a field descriptor, returns the expected wire type."""
- if field_descriptor.GetOptions().packed:
- return wire_format.WIRETYPE_LENGTH_DELIMITED
- else:
- return _WireTypeForFieldType(field_descriptor.type)
-
-
-def _RecursivelyMerge(field_number, field_type, decoder, message):
- """Decodes a message from decoder into message.
- message is either a group or a nested message within some containing
- protocol message. If it's a group, we use the group protocol to
- deserialize, and if it's a nested message, we use the nested-message
- protocol.
-
- Args:
- field_number: The field number of message in its enclosing protocol buffer.
- field_type: The field type of message. Must be either TYPE_MESSAGE
- or TYPE_GROUP.
- decoder: Decoder to read from.
- message: Message to deserialize into.
- """
- if field_type == _FieldDescriptor.TYPE_MESSAGE:
- decoder.ReadMessageInto(message)
- elif field_type == _FieldDescriptor.TYPE_GROUP:
- decoder.ReadGroupInto(field_number, message)
- else:
- raise message_mod.DecodeError('Unexpected field type: %d' % field_type)
-
-
-def _DeserializeScalarFromDecoder(field_type, decoder):
- """Deserializes a scalar of the requested type from decoder. field_type must
- be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant.
- """
- try:
- method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type]
- return method(decoder)
- except KeyError:
- raise message_mod.DecodeError('Unrecognized field type: %d' % field_type)
-
-
-def _SkipField(field_number, wire_type, decoder):
- """Skips a field with the specified wire type.
-
- Args:
- field_number: Tag number of the field to skip.
- wire_type: Wire type of the field to skip.
- decoder: Decoder used to deserialize the messsage. It must be positioned
- just after reading the the tag and wire type of the field.
- """
- if wire_type == wire_format.WIRETYPE_VARINT:
- decoder.ReadUInt64()
- elif wire_type == wire_format.WIRETYPE_FIXED64:
- decoder.ReadFixed64()
- elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
- decoder.SkipBytes(decoder.ReadInt32())
- elif wire_type == wire_format.WIRETYPE_START_GROUP:
- _SkipGroup(field_number, decoder)
- elif wire_type == wire_format.WIRETYPE_END_GROUP:
- pass
- elif wire_type == wire_format.WIRETYPE_FIXED32:
- decoder.ReadFixed32()
- else:
- raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type)
-
-
-def _SkipGroup(group_number, decoder):
- """Skips a nested group from the decoder.
-
- Args:
- group_number: Tag number of the group to skip.
- decoder: Decoder used to deserialize the message. It must be positioned
- exactly at the beginning of the message that should be skipped.
- """
- while True:
- field_number, wire_type = decoder.ReadFieldNumberAndWireType()
- if (wire_type == wire_format.WIRETYPE_END_GROUP and
- field_number == group_number):
- return
- _SkipField(field_number, wire_type, decoder)
-
-
-def _DeserializeMessageSetItem(message, decoder):
- """Deserializes a message using the message set wire format.
-
- Args:
- message: Message to be parsed to.
- decoder: The decoder to be used to deserialize encoded data. Note that the
- decoder should be positioned just after reading the START_GROUP tag that
- began the messageset item.
- """
- field_number, wire_type = decoder.ReadFieldNumberAndWireType()
- if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2:
- raise message_mod.DecodeError(
- 'Incorrect message set wire format. '
- 'wire_type: %d, field_number: %d' % (wire_type, field_number))
-
- type_id = decoder.ReadInt32()
- field_number, wire_type = decoder.ReadFieldNumberAndWireType()
- if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3:
- raise message_mod.DecodeError(
- 'Incorrect message set wire format. '
- 'wire_type: %d, field_number: %d' % (wire_type, field_number))
-
- extension_dict = message.Extensions
- extensions_by_number = extension_dict._AllExtensionsByNumber()
- if type_id not in extensions_by_number:
- _SkipField(field_number, wire_type, decoder)
- return
-
- field_descriptor = extensions_by_number[type_id]
- value = extension_dict[field_descriptor]
- decoder.ReadMessageInto(value)
- # Read the END_GROUP tag.
- field_number, wire_type = decoder.ReadFieldNumberAndWireType()
- if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1:
- raise message_mod.DecodeError(
- 'Incorrect message set wire format. '
- 'wire_type: %d, field_number: %d' % (wire_type, field_number))
-
-
-def _DeserializeOneEntity(message_descriptor, message, decoder):
- """Deserializes the next wire entity from decoder into message.
-
- The next wire entity is either a scalar or a nested message, an
- element in a repeated field (the wire encoding in this case is the
- same), or a packed repeated field (in this case, the entire repeated
- field is read by a single call to _DeserializeOneEntity).
-
- Args:
- message_descriptor: A Descriptor instance describing all fields
- in message.
- message: The Message instance into which we're decoding our fields.
- decoder: The Decoder we're using to deserialize encoded data.
-
- Returns: The number of bytes read from decoder during this method.
- """
- initial_position = decoder.Position()
- field_number, wire_type = decoder.ReadFieldNumberAndWireType()
- extension_dict = message.Extensions
- extensions_by_number = extension_dict._AllExtensionsByNumber()
- if field_number in message_descriptor.fields_by_number:
- # Non-extension field.
- field_descriptor = message_descriptor.fields_by_number[field_number]
- value = getattr(message, _PropertyName(field_descriptor.name))
- def nonextension_setter_fn(scalar):
- setattr(message, _PropertyName(field_descriptor.name), scalar)
- scalar_setter_fn = nonextension_setter_fn
- elif field_number in extensions_by_number:
- # Extension field.
- field_descriptor = extensions_by_number[field_number]
- value = extension_dict[field_descriptor]
- def extension_setter_fn(scalar):
- extension_dict[field_descriptor] = scalar
- scalar_setter_fn = extension_setter_fn
- elif wire_type == wire_format.WIRETYPE_END_GROUP:
- # We assume we're being parsed as the group that's ended.
- return 0
- elif (wire_type == wire_format.WIRETYPE_START_GROUP and
- field_number == 1 and
- message_descriptor.GetOptions().message_set_wire_format):
- # A Message Set item.
- _DeserializeMessageSetItem(message, decoder)
- return decoder.Position() - initial_position
- else:
- _SkipField(field_number, wire_type, decoder)
- return decoder.Position() - initial_position
-
- # If we reach this point, we've identified the field as either
- # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|,
- # and |value| appropriately. Now actually deserialize the thing.
- #
- # field_descriptor: Describes the field we're deserializing.
- # value: The value currently stored in the field to deserialize.
- # Used only if the field is composite and/or repeated.
- # scalar_setter_fn: A function F such that F(scalar) will
- # set a nonrepeated scalar value for this field. Used only
- # if this field is a nonrepeated scalar.
-
- field_number = field_descriptor.number
- expected_wire_type = _WireTypeForField(field_descriptor)
- if wire_type != expected_wire_type:
- # Need to fill in uninterpreted_bytes. Work for the next CL.
- raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
-
- property_name = _PropertyName(field_descriptor.name)
- label = field_descriptor.label
- field_type = field_descriptor.type
- cpp_type = field_descriptor.cpp_type
-
- # Nonrepeated scalar. Just set the field directly.
- if (label != _FieldDescriptor.LABEL_REPEATED
- and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
- scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - initial_position
-
- # Nonrepeated composite. Recursively deserialize.
- if label != _FieldDescriptor.LABEL_REPEATED:
- composite = value
- _RecursivelyMerge(field_number, field_type, decoder, composite)
- return decoder.Position() - initial_position
-
- # Now we know we're dealing with a repeated field of some kind.
- element_list = value
-
- if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
- # Repeated scalar.
- if not field_descriptor.GetOptions().packed:
- element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - initial_position
- else:
- # Packed repeated field.
- length = _DeserializeScalarFromDecoder(
- _FieldDescriptor.TYPE_INT32, decoder)
- content_start = decoder.Position()
- while decoder.Position() - content_start < length:
- element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - initial_position
- else:
- # Repeated composite.
- composite = element_list.add()
- _RecursivelyMerge(field_number, field_type, decoder, composite)
- return decoder.Position() - initial_position
-
-
-def _FieldOrExtensionValues(message, field_or_extension):
- """Retrieves the list of values for the specified field or extension.
-
- The target field or extension can be optional, required or repeated, but it
- must have value(s) set. The assumption is that the target field or extension
- is set (e.g. _HasFieldOrExtension holds true).
- Args:
- message: Message which contains the target field or extension.
- field_or_extension: Field or extension for which the list of values is
- required. Must be an instance of FieldDescriptor.
+def _AddMergeFromStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def MergeFromString(self, serialized):
+ length = len(serialized)
+ try:
+ if self._InternalParse(serialized, 0, length) != length:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise message_mod.DecodeError('Unexpected end-group tag.')
+ except IndexError:
+ raise message_mod.DecodeError('Truncated message.')
+ except struct.error, e:
+ raise message_mod.DecodeError(e)
+ return length # Return this for legacy reasons.
+ cls.MergeFromString = MergeFromString
- Returns:
- A list of values for the specified field or extension. This list will only
- contain a single element if the field is non-repeated.
- """
- if field_or_extension.is_extension:
- value = message.Extensions[field_or_extension]
- else:
- value = getattr(message, _ValueFieldName(field_or_extension.name))
- if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED:
- return [value]
- else:
- # In this case value is a list or repeated values.
- return value
+ local_ReadTag = decoder.ReadTag
+ local_SkipField = decoder.SkipField
+ decoders_by_tag = cls._decoders_by_tag
+
+ def InternalParse(self, buffer, pos, end):
+ self._Modified()
+ field_dict = self._fields
+ while pos != end:
+ (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
+ field_decoder = decoders_by_tag.get(tag_bytes)
+ if field_decoder is None:
+ new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ pos = new_pos
+ else:
+ pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ return pos
+ cls._InternalParse = InternalParse
-def _HasFieldOrExtension(message, field_or_extension):
- """Checks if a message has the specified field or extension set.
+def _AddIsInitializedMethod(message_descriptor, cls):
+ """Adds the IsInitialized and FindInitializationError methods to the
+ protocol message class."""
- The field or extension specified can be optional, required or repeated. If
- it is repeated, this function returns True. Otherwise it checks the has bit
- of the field or extension.
+ required_fields = [field for field in message_descriptor.fields
+ if field.label == _FieldDescriptor.LABEL_REQUIRED]
- Args:
- message: Message which contains the target field or extension.
- field_or_extension: Field or extension to check. This must be a
- FieldDescriptor instance.
+ def IsInitialized(self, errors=None):
+ """Checks if all required fields of a message are set.
- Returns:
- True if the message has a value set for the specified field or extension,
- or if the field or extension is repeated.
- """
- if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED:
- return True
- if field_or_extension.is_extension:
- return message.HasExtension(field_or_extension)
- else:
- return message.HasField(field_or_extension.name)
+ Args:
+ errors: A list which, if provided, will be populated with the field
+ paths of all missing required fields.
+ Returns:
+ True iff the specified message has all required fields set.
+ """
-def _IsFieldOrExtensionInitialized(message, field, errors=None):
- """Checks if a message field or extension is initialized.
+ # Performance is critical so we avoid HasField() and ListFields().
- Args:
- message: The message which contains the field or extension.
- field: Field or extension to check. This must be a FieldDescriptor instance.
- errors: Errors will be appended to it, if set to a meaningful value.
+ for field in required_fields:
+ if (field not in self._fields or
+ (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
+ not self._fields[field]._is_present_in_parent)):
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
- Returns:
- True if the field/extension can be considered initialized.
- """
- # If the field is required and is not set, it isn't initialized.
- if field.label == _FieldDescriptor.LABEL_REQUIRED:
- if not _HasFieldOrExtension(message, field):
- if errors is not None:
- errors.append('Required field %s is not set.' % field.full_name)
- return False
+ for field, value in self._fields.iteritems():
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for element in value:
+ if not element.IsInitialized():
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
+ elif value._is_present_in_parent and not value.IsInitialized():
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
- # If the field is optional and is not set, or if it
- # isn't a submessage then the field is initialized.
- if field.label == _FieldDescriptor.LABEL_OPTIONAL:
- if not _HasFieldOrExtension(message, field):
- return True
- if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
return True
- # The field is set and is either a single or a repeated submessage.
- messages = _FieldOrExtensionValues(message, field)
- # If all submessages in this field are initialized, the field is
- # considered initialized.
- for message in messages:
- if not _InternalIsInitialized(message, errors):
- return False
- return True
-
+ cls.IsInitialized = IsInitialized
-def _InternalIsInitialized(message, errors=None):
- """Checks if all required fields of a message are set.
-
- Args:
- message: The message to check.
- errors: If set, initialization errors will be appended to it.
+ def FindInitializationErrors(self):
+ """Finds required fields which are not initialized.
- Returns:
- True iff the specified message has all required fields set.
- """
- fields_and_extensions = []
- fields_and_extensions.extend(message.DESCRIPTOR.fields)
- fields_and_extensions.extend(
- [extension[0] for extension in message.Extensions._ListSetExtensions()])
- for field_or_extension in fields_and_extensions:
- if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors):
- return False
- return True
-
-
-def _AddMergeFromStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- Decoder = decoder.Decoder
- def MergeFromString(self, serialized):
- decoder = Decoder(serialized)
- byte_count = 0
- while not decoder.EndOfStream():
- bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder)
- if not bytes_read:
- break
- byte_count += bytes_read
- return byte_count
- cls.MergeFromString = MergeFromString
-
-
-def _AddIsInitializedMethod(cls):
- """Adds the IsInitialized method to the protocol message class."""
- cls.IsInitialized = _InternalIsInitialized
+ Returns:
+ A list of strings. Each string is a path to an uninitialized field from
+ the top-level message, e.g. "foo.bar[5].baz".
+ """
+ errors = [] # simplify things
-def _MergeFieldOrExtension(destination_msg, field, value):
- """Merges a specified message field into another message."""
- property_name = _PropertyName(field.name)
- is_extension = field.is_extension
+ for field in required_fields:
+ if not self.HasField(field.name):
+ errors.append(field.name)
- if not is_extension:
- destination = getattr(destination_msg, property_name)
- elif (field.label == _FieldDescriptor.LABEL_REPEATED or
- field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
- destination = destination_msg.Extensions[field]
+ for field, value in self.ListFields():
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.is_extension:
+ name = "(%s)" % field.full_name
+ else:
+ name = field.name
- # Case 1 - a composite field.
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for v in value:
- destination.add().MergeFrom(v)
- else:
- destination.MergeFrom(value)
- return
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in xrange(len(value)):
+ element = value[i]
+ prefix = "%s[%d]." % (name, i)
+ sub_errors = element.FindInitializationErrors()
+ errors += [ prefix + error for error in sub_errors ]
+ else:
+ prefix = name + "."
+ sub_errors = value.FindInitializationErrors()
+ errors += [ prefix + error for error in sub_errors ]
- # Case 2 - a repeated field.
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for v in value:
- destination.append(v)
- return
+ return errors
- # Case 3 - a singular field.
- if is_extension:
- destination_msg.Extensions[field] = value
- else:
- setattr(destination_msg, property_name, value)
+ cls.FindInitializationErrors = FindInitializationErrors
def _AddMergeFromMethod(cls):
+ LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
+ CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
+
def MergeFrom(self, msg):
assert msg is not self
- for field in msg.ListFields():
- _MergeFieldOrExtension(self, field[0], field[1])
+ self._Modified()
+
+ fields = self._fields
+
+ for field, value in msg._fields.iteritems():
+ if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
+ field_value = fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = field._default_constructor(self)
+ fields[field] = field_value
+ field_value.MergeFrom(value)
+ else:
+ self._fields[field] = value
cls.MergeFrom = MergeFrom
def _AddMessageMethods(message_descriptor, cls):
"""Adds implementations of all Message methods to cls."""
_AddListFieldsMethod(message_descriptor, cls)
- _AddHasFieldMethod(cls)
- _AddClearFieldMethod(cls)
- _AddClearExtensionMethod(cls)
- _AddClearMethod(cls)
- _AddHasExtensionMethod(cls)
+ _AddHasFieldMethod(message_descriptor, cls)
+ _AddClearFieldMethod(message_descriptor, cls)
+ if message_descriptor.is_extendable:
+ _AddClearExtensionMethod(cls)
+ _AddHasExtensionMethod(cls)
+ _AddClearMethod(message_descriptor, cls)
_AddEqualsMethod(message_descriptor, cls)
_AddStrMethod(message_descriptor, cls)
_AddSetListenerMethod(cls)
@@ -1304,31 +971,30 @@ def _AddMessageMethods(message_descriptor, cls):
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
_AddMergeFromStringMethod(message_descriptor, cls)
- _AddIsInitializedMethod(cls)
+ _AddIsInitializedMethod(message_descriptor, cls)
_AddMergeFromMethod(cls)
def _AddPrivateHelperMethods(cls):
"""Adds implementation of private helper methods to cls."""
- def MaybeCallTransitionToNonemptyCallback(self):
- """Calls self._listener.TransitionToNonempty() the first time this
- method is called. On all subsequent calls, this is a no-op.
- """
- if not self._called_transition_to_nonempty:
- self._listener.TransitionToNonempty()
- self._called_transition_to_nonempty = True
- cls._MaybeCallTransitionToNonemptyCallback = (
- MaybeCallTransitionToNonemptyCallback)
-
- def MarkByteSizeDirty(self):
+ def Modified(self):
"""Sets the _cached_byte_size_dirty bit to true,
and propagates this to our listener iff this was a state change.
"""
+
+ # Note: Some callers check _cached_byte_size_dirty before calling
+ # _Modified() as an extra optimization. So, if this method is ever
+ # changed such that it does stuff even when _cached_byte_size_dirty is
+ # already true, the callers need to be updated.
if not self._cached_byte_size_dirty:
self._cached_byte_size_dirty = True
- self._listener.ByteSizeDirty()
- cls._MarkByteSizeDirty = MarkByteSizeDirty
+ self._listener_for_children.dirty = True
+ self._is_present_in_parent = True
+ self._listener.Modified()
+
+ cls._Modified = Modified
+ cls.SetInParent = Modified
class _Listener(object):
@@ -1338,22 +1004,17 @@ class _Listener(object):
In order to support semantics like:
- foo.bar.baz = 23
+ foo.bar.baz.qux = 23
assert foo.HasField('bar')
...child objects must have back references to their parents.
This helper class is at the heart of this support.
"""
- def __init__(self, parent_message, has_field_name):
+ def __init__(self, parent_message):
"""Args:
- parent_message: The message whose _MaybeCallTransitionToNonemptyCallback()
- and _MarkByteSizeDirty() methods we should call when we receive
- TransitionToNonempty() and ByteSizeDirty() messages.
- has_field_name: The name of the "has" field that we should set in
- the parent message when we receive a TransitionToNonempty message,
- or None if there's no "has" field to set. (This will be the case
- for child objects in "repeated" fields).
+ parent_message: The message whose _Modified() method we should call when
+ we receive Modified() messages.
"""
# This listener establishes a back reference from a child (contained) object
# to its parent (containing) object. We make this a weak reference to avoid
@@ -1363,36 +1024,27 @@ class _Listener(object):
self._parent_message_weakref = parent_message
else:
self._parent_message_weakref = weakref.proxy(parent_message)
- self._has_field_name = has_field_name
- def TransitionToNonempty(self):
+ # As an optimization, we also indicate directly on the listener whether
+ # or not the parent message is dirty. This way we can avoid traversing
+ # up the tree in the common case.
+ self.dirty = False
+
+ def Modified(self):
+ if self.dirty:
+ return
try:
- if self._has_field_name is not None:
- setattr(self._parent_message_weakref, self._has_field_name, True)
# Propagate the signal to our parents iff this is the first field set.
- self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback()
+ self._parent_message_weakref._Modified()
except ReferenceError:
# We can get here if a client has kept a reference to a child object,
# and is now setting a field on it, but the child's parent has been
# garbage-collected. This is not an error.
pass
- def ByteSizeDirty(self):
- try:
- self._parent_message_weakref._MarkByteSizeDirty()
- except ReferenceError:
- # Same as above.
- pass
-
# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
# TODO(robinson): Unify error handling of "unknown extension" crap.
-# TODO(robinson): There's so much similarity between the way that
-# extensions behave and the way that normal fields behave that it would
-# be really nice to unify more code. It's not immediately obvious
-# how to do this, though, and I'd rather get the full functionality
-# implemented (and, crucially, get all the tests and specs fleshed out
-# and passing), and then come back to this thorny unification problem.
# TODO(robinson): Support iteritems()-style iteration over all
# extensions with the "has" bits turned on?
class _ExtensionDict(object):
@@ -1404,250 +1056,85 @@ class _ExtensionDict(object):
FieldDescriptors.
"""
- class _ExtensionListener(object):
+ def __init__(self, extended_message):
+ """extended_message: Message instance for which we are the Extensions dict.
+ """
- """Adapts an _ExtensionDict to behave as a MessageListener."""
+ self._extended_message = extended_message
- def __init__(self, extension_dict, handle_id):
- self._extension_dict = extension_dict
- self._handle_id = handle_id
+ def __getitem__(self, extension_handle):
+ """Returns the current value of the given extension handle."""
- def TransitionToNonempty(self):
- self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id)
+ _VerifyExtensionHandle(self._extended_message, extension_handle)
- def ByteSizeDirty(self):
- self._extension_dict._SubmessageByteSizeBecameDirty()
+ result = self._extended_message._fields.get(extension_handle)
+ if result is not None:
+ return result
- # TODO(robinson): Somewhere, we need to blow up if people
- # try to register two extensions with the same field number.
- # (And we need a test for this of course).
+ if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
+ result = extension_handle._default_constructor(self._extended_message)
+ elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ result = extension_handle.message_type._concrete_class()
+ try:
+ result._SetListener(self._extended_message._listener_for_children)
+ except ReferenceError:
+ pass
+ else:
+ # Singular scalar -- just return the default without inserting into the
+ # dict.
+ return extension_handle.default_value
- def __init__(self, extended_message, known_extensions):
- """extended_message: Message instance for which we are the Extensions dict.
- known_extensions: Iterable of known extension handles.
- These must be FieldDescriptors.
- """
- # We keep a weak reference to extended_message, since
- # it has a reference to this instance in turn.
- self._extended_message = weakref.proxy(extended_message)
- # We make a deep copy of known_extensions to avoid any
- # thread-safety concerns, since the argument passed in
- # is the global (class-level) dict of known extensions for
- # this type of message, which could be modified at any time
- # via a RegisterExtension() call.
- #
- # This dict maps from handle id to handle (a FieldDescriptor).
- #
- # XXX
- # TODO(robinson): This isn't good enough. The client could
- # instantiate an object in module A, then afterward import
- # module B and pass the instance to B.Foo(). If B imports
- # an extender of this proto and then tries to use it, B
- # will get a KeyError, even though the extension *is* registered
- # at the time of use.
- # XXX
- self._known_extensions = dict((id(e), e) for e in known_extensions)
- # Read lock around self._values, which may be modified by multiple
- # concurrent readers in the conceptually "const" __getitem__ method.
- # So, we grab this lock in every "read-only" method to ensure
- # that concurrent read access is safe without external locking.
- self._lock = threading.Lock()
- # Maps from extension handle ID to current value of that extension.
- self._values = {}
- # Maps from extension handle ID to a boolean "has" bit, but only
- # for non-repeated extension fields.
- keys = (id for id, extension in self._known_extensions.iteritems()
- if extension.label != _FieldDescriptor.LABEL_REPEATED)
- self._has_bits = dict.fromkeys(keys, False)
-
- self._extensions_by_number = dict(
- (f.number, f) for f in self._known_extensions.itervalues())
-
- self._extensions_by_name = {}
- for extension in self._known_extensions.itervalues():
- if (extension.containing_type.GetOptions().message_set_wire_format and
- extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and
- extension.message_type == extension.extension_scope and
- extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL):
- extension_name = extension.message_type.full_name
- else:
- extension_name = extension.full_name
- self._extensions_by_name[extension_name] = extension
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ result = self._extended_message._fields.setdefault(
+ extension_handle, result)
- def __getitem__(self, extension_handle):
- """Returns the current value of the given extension handle."""
- # We don't care as much about keeping critical sections short in the
- # extension support, since it's presumably much less of a common case.
- self._lock.acquire()
- try:
- handle_id = id(extension_handle)
- if handle_id not in self._known_extensions:
- raise KeyError('Extension not known to this class')
- if handle_id not in self._values:
- self._AddMissingHandle(extension_handle, handle_id)
- return self._values[handle_id]
- finally:
- self._lock.release()
+ return result
def __eq__(self, other):
- # We have to grab read locks since we're accessing _values
- # in a "const" method. See the comment in the constructor.
- if self is other:
- return True
- self._lock.acquire()
- try:
- other._lock.acquire()
- try:
- if self._has_bits != other._has_bits:
- return False
- # If there's a "has" bit, then only compare values where it is true.
- for k, v in self._values.iteritems():
- if self._has_bits.get(k, False) and v != other._values[k]:
- return False
- return True
- finally:
- other._lock.release()
- finally:
- self._lock.release()
+ if not isinstance(other, self.__class__):
+ return False
+
+ my_fields = self._extended_message.ListFields()
+ other_fields = other._extended_message.ListFields()
+
+ # Get rid of non-extension fields.
+ my_fields = [ field for field in my_fields if field.is_extension ]
+ other_fields = [ field for field in other_fields if field.is_extension ]
+
+ return my_fields == other_fields
def __ne__(self, other):
return not self == other
# Note that this is only meaningful for non-repeated, scalar extension
- # fields. Note also that we may have to call
- # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field
- # this way, to set any necssary "has" bits in the ancestors of the extended
- # message.
+ # fields. Note also that we may have to call _Modified() when we do
+ # successfully set a field this way, to set any necssary "has" bits in the
+ # ancestors of the extended message.
def __setitem__(self, extension_handle, value):
"""If extension_handle specifies a non-repeated, scalar extension
field, sets the value of that field.
"""
- handle_id = id(extension_handle)
- if handle_id not in self._known_extensions:
- raise KeyError('Extension not known to this class')
- field = extension_handle # Just shorten the name.
- if (field.label == _FieldDescriptor.LABEL_OPTIONAL
- and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
- # It's slightly wasteful to lookup the type checker each time,
- # but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
- type_checker.CheckValue(value)
- self._values[handle_id] = value
- self._has_bits[handle_id] = True
- self._extended_message._MarkByteSizeDirty()
- self._extended_message._MaybeCallTransitionToNonemptyCallback()
- else:
- raise TypeError('Extension is repeated and/or a composite type.')
-
- def _AddMissingHandle(self, extension_handle, handle_id):
- """Helper internal to ExtensionDict."""
- # Special handling for non-repeated message extensions, which (like
- # normal fields of this kind) are initialized lazily.
- # REQUIRES: _lock already held.
- cpp_type = extension_handle.cpp_type
- label = extension_handle.label
- if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
- and label != _FieldDescriptor.LABEL_REPEATED):
- self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id)
- else:
- self._values[handle_id] = _DefaultValueForField(
- self._extended_message, extension_handle)
-
- def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id):
- """Helper internal to ExtensionDict."""
- # REQUIRES: _lock already held.
- value = extension_handle.message_type._concrete_class()
- value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id))
- self._values[handle_id] = value
-
- def _SubmessageTransitionedToNonempty(self, handle_id):
- """Called when a submessage with a given handle id first transitions to
- being nonempty. Called by _ExtensionListener.
- """
- assert handle_id in self._has_bits
- self._has_bits[handle_id] = True
- self._extended_message._MaybeCallTransitionToNonemptyCallback()
- def _SubmessageByteSizeBecameDirty(self):
- """Called whenever a submessage's cached byte size becomes invalid
- (goes from being "clean" to being "dirty"). Called by _ExtensionListener.
- """
- self._extended_message._MarkByteSizeDirty()
-
- # We may wish to widen the public interface of Message.Extensions
- # to expose some of this private functionality in the future.
- # For now, we make all this functionality module-private and just
- # implement what we need for serialization/deserialization,
- # HasField()/ClearField(), etc.
-
- def _HasExtension(self, extension_handle):
- """Method for internal use by this module.
- Returns true iff we "have" this extension in the sense of the
- "has" bit being set.
- """
- handle_id = id(extension_handle)
- # Note that this is different from the other checks.
- if handle_id not in self._has_bits:
- raise KeyError('Extension not known to this class, or is repeated field.')
- return self._has_bits[handle_id]
-
- # Intentionally pretty similar to ClearField() above.
- def _ClearExtension(self, extension_handle):
- """Method for internal use by this module.
- Clears the specified extension, unsetting its "has" bit.
- """
- handle_id = id(extension_handle)
- if handle_id not in self._known_extensions:
- raise KeyError('Extension not known to this class')
- default_value = _DefaultValueForField(self._extended_message,
- extension_handle)
- if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
- self._extended_message._MarkByteSizeDirty()
- else:
- cpp_type = extension_handle.cpp_type
- if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if handle_id in self._values:
- # Future modifications to this object shouldn't set any
- # "has" bits here.
- self._values[handle_id]._SetListener(None)
- if self._has_bits[handle_id]:
- self._has_bits[handle_id] = False
- self._extended_message._MarkByteSizeDirty()
- if handle_id in self._values:
- del self._values[handle_id]
-
- def _ListSetExtensions(self):
- """Method for internal use by this module.
-
- Returns an sequence of all extensions that are currently "set"
- in this extension dict. A "set" extension is a repeated extension,
- or a non-repeated extension with its "has" bit set.
-
- The returned sequence contains (field_descriptor, value) pairs,
- where value is the current value of the extension with the given
- field descriptor.
-
- The sequence values are in arbitrary order.
- """
- self._lock.acquire() # Read-only methods must lock around self._values.
- try:
- set_extensions = []
- for handle_id, value in self._values.iteritems():
- handle = self._known_extensions[handle_id]
- if (handle.label == _FieldDescriptor.LABEL_REPEATED
- or self._has_bits[handle_id]):
- set_extensions.append((handle, value))
- return set_extensions
- finally:
- self._lock.release()
-
- def _AllExtensionsByNumber(self):
- """Method for internal use by this module.
-
- Returns: A dict mapping field_number to (handle, field_descriptor),
- for *all* registered extensions for this dict.
- """
- return self._extensions_by_number
+ _VerifyExtensionHandle(self._extended_message, extension_handle)
+
+ if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
+ extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
+ raise TypeError(
+ 'Cannot assign to extension "%s" because it is a repeated or '
+ 'composite type.' % extension_handle.full_name)
+
+ # It's slightly wasteful to lookup the type checker each time,
+ # but we expect this to be a vanishingly uncommon case anyway.
+ type_checker = type_checkers.GetTypeChecker(
+ extension_handle.cpp_type, extension_handle.type)
+ type_checker.CheckValue(value)
+ self._extended_message._fields[extension_handle] = value
+ self._extended_message._Modified()
def _FindExtensionByName(self, name):
"""Tries to find a known extension with the specified name.
@@ -1658,4 +1145,4 @@ class _ExtensionDict(object):
Returns:
Extension field descriptor.
"""
- return self._extensions_by_name.get(name, None)
+ return self._extended_message._extensions_by_name.get(name, None)
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 1cddce6c..889aa836 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -149,6 +149,10 @@ def _MergeField(tokenizer, message):
name.append(tokenizer.ConsumeIdentifier())
name = '.'.join(name)
+ if not message_descriptor.is_extendable:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" does not have extensions.' %
+ message_descriptor.full_name)
field = message.Extensions._FindExtensionByName(name)
if not field:
raise tokenizer.ParseErrorPreviousToken(
@@ -198,6 +202,7 @@ def _MergeField(tokenizer, message):
sub_message = message.Extensions[field]
else:
sub_message = getattr(message, field.name)
+ sub_message.SetInParent()
while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd():