aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal/encoder.py
diff options
context:
space:
mode:
authorGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
committerGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-12-18 02:11:36 +0000
commitfccb146e3fe437b0df1e9c50d4b8e1080ddb4bd9 (patch)
tree9f2d9fe0267d96a54e541377ffeada3d0bff0d1d /python/google/protobuf/internal/encoder.py
parentd5cf7b55a6a1f959d1646785f63ca2b62da78079 (diff)
Massive roll-up of changes. See CHANGES.txt.
Diffstat (limited to 'python/google/protobuf/internal/encoder.py')
-rwxr-xr-xpython/google/protobuf/internal/encoder.py888
1 files changed, 647 insertions, 241 deletions
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 3ec3b2b1..aa05d5b3 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -28,253 +28,659 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""Class for encoding protocol message primitives.
+"""Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type
into one of the 5 physical wire types.
+
+This code is designed to push the Python interpreter's performance to the
+limits.
+
+The basic idea is that at startup time, for every field (i.e. every
+FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
+sizer takes a value of this field's type and computes its byte size. The
+encoder takes a writer function and a value. It encodes the value into byte
+strings and invokes the writer function to write those strings. Typically the
+writer function is the write() method of a cStringIO.
+
+We try to do as much work as possible when constructing the writer and the
+sizer rather than when calling them. In particular:
+* We copy any needed global functions to local variables, so that we do not need
+ to do costly global table lookups at runtime.
+* Similarly, we try to do any attribute lookups at startup time if possible.
+* Every field's tag is encoded to bytes at startup, since it can't change at
+ runtime.
+* Whatever component of the field size we can compute at startup, we do.
+* We *avoid* sharing code if doing so would make the code slower and not sharing
+ does not burden us too much. For example, encoders for repeated fields do
+ not just call the encoders for singular fields in a loop because this would
+ add an extra function call overhead for every loop iteration; instead, we
+ manually inline the single-value encoder into the loop.
+* If a Python function lacks a return statement, Python actually generates
+ instructions to pop the result of the last statement off the stack, push
+ None onto the stack, and then return that. If we really don't care what
+ value is returned, then we can save two instructions by returning the
+ result of the last statement. It looks funny but it helps.
+* We assume that type and bounds checking has happened at a higher level.
"""
-__author__ = 'robinson@google.com (Will Robinson)'
+__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-from google.protobuf import message
from google.protobuf.internal import wire_format
-from google.protobuf.internal import output_stream
-
-
-# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
-# that the interface is strongly inspired by WireFormat from the C++ proto2
-# implementation.
-
-
-class Encoder(object):
-
- """Encodes logical protocol buffer fields to the wire format."""
-
- def __init__(self):
- self._stream = output_stream.OutputStream()
-
- def ToString(self):
- """Returns all values encoded in this object as a string."""
- return self._stream.ToString()
-
- # Append*NoTag methods. These are necessary for serializing packed
- # repeated fields. The Append*() methods call these methods to do
- # the actual serialization.
- def AppendInt32NoTag(self, value):
- """Appends a 32-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarint32(value)
-
- def AppendInt64NoTag(self, value):
- """Appends a 64-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarint64(value)
-
- def AppendUInt32NoTag(self, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarUInt32(unsigned_value)
-
- def AppendUInt64NoTag(self, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
- self._stream.AppendVarUInt64(unsigned_value)
-
- def AppendSInt32NoTag(self, value):
- """Appends a 32-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- zigzag_value = wire_format.ZigZagEncode(value)
- self._stream.AppendVarUInt32(zigzag_value)
-
- def AppendSInt64NoTag(self, value):
- """Appends a 64-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- zigzag_value = wire_format.ZigZagEncode(value)
- self._stream.AppendVarUInt64(zigzag_value)
-
- def AppendFixed32NoTag(self, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self._stream.AppendLittleEndian32(unsigned_value)
-
- def AppendFixed64NoTag(self, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self._stream.AppendLittleEndian64(unsigned_value)
-
- def AppendSFixed32NoTag(self, value):
- """Appends a signed 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- sign = (value & 0x80000000) and -1 or 0
- if value >> 32 != sign:
- raise message.EncodeError('SFixed32 out of range: %d' % value)
- self._stream.AppendLittleEndian32(value & 0xffffffff)
-
- def AppendSFixed64NoTag(self, value):
- """Appends a signed 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- sign = (value & 0x8000000000000000) and -1 or 0
- if value >> 64 != sign:
- raise message.EncodeError('SFixed64 out of range: %d' % value)
- self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
-
- def AppendFloatNoTag(self, value):
- """Appends a floating-point number to our buffer."""
- self._stream.AppendRawBytes(
- struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
-
- def AppendDoubleNoTag(self, value):
- """Appends a double-precision floating-point number to our buffer."""
- self._stream.AppendRawBytes(
- struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
-
- def AppendBoolNoTag(self, value):
- """Appends a boolean to our buffer."""
- self.AppendInt32NoTag(value)
-
- def AppendEnumNoTag(self, value):
- """Appends an enum value to our buffer."""
- self.AppendInt32NoTag(value)
-
-
- # All the Append*() methods below first append a tag+type pair to the buffer
- # before appending the specified value.
-
- def AppendInt32(self, field_number, value):
- """Appends a 32-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendInt32NoTag(value)
-
- def AppendInt64(self, field_number, value):
- """Appends a 64-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendInt64NoTag(value)
-
- def AppendUInt32(self, field_number, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendUInt32NoTag(unsigned_value)
-
- def AppendUInt64(self, field_number, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendUInt64NoTag(unsigned_value)
-
- def AppendSInt32(self, field_number, value):
- """Appends a 32-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendSInt32NoTag(value)
-
- def AppendSInt64(self, field_number, value):
- """Appends a 64-bit integer to our buffer, zigzag-encoded and then
- varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
- self.AppendSInt64NoTag(value)
-
- def AppendFixed32(self, field_number, unsigned_value):
- """Appends an unsigned 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendFixed32NoTag(unsigned_value)
-
- def AppendFixed64(self, field_number, unsigned_value):
- """Appends an unsigned 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendFixed64NoTag(unsigned_value)
-
- def AppendSFixed32(self, field_number, value):
- """Appends a signed 32-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendSFixed32NoTag(value)
-
- def AppendSFixed64(self, field_number, value):
- """Appends a signed 64-bit integer to our buffer, in little-endian
- byte-order.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendSFixed64NoTag(value)
-
- def AppendFloat(self, field_number, value):
- """Appends a floating-point number to our buffer."""
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
- self.AppendFloatNoTag(value)
-
- def AppendDouble(self, field_number, value):
- """Appends a double-precision floating-point number to our buffer."""
- self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
- self.AppendDoubleNoTag(value)
-
- def AppendBool(self, field_number, value):
- """Appends a boolean to our buffer."""
- self.AppendInt32(field_number, value)
-
- def AppendEnum(self, field_number, value):
- """Appends an enum value to our buffer."""
- self.AppendInt32(field_number, value)
-
- def AppendString(self, field_number, value):
- """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer,
- with the length varint-encoded.
- """
- self.AppendBytes(field_number, value.encode('utf-8'))
-
- def AppendBytes(self, field_number, value):
- """Appends a length-prefixed sequence of bytes to our buffer, with the
- length varint-encoded.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- self._stream.AppendVarUInt32(len(value))
- self._stream.AppendRawBytes(value)
-
- # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to
- # avoid the extra string copy here. We can do so if we widen the Message
- # interface to be able to serialize to a stream in addition to a string. The
- # challenge when thinking ahead to the Python/C API implementation of Message
- # is finding a stream-like Python thing to which we can write raw bytes
- # from C. I'm not sure such a thing exists(?). (array.array is pretty much
- # what we want, but it's not directly exposed in the Python/C API).
-
- def AppendGroup(self, field_number, group):
- """Appends a group to our buffer.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
- self._stream.AppendRawBytes(group.SerializeToString())
- self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
-
- def AppendMessage(self, field_number, msg):
- """Appends a nested message to our buffer.
- """
- self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- self._stream.AppendVarUInt32(msg.ByteSize())
- self._stream.AppendRawBytes(msg.SerializeToString())
-
- def AppendMessageSetItem(self, field_number, msg):
- """Appends an item using the message set wire format.
-
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
+
+
+def _VarintSize(value):
+ """Compute the size of a varint value."""
+ if value <= 0x7f: return 1
+ if value <= 0x3fff: return 2
+ if value <= 0x1fffff: return 3
+ if value <= 0xfffffff: return 4
+ if value <= 0x7ffffffff: return 5
+ if value <= 0x3ffffffffff: return 6
+ if value <= 0x1ffffffffffff: return 7
+ if value <= 0xffffffffffffff: return 8
+ if value <= 0x7fffffffffffffff: return 9
+ return 10
+
+
+def _SignedVarintSize(value):
+ """Compute the size of a signed varint value."""
+ if value < 0: return 10
+ if value <= 0x7f: return 1
+ if value <= 0x3fff: return 2
+ if value <= 0x1fffff: return 3
+ if value <= 0xfffffff: return 4
+ if value <= 0x7ffffffff: return 5
+ if value <= 0x3ffffffffff: return 6
+ if value <= 0x1ffffffffffff: return 7
+ if value <= 0xffffffffffffff: return 8
+ if value <= 0x7fffffffffffffff: return 9
+ return 10
+
+
+def _TagSize(field_number):
+ """Returns the number of bytes required to serialize a tag with this field
+ number."""
+ # Just pass in type 0, since the type won't affect the tag+type size.
+ return _VarintSize(wire_format.PackTag(field_number, 0))
+
+
+# --------------------------------------------------------------------
+# In this section we define some generic sizers. Each of these functions
+# takes parameters specific to a particular field type, e.g. int32 or fixed64.
+# It returns another function which in turn takes parameters specific to a
+# particular field, e.g. the field number and whether it is repeated or packed.
+# Look at the next section to see how these are used.
+
+
+def _SimpleSizer(compute_value_size):
+ """A sizer which uses the function compute_value_size to compute the size of
+ each value. Typically compute_value_size is _VarintSize."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = 0
+ for element in value:
+ result += compute_value_size(element)
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += compute_value_size(element)
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + compute_value_size(value)
+ return FieldSize
+
+ return SpecificSizer
+
+
+def _ModifiedSizer(compute_value_size, modify_value):
+ """Like SimpleSizer, but modify_value is invoked on each value before it is
+ passed to compute_value_size. modify_value is typically ZigZagEncode."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = 0
+ for element in value:
+ result += compute_value_size(modify_value(element))
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += compute_value_size(modify_value(element))
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + compute_value_size(modify_value(value))
+ return FieldSize
+
+ return SpecificSizer
+
+
+def _FixedSizer(value_size):
+ """Like _SimpleSizer except for a fixed-size field. The input is the size
+ of one value."""
+
+ def SpecificSizer(field_number, is_repeated, is_packed):
+ tag_size = _TagSize(field_number)
+ if is_packed:
+ local_VarintSize = _VarintSize
+ def PackedFieldSize(value):
+ result = len(value) * value_size
+ return result + local_VarintSize(result) + tag_size
+ return PackedFieldSize
+ elif is_repeated:
+ element_size = value_size + tag_size
+ def RepeatedFieldSize(value):
+ return len(value) * element_size
+ return RepeatedFieldSize
+ else:
+ field_size = value_size + tag_size
+ def FieldSize(value):
+ return field_size
+ return FieldSize
+
+ return SpecificSizer
+
+
+# ====================================================================
+# Here we declare a sizer constructor for each field type. Each "sizer
+# constructor" is a function that takes (field_number, is_repeated, is_packed)
+# as parameters and returns a sizer, which in turn takes a field value as
+# a parameter and returns its encoded size.
+
+
+Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
+
+UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
+
+SInt32Sizer = SInt64Sizer = _ModifiedSizer(
+ _SignedVarintSize, wire_format.ZigZagEncode)
+
+Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
+Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
+
+BoolSizer = _FixedSizer(1)
+
+
+def StringSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a string field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = local_len(element.encode('utf-8'))
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = local_len(value.encode('utf-8'))
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+def BytesSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a bytes field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = local_len(element)
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = local_len(value)
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+def GroupSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a group field."""
+
+ tag_size = _TagSize(field_number) * 2
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ result += element.ByteSize()
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ return tag_size + value.ByteSize()
+ return FieldSize
+
+
+def MessageSizer(field_number, is_repeated, is_packed):
+ """Returns a sizer for a message field."""
+
+ tag_size = _TagSize(field_number)
+ local_VarintSize = _VarintSize
+ assert not is_packed
+ if is_repeated:
+ def RepeatedFieldSize(value):
+ result = tag_size * len(value)
+ for element in value:
+ l = element.ByteSize()
+ result += local_VarintSize(l) + l
+ return result
+ return RepeatedFieldSize
+ else:
+ def FieldSize(value):
+ l = value.ByteSize()
+ return tag_size + local_VarintSize(l) + l
+ return FieldSize
+
+
+# --------------------------------------------------------------------
+# MessageSet is special.
+
+
+def MessageSetItemSizer(field_number):
+ """Returns a sizer for extensions of MessageSet.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
+ }
+ }
+ """
+ static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
+ _TagSize(3))
+ local_VarintSize = _VarintSize
+
+ def FieldSize(value):
+ l = value.ByteSize()
+ return static_size + local_VarintSize(l) + l
+
+ return FieldSize
+
+
+# ====================================================================
+# Encoders!
+
+
+def _VarintEncoder():
+ """Return an encoder for a basic varint value (does not include tag)."""
+
+ local_chr = chr
+ def EncodeVarint(write, value):
+ bits = value & 0x7f
+ value >>= 7
+ while value:
+ write(local_chr(0x80|bits))
+ bits = value & 0x7f
+ value >>= 7
+ return write(local_chr(bits))
+
+ return EncodeVarint
+
+
+def _SignedVarintEncoder():
+ """Return an encoder for a basic signed varint value (does not include
+ tag)."""
+
+ local_chr = chr
+ def EncodeSignedVarint(write, value):
+ if value < 0:
+ value += (1 << 64)
+ bits = value & 0x7f
+ value >>= 7
+ while value:
+ write(local_chr(0x80|bits))
+ bits = value & 0x7f
+ value >>= 7
+ return write(local_chr(bits))
+
+ return EncodeSignedVarint
+
+
+_EncodeVarint = _VarintEncoder()
+_EncodeSignedVarint = _SignedVarintEncoder()
+
+
+def _VarintBytes(value):
+ """Encode the given integer as a varint and return the bytes. This is only
+ called at startup time so it doesn't need to be fast."""
+
+ pieces = []
+ _EncodeVarint(pieces.append, value)
+ return "".join(pieces)
+
+
+def TagBytes(field_number, wire_type):
+ """Encode the given tag and return the bytes. Only called at startup."""
+
+ return _VarintBytes(wire_format.PackTag(field_number, wire_type))
+
+# --------------------------------------------------------------------
+# As with sizers (see above), we have a number of common encoder
+# implementations.
+
+
+def _SimpleEncoder(wire_type, encode_value, compute_value_size):
+ """Return a constructor for an encoder for fields of a particular type.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ encode_value: A function which encodes an individual value, e.g.
+ _EncodeVarint().
+ compute_value_size: A function which computes the size of an individual
+ value, e.g. _VarintSize().
+ """
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ size = 0
+ for element in value:
+ size += compute_value_size(element)
+ local_EncodeVarint(write, size)
+ for element in value:
+ encode_value(write, element)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ encode_value(write, element)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return encode_value(write, value)
+ return EncodeField
+
+ return SpecificEncoder
+
+
+def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
+ """Like SimpleEncoder but additionally invokes modify_value on every value
+ before passing it to encode_value. Usually modify_value is ZigZagEncode."""
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ size = 0
+ for element in value:
+ size += compute_value_size(modify_value(element))
+ local_EncodeVarint(write, size)
+ for element in value:
+ encode_value(write, modify_value(element))
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ encode_value(write, modify_value(element))
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return encode_value(write, modify_value(value))
+ return EncodeField
+
+ return SpecificEncoder
+
+
+def _StructPackEncoder(wire_type, format):
+ """Return a constructor for an encoder for a fixed-width field.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ format: The format string to pass to struct.pack().
+ """
+
+ value_size = struct.calcsize(format)
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ local_struct_pack = struct.pack
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value) * value_size)
+ for element in value:
+ write(local_struct_pack(format, element))
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ write(local_struct_pack(format, element))
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ return write(local_struct_pack(format, value))
+ return EncodeField
+
+ return SpecificEncoder
+
+
+# ====================================================================
+# Here we declare an encoder constructor for each field type. These work
+# very similarly to sizer constructors, described earlier.
+
+
+Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
+
+UInt32Encoder = UInt64Encoder = _SimpleEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
+
+SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
+ wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
+ wire_format.ZigZagEncode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
+
+
+def BoolEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a boolean field."""
+
+ false_byte = chr(0)
+ true_byte = chr(1)
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value))
+ for element in value:
+ if element:
+ write(true_byte)
+ else:
+ write(false_byte)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ if element:
+ write(true_byte)
+ else:
+ write(false_byte)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ if value:
+ return write(true_byte)
+ return write(false_byte)
+ return EncodeField
+
+
+def StringEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a string field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ encoded = element.encode('utf-8')
+ write(tag)
+ local_EncodeVarint(write, local_len(encoded))
+ write(encoded)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ encoded = value.encode('utf-8')
+ write(tag)
+ local_EncodeVarint(write, local_len(encoded))
+ return write(encoded)
+ return EncodeField
+
+
+def BytesEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a bytes field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ local_len = len
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag)
+ local_EncodeVarint(write, local_len(element))
+ write(element)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(tag)
+ local_EncodeVarint(write, local_len(value))
+ return write(value)
+ return EncodeField
+
+
+def GroupEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a group field."""
+
+ start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
+ end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(start_tag)
+ element._InternalSerialize(write)
+ write(end_tag)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(start_tag)
+ value._InternalSerialize(write)
+ return write(end_tag)
+ return EncodeField
+
+
+def MessageEncoder(field_number, is_repeated, is_packed):
+ """Returns an encoder for a message field."""
+
+ tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ assert not is_packed
+ if is_repeated:
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag)
+ local_EncodeVarint(write, element.ByteSize())
+ element._InternalSerialize(write)
+ return EncodeRepeatedField
+ else:
+ def EncodeField(write, value):
+ write(tag)
+ local_EncodeVarint(write, value.ByteSize())
+ return value._InternalSerialize(write)
+ return EncodeField
+
+
+# --------------------------------------------------------------------
+# As before, MessageSet is special.
+
+
+def MessageSetItemEncoder(field_number):
+ """Encoder for extensions of MessageSet.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
}
- """
- self.AppendTag(1, wire_format.WIRETYPE_START_GROUP)
- self.AppendInt32(2, field_number)
- self.AppendMessage(3, msg)
- self.AppendTag(1, wire_format.WIRETYPE_END_GROUP)
-
- def AppendTag(self, field_number, wire_type):
- """Appends a tag containing field number and wire type information."""
- self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type))
+ }
+ """
+ start_bytes = "".join([
+ TagBytes(1, wire_format.WIRETYPE_START_GROUP),
+ TagBytes(2, wire_format.WIRETYPE_VARINT),
+ _VarintBytes(field_number),
+ TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
+ end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+ local_EncodeVarint = _EncodeVarint
+
+ def EncodeField(write, value):
+ write(start_bytes)
+ local_EncodeVarint(write, value.ByteSize())
+ value._InternalSerialize(write)
+ return write(end_bytes)
+
+ return EncodeField