aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf
diff options
context:
space:
mode:
authorGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-07-29 01:13:20 +0000
committerGravatar kenton@google.com <kenton@google.com@630680e5-0e50-0410-840e-4b1c322b438d>2009-07-29 01:13:20 +0000
commit80b1d62bfcea65c59e2160da71dad84b1bd19cef (patch)
tree5423b830c53174fec83a7ea01ff0877e11c1ddb6 /python/google/protobuf
parentd2fd0638c309113ccae3731a58e30419f522269a (diff)
Submit recent changes from internal branch, including "lite mode" for
C++ and Java. See CHANGES.txt for more details.
Diffstat (limited to 'python/google/protobuf')
-rwxr-xr-xpython/google/protobuf/internal/containers.py8
-rwxr-xr-xpython/google/protobuf/internal/decoder.py4
-rwxr-xr-xpython/google/protobuf/internal/decoder_test.py18
-rwxr-xr-xpython/google/protobuf/internal/encoder.py6
-rwxr-xr-xpython/google/protobuf/internal/encoder_test.py12
-rwxr-xr-xpython/google/protobuf/internal/message_test.py53
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py88
-rwxr-xr-xpython/google/protobuf/internal/test_util.py198
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py281
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py6
-rwxr-xr-xpython/google/protobuf/internal/wire_format.py2
-rwxr-xr-xpython/google/protobuf/message.py3
-rwxr-xr-xpython/google/protobuf/reflection.py66
-rwxr-xr-xpython/google/protobuf/service.py5
-rwxr-xr-xpython/google/protobuf/text_format.py527
15 files changed, 1235 insertions, 42 deletions
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index fa1e3402..d8a825df 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -112,9 +112,11 @@ class RepeatedScalarFieldContainer(BaseContainer):
return
orig_empty = len(self._values) == 0
+ new_values = []
for elem in elem_seq:
self._type_checker.CheckValue(elem)
- self._values.extend(elem_seq)
+ new_values.append(elem)
+ self._values.extend(new_values)
self._message_listener.ByteSizeDirty()
if orig_empty:
self._message_listener.TransitionToNonempty()
@@ -139,9 +141,11 @@ class RepeatedScalarFieldContainer(BaseContainer):
def __setslice__(self, start, stop, values):
"""Sets the subset of items from between the specified indices."""
+ new_values = []
for value in values:
self._type_checker.CheckValue(value)
- self._values[start:stop] = list(values)
+ new_values.append(value)
+ self._values[start:stop] = new_values
self._message_listener.ByteSizeDirty()
def __delitem__(self, key):
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 0bee6101..83d6fe0c 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -135,12 +135,12 @@ class Decoder(object):
def ReadFloat(self):
"""Reads and returns a 4-byte floating-point number."""
serialized = self._stream.ReadBytes(4)
- return struct.unpack('f', serialized)[0]
+ return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0]
def ReadDouble(self):
"""Reads and returns an 8-byte floating-point number."""
serialized = self._stream.ReadBytes(8)
- return struct.unpack('d', serialized)[0]
+ return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0]
def ReadBool(self):
"""Reads and returns a bool."""
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
index 9bae888c..98e46472 100755
--- a/python/google/protobuf/internal/decoder_test.py
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -36,12 +36,12 @@ __author__ = 'robinson@google.com (Will Robinson)'
import struct
import unittest
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import encoder
from google.protobuf.internal import decoder
-import logging
+from google.protobuf.internal import encoder
from google.protobuf.internal import input_stream
+from google.protobuf.internal import wire_format
from google.protobuf import message
+import logging
import mox
@@ -110,6 +110,10 @@ class DecoderTest(unittest.TestCase):
self.mox.VerifyAll()
self.mox.ResetAll()
+ VAL = 1.125 # Perfectly representable as a float (no rounding error).
+ LITTLE_FLOAT_VAL = '\x00\x00\x90?'
+ LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
+
def testReadScalars(self):
test_string = 'I can feel myself getting sutpider.'
scalar_tests = [
@@ -125,10 +129,10 @@ class DecoderTest(unittest.TestCase):
'ReadLittleEndian32', long(0xffffffff)],
['sfixed64', decoder.Decoder.ReadSFixed64, long(-1),
'ReadLittleEndian64', 0xffffffffffffffff],
- ['float', decoder.Decoder.ReadFloat, 0.0,
- 'ReadBytes', struct.pack('f', 0.0), 4],
- ['double', decoder.Decoder.ReadDouble, 0.0,
- 'ReadBytes', struct.pack('d', 0.0), 8],
+ ['float', decoder.Decoder.ReadFloat, self.VAL,
+ 'ReadBytes', self.LITTLE_FLOAT_VAL, 4],
+ ['double', decoder.Decoder.ReadDouble, self.VAL,
+ 'ReadBytes', self.LITTLE_DOUBLE_VAL, 8],
['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1],
['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23],
['string', decoder.Decoder.ReadString,
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index eed8c8bd..3ec3b2b1 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -123,11 +123,13 @@ class Encoder(object):
def AppendFloatNoTag(self, value):
"""Appends a floating-point number to our buffer."""
- self._stream.AppendRawBytes(struct.pack('f', value))
+ self._stream.AppendRawBytes(
+ struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
def AppendDoubleNoTag(self, value):
"""Appends a double-precision floating-point number to our buffer."""
- self._stream.AppendRawBytes(struct.pack('d', value))
+ self._stream.AppendRawBytes(
+ struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
def AppendBoolNoTag(self, value):
"""Appends a boolean to our buffer."""
diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py
index 83a21c39..bf75ea80 100755
--- a/python/google/protobuf/internal/encoder_test.py
+++ b/python/google/protobuf/internal/encoder_test.py
@@ -123,6 +123,10 @@ class EncoderTest(unittest.TestCase):
self.mox.VerifyAll()
self.mox.ResetAll()
+ VAL = 1.125 # Perfectly representable as a float (no rounding error).
+ LITTLE_FLOAT_VAL = '\x00\x00\x90?'
+ LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
+
def testAppendScalars(self):
utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'
utf8_string = unicode(utf8_bytes, 'utf-8')
@@ -144,9 +148,9 @@ class EncoderTest(unittest.TestCase):
['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64',
wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff],
['float', self.encoder.AppendFloat, 'AppendRawBytes',
- wire_format.WIRETYPE_FIXED32, 0.0, struct.pack('f', 0.0)],
+ wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL],
['double', self.encoder.AppendDouble, 'AppendRawBytes',
- wire_format.WIRETYPE_FIXED64, 0.0, struct.pack('d', 0.0)],
+ wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL],
['bool', self.encoder.AppendBool, 'AppendVarint32',
wire_format.WIRETYPE_VARINT, False],
['enum', self.encoder.AppendEnum, 'AppendVarint32',
@@ -185,9 +189,9 @@ class EncoderTest(unittest.TestCase):
['sfixed64', self.encoder.AppendSFixed64NoTag,
'AppendLittleEndian64', None, 0],
['float', self.encoder.AppendFloatNoTag,
- 'AppendRawBytes', None, 0.0, struct.pack('f', 0.0)],
+ 'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL],
['double', self.encoder.AppendDoubleNoTag,
- 'AppendRawBytes', None, 0.0, struct.pack('d', 0.0)],
+ 'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL],
['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
['sint32', self.encoder.AppendSInt32NoTag,
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
new file mode 100755
index 00000000..df344cf0
--- /dev/null
+++ b/python/google/protobuf/internal/message_test.py
@@ -0,0 +1,53 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests python protocol buffers against the golden message."""
+
+__author__ = 'gps@google.com (Gregory P. Smith)'
+
+import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import test_util
+
+
+class MessageTest(test_util.GoldenMessageTestCase):
+
+ def testGoldenMessage(self):
+ golden_data = test_util.GoldenFile('golden_message').read()
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ self.ExpectAllFieldsSet(golden_message)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index e2da769a..86101774 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -232,13 +232,14 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_string.extend(['foo', 'bar'])
proto.repeated_string.extend([])
proto.repeated_string.append('baz')
+ proto.repeated_string.extend(str(x) for x in xrange(2))
proto.optional_int32 = 21
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
(proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
(proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
(proto.DESCRIPTOR.fields_by_name['repeated_string' ],
- ['foo', 'bar', 'baz']) ],
+ ['foo', 'bar', 'baz', '0', '1']) ],
proto.ListFields())
def testSingularListExtensions(self):
@@ -447,6 +448,10 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
+ # Test slice assignment with an iterator
+ proto.repeated_int32[1:4] = (i for i in xrange(3))
+ self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
+
# Test slice assignment.
proto.repeated_int32[1:4] = [35, 40, 45]
self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
@@ -1739,13 +1744,14 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
- def testSerializedAllPackedFields(self):
+ def testSerializeAllPackedFields(self):
first_proto = unittest_pb2.TestPackedTypes()
second_proto = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(first_proto)
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ bytes_read = second_proto.MergeFromString(serialized)
+ self.assertEqual(second_proto.ByteSize(), bytes_read)
self.assertEqual(first_proto, second_proto)
def testSerializeAllPackedExtensions(self):
@@ -1753,7 +1759,8 @@ class SerializationTest(unittest.TestCase):
second_proto = unittest_pb2.TestPackedExtensions()
test_util.SetAllPackedExtensions(first_proto)
serialized = first_proto.SerializeToString()
- second_proto.MergeFromString(serialized)
+ bytes_read = second_proto.MergeFromString(serialized)
+ self.assertEqual(second_proto.ByteSize(), bytes_read)
self.assertEqual(first_proto, second_proto)
def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
@@ -1838,6 +1845,79 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
51)
+ def testInitKwargs(self):
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=1,
+ optional_string='foo',
+ optional_bool=True,
+ optional_bytes='bar',
+ optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
+ optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
+ optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
+ optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
+ repeated_int32=[1, 2, 3])
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('optional_int32'))
+ self.assertTrue(proto.HasField('optional_string'))
+ self.assertTrue(proto.HasField('optional_bool'))
+ self.assertTrue(proto.HasField('optional_bytes'))
+ self.assertTrue(proto.HasField('optional_nested_message'))
+ self.assertTrue(proto.HasField('optional_foreign_message'))
+ self.assertTrue(proto.HasField('optional_nested_enum'))
+ self.assertTrue(proto.HasField('optional_foreign_enum'))
+ self.assertEqual(1, proto.optional_int32)
+ self.assertEqual('foo', proto.optional_string)
+ self.assertEqual(True, proto.optional_bool)
+ self.assertEqual('bar', proto.optional_bytes)
+ self.assertEqual(1, proto.optional_nested_message.bb)
+ self.assertEqual(1, proto.optional_foreign_message.c)
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ proto.optional_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
+ self.assertEqual([1, 2, 3], proto.repeated_int32)
+
+ def testInitArgsUnknownFieldName(self):
+ def InitalizeEmptyMessageWithExtraKeywordArg():
+ unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
+ self._CheckRaises(ValueError,
+ InitalizeEmptyMessageWithExtraKeywordArg,
+ 'Protocol message has no "unknown" field.')
+
+ def testInitRequiredKwargs(self):
+ proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('a'))
+ self.assertTrue(proto.HasField('b'))
+ self.assertTrue(proto.HasField('c'))
+ self.assertTrue(not proto.HasField('dummy2'))
+ self.assertEqual(1, proto.a)
+ self.assertEqual(1, proto.b)
+ self.assertEqual(1, proto.c)
+
+ def testInitRequiredForeignKwargs(self):
+ proto = unittest_pb2.TestRequiredForeign(
+ optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('optional_message'))
+ self.assertTrue(proto.optional_message.IsInitialized())
+ self.assertTrue(proto.optional_message.HasField('a'))
+ self.assertTrue(proto.optional_message.HasField('b'))
+ self.assertTrue(proto.optional_message.HasField('c'))
+ self.assertTrue(not proto.optional_message.HasField('dummy2'))
+ self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
+ proto.optional_message)
+ self.assertEqual(1, proto.optional_message.a)
+ self.assertEqual(1, proto.optional_message.b)
+ self.assertEqual(1, proto.optional_message.c)
+
+ def testInitRepeatedKwargs(self):
+ proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
+ self.assertTrue(proto.IsInitialized())
+ self.assertEqual(1, proto.repeated_int32[0])
+ self.assertEqual(2, proto.repeated_int32[1])
+ self.assertEqual(3, proto.repeated_int32[2])
+
+
class OptionsTest(unittest.TestCase):
def testMessageOptions(self):
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 2d50bc4a..1a0da552 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -38,6 +38,7 @@ __author__ = 'robinson@google.com (Will Robinson)'
import os.path
+import unittest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
@@ -351,6 +352,200 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized):
if expected != serialized:
raise ValueError('Expected %r, found %r' % (expected, serialized))
+
+class GoldenMessageTestCase(unittest.TestCase):
+ """This adds methods to TestCase useful for verifying our Golden Message."""
+
+ def ExpectAllFieldsSet(self, message):
+ """Check all fields for correct values have after Set*Fields() is called."""
+ self.assertTrue(message.HasField('optional_int32'))
+ self.assertTrue(message.HasField('optional_int64'))
+ self.assertTrue(message.HasField('optional_uint32'))
+ self.assertTrue(message.HasField('optional_uint64'))
+ self.assertTrue(message.HasField('optional_sint32'))
+ self.assertTrue(message.HasField('optional_sint64'))
+ self.assertTrue(message.HasField('optional_fixed32'))
+ self.assertTrue(message.HasField('optional_fixed64'))
+ self.assertTrue(message.HasField('optional_sfixed32'))
+ self.assertTrue(message.HasField('optional_sfixed64'))
+ self.assertTrue(message.HasField('optional_float'))
+ self.assertTrue(message.HasField('optional_double'))
+ self.assertTrue(message.HasField('optional_bool'))
+ self.assertTrue(message.HasField('optional_string'))
+ self.assertTrue(message.HasField('optional_bytes'))
+
+ self.assertTrue(message.HasField('optionalgroup'))
+ self.assertTrue(message.HasField('optional_nested_message'))
+ self.assertTrue(message.HasField('optional_foreign_message'))
+ self.assertTrue(message.HasField('optional_import_message'))
+
+ self.assertTrue(message.optionalgroup.HasField('a'))
+ self.assertTrue(message.optional_nested_message.HasField('bb'))
+ self.assertTrue(message.optional_foreign_message.HasField('c'))
+ self.assertTrue(message.optional_import_message.HasField('d'))
+
+ self.assertTrue(message.HasField('optional_nested_enum'))
+ self.assertTrue(message.HasField('optional_foreign_enum'))
+ self.assertTrue(message.HasField('optional_import_enum'))
+
+ self.assertTrue(message.HasField('optional_string_piece'))
+ self.assertTrue(message.HasField('optional_cord'))
+
+ self.assertEqual(101, message.optional_int32)
+ self.assertEqual(102, message.optional_int64)
+ self.assertEqual(103, message.optional_uint32)
+ self.assertEqual(104, message.optional_uint64)
+ self.assertEqual(105, message.optional_sint32)
+ self.assertEqual(106, message.optional_sint64)
+ self.assertEqual(107, message.optional_fixed32)
+ self.assertEqual(108, message.optional_fixed64)
+ self.assertEqual(109, message.optional_sfixed32)
+ self.assertEqual(110, message.optional_sfixed64)
+ self.assertEqual(111, message.optional_float)
+ self.assertEqual(112, message.optional_double)
+ self.assertEqual(True, message.optional_bool)
+ self.assertEqual('115', message.optional_string)
+ self.assertEqual('116', message.optional_bytes)
+
+ self.assertEqual(117, message.optionalgroup.a);
+ self.assertEqual(118, message.optional_nested_message.bb)
+ self.assertEqual(119, message.optional_foreign_message.c)
+ self.assertEqual(120, message.optional_import_message.d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum)
+ self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
+
+ # -----------------------------------------------------------------
+
+ self.assertEqual(2, len(message.repeated_int32))
+ self.assertEqual(2, len(message.repeated_int64))
+ self.assertEqual(2, len(message.repeated_uint32))
+ self.assertEqual(2, len(message.repeated_uint64))
+ self.assertEqual(2, len(message.repeated_sint32))
+ self.assertEqual(2, len(message.repeated_sint64))
+ self.assertEqual(2, len(message.repeated_fixed32))
+ self.assertEqual(2, len(message.repeated_fixed64))
+ self.assertEqual(2, len(message.repeated_sfixed32))
+ self.assertEqual(2, len(message.repeated_sfixed64))
+ self.assertEqual(2, len(message.repeated_float))
+ self.assertEqual(2, len(message.repeated_double))
+ self.assertEqual(2, len(message.repeated_bool))
+ self.assertEqual(2, len(message.repeated_string))
+ self.assertEqual(2, len(message.repeated_bytes))
+
+ self.assertEqual(2, len(message.repeatedgroup))
+ self.assertEqual(2, len(message.repeated_nested_message))
+ self.assertEqual(2, len(message.repeated_foreign_message))
+ self.assertEqual(2, len(message.repeated_import_message))
+ self.assertEqual(2, len(message.repeated_nested_enum))
+ self.assertEqual(2, len(message.repeated_foreign_enum))
+ self.assertEqual(2, len(message.repeated_import_enum))
+
+ self.assertEqual(2, len(message.repeated_string_piece))
+ self.assertEqual(2, len(message.repeated_cord))
+
+ self.assertEqual(201, message.repeated_int32[0])
+ self.assertEqual(202, message.repeated_int64[0])
+ self.assertEqual(203, message.repeated_uint32[0])
+ self.assertEqual(204, message.repeated_uint64[0])
+ self.assertEqual(205, message.repeated_sint32[0])
+ self.assertEqual(206, message.repeated_sint64[0])
+ self.assertEqual(207, message.repeated_fixed32[0])
+ self.assertEqual(208, message.repeated_fixed64[0])
+ self.assertEqual(209, message.repeated_sfixed32[0])
+ self.assertEqual(210, message.repeated_sfixed64[0])
+ self.assertEqual(211, message.repeated_float[0])
+ self.assertEqual(212, message.repeated_double[0])
+ self.assertEqual(True, message.repeated_bool[0])
+ self.assertEqual('215', message.repeated_string[0])
+ self.assertEqual('216', message.repeated_bytes[0])
+
+ self.assertEqual(217, message.repeatedgroup[0].a)
+ self.assertEqual(218, message.repeated_nested_message[0].bb)
+ self.assertEqual(219, message.repeated_foreign_message[0].c)
+ self.assertEqual(220, message.repeated_import_message[0].d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[0])
+ self.assertEqual(unittest_pb2.FOREIGN_BAR,
+ message.repeated_foreign_enum[0])
+ self.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
+
+ self.assertEqual(301, message.repeated_int32[1])
+ self.assertEqual(302, message.repeated_int64[1])
+ self.assertEqual(303, message.repeated_uint32[1])
+ self.assertEqual(304, message.repeated_uint64[1])
+ self.assertEqual(305, message.repeated_sint32[1])
+ self.assertEqual(306, message.repeated_sint64[1])
+ self.assertEqual(307, message.repeated_fixed32[1])
+ self.assertEqual(308, message.repeated_fixed64[1])
+ self.assertEqual(309, message.repeated_sfixed32[1])
+ self.assertEqual(310, message.repeated_sfixed64[1])
+ self.assertEqual(311, message.repeated_float[1])
+ self.assertEqual(312, message.repeated_double[1])
+ self.assertEqual(False, message.repeated_bool[1])
+ self.assertEqual('315', message.repeated_string[1])
+ self.assertEqual('316', message.repeated_bytes[1])
+
+ self.assertEqual(317, message.repeatedgroup[1].a)
+ self.assertEqual(318, message.repeated_nested_message[1].bb)
+ self.assertEqual(319, message.repeated_foreign_message[1].c)
+ self.assertEqual(320, message.repeated_import_message[1].d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.repeated_nested_enum[1])
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ message.repeated_foreign_enum[1])
+ self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
+
+ # -----------------------------------------------------------------
+
+ self.assertTrue(message.HasField('default_int32'))
+ self.assertTrue(message.HasField('default_int64'))
+ self.assertTrue(message.HasField('default_uint32'))
+ self.assertTrue(message.HasField('default_uint64'))
+ self.assertTrue(message.HasField('default_sint32'))
+ self.assertTrue(message.HasField('default_sint64'))
+ self.assertTrue(message.HasField('default_fixed32'))
+ self.assertTrue(message.HasField('default_fixed64'))
+ self.assertTrue(message.HasField('default_sfixed32'))
+ self.assertTrue(message.HasField('default_sfixed64'))
+ self.assertTrue(message.HasField('default_float'))
+ self.assertTrue(message.HasField('default_double'))
+ self.assertTrue(message.HasField('default_bool'))
+ self.assertTrue(message.HasField('default_string'))
+ self.assertTrue(message.HasField('default_bytes'))
+
+ self.assertTrue(message.HasField('default_nested_enum'))
+ self.assertTrue(message.HasField('default_foreign_enum'))
+ self.assertTrue(message.HasField('default_import_enum'))
+
+ self.assertEqual(401, message.default_int32)
+ self.assertEqual(402, message.default_int64)
+ self.assertEqual(403, message.default_uint32)
+ self.assertEqual(404, message.default_uint64)
+ self.assertEqual(405, message.default_sint32)
+ self.assertEqual(406, message.default_sint64)
+ self.assertEqual(407, message.default_fixed32)
+ self.assertEqual(408, message.default_fixed64)
+ self.assertEqual(409, message.default_sfixed32)
+ self.assertEqual(410, message.default_sfixed64)
+ self.assertEqual(411, message.default_float)
+ self.assertEqual(412, message.default_double)
+ self.assertEqual(False, message.default_bool)
+ self.assertEqual('415', message.default_string)
+ self.assertEqual('416', message.default_bytes)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum)
+ self.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
+
def GoldenFile(filename):
"""Finds the given golden file and returns a file object representing it."""
@@ -359,7 +554,8 @@ def GoldenFile(filename):
while os.path.exists(path):
if os.path.exists(os.path.join(path, 'src/google/protobuf')):
# Found it. Load the golden file from the testdata directory.
- return file(os.path.join(path, 'src/google/protobuf/testdata', filename))
+ full_path = os.path.join(path, 'src/google/protobuf/testdata', filename)
+ return open(full_path, 'rb')
path = os.path.join(path, '..')
raise RuntimeError(
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 871590e7..0cf27186 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -42,11 +42,16 @@ from google.protobuf.internal import test_util
from google.protobuf import unittest_pb2
from google.protobuf import unittest_mset_pb2
-class TextFormatTest(unittest.TestCase):
- def CompareToGoldenFile(self, text, golden_filename):
+
+class TextFormatTest(test_util.GoldenMessageTestCase):
+ def ReadGolden(self, golden_filename):
f = test_util.GoldenFile(golden_filename)
golden_lines = f.readlines()
f.close()
+ return golden_lines
+
+ def CompareToGoldenFile(self, text, golden_filename):
+ golden_lines = self.ReadGolden(golden_filename)
self.CompareToGoldenLines(text, golden_lines)
def CompareToGoldenText(self, text, golden_text):
@@ -117,6 +122,276 @@ class TextFormatTest(unittest.TestCase):
return text.replace('e+0','e+').replace('e+0','e+') \
.replace('e-0','e-').replace('e-0','e-')
+ def testMergeGolden(self):
+ golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
+ parsed_message = unittest_pb2.TestAllTypes()
+ text_format.Merge(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEquals(message, parsed_message)
+
+ def testMergeGoldenExtensions(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_extensions_data.txt'))
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Merge(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.assertEquals(message, parsed_message)
+
+ def testMergeAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllTypes()
+ text_format.Merge(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+ self.ExpectAllFieldsSet(message)
+
+ def testMergeAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Merge(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testMergeMessageSet(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_uint64: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Merge(text, message)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEquals(23, message.message_set.Extensions[ext1].i)
+ self.assertEquals('foo', message.message_set.Extensions[ext2].str)
+
+ def testMergeExotic(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_int64: -9223372036854775808\n'
+ 'repeated_uint64: 18446744073709551615\n'
+ 'repeated_double: 123.456\n'
+ 'repeated_double: 1.23e+22\n'
+ 'repeated_double: 1.23e-18\n'
+ 'repeated_string: \n'
+ '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
+ text_format.Merge(text, message)
+
+ self.assertEqual(-9223372036854775808, message.repeated_int64[0])
+ self.assertEqual(18446744073709551615, message.repeated_uint64[0])
+ self.assertEqual(123.456, message.repeated_double[0])
+ self.assertEqual(1.23e22, message.repeated_double[1])
+ self.assertEqual(1.23e-18, message.repeated_double[2])
+ self.assertEqual(
+ '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0])
+
+ def testMergeUnknownField(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'unknown_field: 8\n'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
+ '"unknown_field".'),
+ text_format.Merge, text, message)
+
+ def testMergeBadExtension(self):
+ message = unittest_pb2.TestAllTypes()
+ text = '[unknown_extension]: 8\n'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Merge, text, message)
+
+ def testMergeGroupNotClosed(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'RepeatedGroup: <'
+ self.assertRaisesWithMessage(
+ text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Merge, text, message)
+
+ text = 'RepeatedGroup: {'
+ self.assertRaisesWithMessage(
+ text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Merge, text, message)
+
+ def testMergeBadEnumValue(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_nested_enum: BARR'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ 'has no value named BARR.'),
+ text_format.Merge, text, message)
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_nested_enum: 100'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ 'has no value with number 100.'),
+ text_format.Merge, text, message)
+
+ def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs):
+ """Same as assertRaises, but also compares the exception message."""
+ if hasattr(e_class, '__name__'):
+ exc_name = e_class.__name__
+ else:
+ exc_name = str(e_class)
+
+ try:
+ func(*args, **kwargs)
+ except e_class, expr:
+ if str(expr) != e:
+ msg = '%s raised, but with wrong message: "%s" instead of "%s"'
+ raise self.failureException(msg % (exc_name,
+ str(expr).encode('string_escape'),
+ e.encode('string_escape')))
+ return
+ else:
+ raise self.failureException('%s not raised' % exc_name)
+
+
+class TokenizerTest(unittest.TestCase):
+
+ def testSimpleTokenCases(self):
+ text = ('identifier1:"string1"\n \n\n'
+ 'identifier2 : \n \n123 \n identifier3 :\'string\'\n'
+ 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n'
+ 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
+ 'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
+ 'ID12: 2222222222222222222')
+ tokenizer = text_format._Tokenizer(text)
+ methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
+ ':',
+ (tokenizer.ConsumeString, 'string1'),
+ (tokenizer.ConsumeIdentifier, 'identifier2'),
+ ':',
+ (tokenizer.ConsumeInt32, 123),
+ (tokenizer.ConsumeIdentifier, 'identifier3'),
+ ':',
+ (tokenizer.ConsumeString, 'string'),
+ (tokenizer.ConsumeIdentifier, 'identifiER_4'),
+ ':',
+ (tokenizer.ConsumeFloat, 1.1e+2),
+ (tokenizer.ConsumeIdentifier, 'ID5'),
+ ':',
+ (tokenizer.ConsumeFloat, -0.23),
+ (tokenizer.ConsumeIdentifier, 'ID6'),
+ ':',
+ (tokenizer.ConsumeString, 'aaaa\'bbbb'),
+ (tokenizer.ConsumeIdentifier, 'ID7'),
+ ':',
+ (tokenizer.ConsumeString, 'aa\"bb'),
+ (tokenizer.ConsumeIdentifier, 'ID8'),
+ ':',
+ '{',
+ (tokenizer.ConsumeIdentifier, 'A'),
+ ':',
+ (tokenizer.ConsumeFloat, float('inf')),
+ (tokenizer.ConsumeIdentifier, 'B'),
+ ':',
+ (tokenizer.ConsumeFloat, float('-inf')),
+ (tokenizer.ConsumeIdentifier, 'C'),
+ ':',
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'D'),
+ ':',
+ (tokenizer.ConsumeBool, False),
+ '}',
+ (tokenizer.ConsumeIdentifier, 'ID9'),
+ ':',
+ (tokenizer.ConsumeUint32, 22),
+ (tokenizer.ConsumeIdentifier, 'ID10'),
+ ':',
+ (tokenizer.ConsumeInt64, -111111111111111111),
+ (tokenizer.ConsumeIdentifier, 'ID11'),
+ ':',
+ (tokenizer.ConsumeInt32, -22),
+ (tokenizer.ConsumeIdentifier, 'ID12'),
+ ':',
+ (tokenizer.ConsumeUint64, 2222222222222222222)]
+
+ i = 0
+ while not tokenizer.AtEnd():
+ m = methods[i]
+ if type(m) == str:
+ token = tokenizer.token
+ self.assertEqual(token, m)
+ tokenizer.NextToken()
+ else:
+ self.assertEqual(m[1], m[0]())
+ i += 1
+
+ def testConsumeIntegers(self):
+ # This test only tests the failures in the integer parsing methods as well
+ # as the '0' special cases.
+ int64_max = (1 << 63) - 1
+ uint32_max = (1 << 32) - 1
+ text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
+ self.assertEqual(-1, tokenizer.ConsumeInt32())
+
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32)
+ self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64())
+
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64)
+ self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64())
+ self.assertTrue(tokenizer.AtEnd())
+
+ text = '-0 -0 0 0'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertEqual(0, tokenizer.ConsumeUint32())
+ self.assertEqual(0, tokenizer.ConsumeUint64())
+ self.assertEqual(0, tokenizer.ConsumeUint32())
+ self.assertEqual(0, tokenizer.ConsumeUint64())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeByteString(self):
+ text = '"string1\''
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = 'string1"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\xt"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\x"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ def testConsumeBool(self):
+ text = 'not-a-bool'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
+
+
if __name__ == '__main__':
unittest.main()
-
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index c009627f..a3bc57ff 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -122,8 +122,10 @@ class UnicodeValueChecker(object):
try:
unicode(proposed_value, 'ascii')
except UnicodeDecodeError:
- raise ValueError('%.1024r isn\'t in 7-bit ASCII encoding.'
- % (proposed_value))
+ raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII '
+ 'encoding. Non-ASCII strings must be converted to '
+ 'unicode objects before being added.' %
+ (proposed_value))
class Int32ValueChecker(IntValueChecker):
diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py
index 950267f9..da6464de 100755
--- a/python/google/protobuf/internal/wire_format.py
+++ b/python/google/protobuf/internal/wire_format.py
@@ -64,6 +64,8 @@ UINT64_MAX = (1 << 64) - 1
# "struct" format strings that will encode/decode the specified formats.
FORMAT_UINT32_LITTLE_ENDIAN = '<I'
FORMAT_UINT64_LITTLE_ENDIAN = '<Q'
+FORMAT_FLOAT_LITTLE_ENDIAN = '<f'
+FORMAT_DOUBLE_LITTLE_ENDIAN = '<d'
# We'll have to provide alternate implementations of AppendLittleEndian*() on
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index 4da024ca..9a88bdc8 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -36,7 +36,6 @@
__author__ = 'robinson@google.com (Will Robinson)'
-from google.protobuf import text_format
class Error(Exception): pass
class DecodeError(Error): pass
@@ -76,7 +75,7 @@ class Message(object):
return not self == other_msg
def __str__(self):
- return text_format.MessageToString(self)
+ raise NotImplementedError
def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 5ab7a1b1..d65d8b67 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -62,6 +62,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -291,7 +292,7 @@ def _DefaultValueForField(message, field):
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
fields = message_descriptor.fields
- def init(self):
+ def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = False
self._listener = message_listener_mod.NullMessageListener()
@@ -306,12 +307,30 @@ def _AddInitMethod(message_descriptor, cls):
if field.label != _FieldDescriptor.LABEL_REPEATED:
setattr(self, _HasFieldName(field.name), False)
self.Extensions = _ExtensionDict(self, cls._known_extensions)
+ for field_name, field_value in kwargs.iteritems():
+ field = _GetFieldByName(message_descriptor, field_name)
+ _MergeFieldOrExtension(self, field, field_value)
init.__module__ = None
init.__doc__ = None
cls.__init__ = init
+def _GetFieldByName(message_descriptor, field_name):
+ """Returns a field descriptor by field name.
+
+ Args:
+ message_descriptor: A Descriptor describing all fields in message.
+ field_name: The name of the field to retrieve.
+ Returns:
+ The field descriptor associated with the field name.
+ """
+ try:
+ return message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+
def _AddPropertiesForFields(descriptor, cls):
"""Adds properties for all fields in this protocol message type."""
for field in descriptor.fields:
@@ -543,10 +562,7 @@ def _AddHasFieldMethod(cls):
def _AddClearFieldMethod(cls):
"""Helper for _AddMessageMethods()."""
def ClearField(self, field_name):
- try:
- field = self.DESCRIPTOR.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ field = _GetFieldByName(self.DESCRIPTOR, field_name)
proto_field_name = field.name
python_field_name = _ValueFieldName(proto_field_name)
has_field_name = _HasFieldName(proto_field_name)
@@ -629,6 +645,13 @@ def _AddEqualsMethod(message_descriptor, cls):
cls.__eq__ = __eq__
+def _AddStrMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __str__(self):
+ return text_format.MessageToString(self)
+ cls.__str__ = __str__
+
+
def _AddSetListenerMethod(cls):
"""Helper for _AddMessageMethods()."""
def SetListener(self, listener):
@@ -1090,7 +1113,7 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
content_start = decoder.Position()
while decoder.Position() - content_start < length:
element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - content_start
+ return decoder.Position() - initial_position
else:
# Repeated composite.
composite = element_list.add()
@@ -1275,6 +1298,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddClearMethod(cls)
_AddHasExtensionMethod(cls)
_AddEqualsMethod(message_descriptor, cls)
+ _AddStrMethod(message_descriptor, cls)
_AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
@@ -1436,6 +1460,20 @@ class _ExtensionDict(object):
if extension.label != _FieldDescriptor.LABEL_REPEATED)
self._has_bits = dict.fromkeys(keys, False)
+ self._extensions_by_number = dict(
+ (f.number, f) for f in self._known_extensions.itervalues())
+
+ self._extensions_by_name = {}
+ for extension in self._known_extensions.itervalues():
+ if (extension.containing_type.GetOptions().message_set_wire_format and
+ extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and
+ extension.message_type == extension.extension_scope and
+ extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL):
+ extension_name = extension.message_type.full_name
+ else:
+ extension_name = extension.full_name
+ self._extensions_by_name[extension_name] = extension
+
def __getitem__(self, extension_handle):
"""Returns the current value of the given extension handle."""
# We don't care as much about keeping critical sections short in the
@@ -1609,7 +1647,15 @@ class _ExtensionDict(object):
Returns: A dict mapping field_number to (handle, field_descriptor),
for *all* registered extensions for this dict.
"""
- # TODO(robinson): Precompute and store this away. Note that we'll have to
- # be careful when we move away from having _known_extensions as a
- # deep-copied member of this object.
- return dict((f.number, f) for f in self._known_extensions.itervalues())
+ return self._extensions_by_number
+
+ def _FindExtensionByName(self, name):
+ """Tries to find a known extension with the specified name.
+
+ Args:
+ name: Extension full name.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extensions_by_name.get(name, None)
diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py
index 9ec42fe3..dd136c9a 100755
--- a/python/google/protobuf/service.py
+++ b/python/google/protobuf/service.py
@@ -67,8 +67,6 @@ class Service(object):
and "done" will later be called with the response value.
In the blocking case, RpcException will be raised on error.
- Asynchronous calls must check status via the Failed method of the
- RpcController.
Preconditions:
* method_descriptor.service == GetDescriptor
@@ -82,6 +80,9 @@ class Service(object):
Postconditions:
* "done" will be called when the method is complete. This may be
before CallMethod() returns or it may be at some point in the future.
+ * If the RPC failed, the response value passed to "done" will be None.
+ Further details about the failure can be found by querying the
+ RpcController.
"""
raise NotImplementedError
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 596ef946..1cddce6c 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -33,10 +33,19 @@
__author__ = 'kenton@google.com (Kenton Varda)'
import cStringIO
+import re
+from collections import deque
+from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
-__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue' ]
+__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField',
+ 'PrintFieldValue', 'Merge' ]
+
+
+class ParseError(Exception):
+ """Thrown in case of ASCII parsing error."""
+
def MessageToString(message):
out = cStringIO.StringIO()
@@ -45,6 +54,7 @@ def MessageToString(message):
out.close()
return result
+
def PrintMessage(message, out, indent = 0):
for field, value in message.ListFields():
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
@@ -53,6 +63,7 @@ def PrintMessage(message, out, indent = 0):
else:
PrintField(field, value, out, indent)
+
def PrintField(field, value, out, indent = 0):
"""Print a single field name/value pair. For repeated fields, the value
should be a single element."""
@@ -82,6 +93,7 @@ def PrintField(field, value, out, indent = 0):
PrintFieldValue(field, value, out, indent)
out.write('\n')
+
def PrintFieldValue(field, value, out, indent = 0):
"""Print a single field value (not including name). For repeated fields,
the value should be a single element."""
@@ -104,6 +116,507 @@ def PrintFieldValue(field, value, out, indent = 0):
else:
out.write(str(value))
+
+def Merge(text, message):
+ """Merges an ASCII representation of a protocol message into a message.
+
+ Args:
+ text: Message ASCII representation.
+ message: A protocol buffer message to merge into.
+
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ tokenizer = _Tokenizer(text)
+ while not tokenizer.AtEnd():
+ _MergeField(tokenizer, message)
+
+
+def _MergeField(tokenizer, message):
+ """Merges a single protocol message field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ message: A protocol message to record the data.
+
+ Raises:
+ ParseError: In case of ASCII parsing problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ if tokenizer.TryConsume('['):
+ name = [tokenizer.ConsumeIdentifier()]
+ while tokenizer.TryConsume('.'):
+ name.append(tokenizer.ConsumeIdentifier())
+ name = '.'.join(name)
+
+ field = message.Extensions._FindExtensionByName(name)
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" not registered.' % name)
+ elif message_descriptor != field.containing_type:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" does not extend message type "%s".' % (
+ name, message_descriptor.full_name))
+ tokenizer.Consume(']')
+ else:
+ name = tokenizer.ConsumeIdentifier()
+ field = message_descriptor.fields_by_name.get(name, None)
+
+ # Group names are expected to be capitalized as they appear in the
+ # .proto file, which actually matches their type names, not their field
+ # names.
+ if not field:
+ field = message_descriptor.fields_by_name.get(name.lower(), None)
+ if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
+ field = None
+
+ if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
+ field.message_type.name != name):
+ field = None
+
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" has no field named "%s".' % (
+ message_descriptor.full_name, name))
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ tokenizer.TryConsume(':')
+
+ if tokenizer.TryConsume('<'):
+ end_token = '>'
+ else:
+ tokenizer.Consume('{')
+ end_token = '}'
+
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ sub_message = message.Extensions[field].add()
+ else:
+ sub_message = getattr(message, field.name).add()
+ else:
+ if field.is_extension:
+ sub_message = message.Extensions[field]
+ else:
+ sub_message = getattr(message, field.name)
+
+ while not tokenizer.TryConsume(end_token):
+ if tokenizer.AtEnd():
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
+ _MergeField(tokenizer, sub_message)
+ else:
+ _MergeScalarField(tokenizer, message, field)
+
+
+def _MergeScalarField(tokenizer, message, field):
+ """Merges a single protocol message scalar field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: A protocol message to record the data.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of ASCII parsing problems.
+ RuntimeError: On runtime errors.
+ """
+ tokenizer.Consume(':')
+ value = None
+
+ if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
+ descriptor.FieldDescriptor.TYPE_SINT32,
+ descriptor.FieldDescriptor.TYPE_SFIXED32):
+ value = tokenizer.ConsumeInt32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
+ descriptor.FieldDescriptor.TYPE_SINT64,
+ descriptor.FieldDescriptor.TYPE_SFIXED64):
+ value = tokenizer.ConsumeInt64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
+ descriptor.FieldDescriptor.TYPE_FIXED32):
+ value = tokenizer.ConsumeUint32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
+ descriptor.FieldDescriptor.TYPE_FIXED64):
+ value = tokenizer.ConsumeUint64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
+ descriptor.FieldDescriptor.TYPE_DOUBLE):
+ value = tokenizer.ConsumeFloat()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ value = tokenizer.ConsumeBool()
+ elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
+ value = tokenizer.ConsumeString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ value = tokenizer.ConsumeByteString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ # Enum can be specified by a number (the enum value), or by
+ # a string literal (the enum name).
+ enum_descriptor = field.enum_type
+ if tokenizer.LookingAtInteger():
+ number = tokenizer.ConsumeInt32()
+ enum_value = enum_descriptor.values_by_number.get(number, None)
+ if enum_value is None:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Enum type "%s" has no value with number %d.' % (
+ enum_descriptor.full_name, number))
+ else:
+ identifier = tokenizer.ConsumeIdentifier()
+ enum_value = enum_descriptor.values_by_name.get(identifier, None)
+ if enum_value is None:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Enum type "%s" has no value named %s.' % (
+ enum_descriptor.full_name, identifier))
+ value = enum_value.number
+ else:
+ raise RuntimeError('Unknown field type %d' % field.type)
+
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ message.Extensions[field].append(value)
+ else:
+ getattr(message, field.name).append(value)
+ else:
+ if field.is_extension:
+ message.Extensions[field] = value
+ else:
+ setattr(message, field.name, value)
+
+
+class _Tokenizer(object):
+ """Protocol buffer ASCII representation tokenizer.
+
+ This class handles the lower level string parsing by splitting it into
+ meaningful tokens.
+
+ It was directly ported from the Java protocol buffer API.
+ """
+
+ _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
+ _TOKEN = re.compile(
+ '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier
+ '[0-9+-][0-9a-zA-Z_.+-]*|' # a number
+ '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
+ '\'([^\"\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
+ _IDENTIFIER = re.compile('\w+')
+ _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(),
+ type_checkers.Int32ValueChecker(),
+ type_checkers.Uint64ValueChecker(),
+ type_checkers.Int64ValueChecker()]
+ _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE)
+ _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE)
+
+ def __init__(self, text_message):
+ self._text_message = text_message
+
+ self._position = 0
+ self._line = -1
+ self._column = 0
+ self._token_start = None
+ self.token = ''
+ self._lines = deque(text_message.split('\n'))
+ self._current_line = ''
+ self._previous_line = 0
+ self._previous_column = 0
+ self._SkipWhitespace()
+ self.NextToken()
+
+ def AtEnd(self):
+ """Checks the end of the text was reached.
+
+ Returns:
+ True iff the end was reached.
+ """
+ return not self._lines and not self._current_line
+
+ def _PopLine(self):
+ while not self._current_line:
+ if not self._lines:
+ self._current_line = ''
+ return
+ self._line += 1
+ self._column = 0
+ self._current_line = self._lines.popleft()
+
+ def _SkipWhitespace(self):
+ while True:
+ self._PopLine()
+ match = re.match(self._WHITESPACE, self._current_line)
+ if not match:
+ break
+ length = len(match.group(0))
+ self._current_line = self._current_line[length:]
+ self._column += length
+
+ def TryConsume(self, token):
+ """Tries to consume a given piece of text.
+
+ Args:
+ token: Text to consume.
+
+ Returns:
+ True iff the text was consumed.
+ """
+ if self.token == token:
+ self.NextToken()
+ return True
+ return False
+
+ def Consume(self, token):
+ """Consumes a piece of text.
+
+ Args:
+ token: Text to consume.
+
+ Raises:
+ ParseError: If the text couldn't be consumed.
+ """
+ if not self.TryConsume(token):
+ raise self._ParseError('Expected "%s".' % token)
+
+ def LookingAtInteger(self):
+ """Checks if the current token is an integer.
+
+ Returns:
+ True iff the current token is an integer.
+ """
+ if not self.token:
+ return False
+ c = self.token[0]
+ return (c >= '0' and c <= '9') or c == '-' or c == '+'
+
+ def ConsumeIdentifier(self):
+ """Consumes protocol message field identifier.
+
+ Returns:
+ Identifier string.
+
+ Raises:
+ ParseError: If an identifier couldn't be consumed.
+ """
+ result = self.token
+ if not re.match(self._IDENTIFIER, result):
+ raise self._ParseError('Expected identifier.')
+ self.NextToken()
+ return result
+
+ def ConsumeInt32(self):
+ """Consumes a signed 32bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=True, is_long=False)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeUint32(self):
+ """Consumes an unsigned 32bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 32bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=False, is_long=False)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeInt64(self):
+ """Consumes a signed 64bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 64bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=True, is_long=True)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeUint64(self):
+ """Consumes an unsigned 64bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 64bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=False, is_long=True)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeFloat(self):
+ """Consumes an floating point number.
+
+ Returns:
+ The number parsed.
+
+ Raises:
+ ParseError: If a floating point number couldn't be consumed.
+ """
+ text = self.token
+ if re.match(self._FLOAT_INFINITY, text):
+ self.NextToken()
+ if text.startswith('-'):
+ return float('-inf')
+ return float('inf')
+
+ if re.match(self._FLOAT_NAN, text):
+ self.NextToken()
+ return float('nan')
+
+ try:
+ result = float(text)
+ except ValueError, e:
+ raise self._FloatParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeBool(self):
+ """Consumes a boolean value.
+
+ Returns:
+ The bool parsed.
+
+ Raises:
+ ParseError: If a boolean value couldn't be consumed.
+ """
+ if self.token == 'true':
+ self.NextToken()
+ return True
+ elif self.token == 'false':
+ self.NextToken()
+ return False
+ else:
+ raise self._ParseError('Expected "true" or "false".')
+
+ def ConsumeString(self):
+ """Consumes a string value.
+
+ Returns:
+ The string parsed.
+
+ Raises:
+ ParseError: If a string value couldn't be consumed.
+ """
+ return unicode(self.ConsumeByteString(), 'utf-8')
+
+ def ConsumeByteString(self):
+ """Consumes a byte array value.
+
+ Returns:
+ The array parsed (as a string).
+
+ Raises:
+ ParseError: If a byte array value couldn't be consumed.
+ """
+ text = self.token
+ if len(text) < 1 or text[0] not in ('\'', '"'):
+ raise self._ParseError('Exptected string.')
+
+ if len(text) < 2 or text[-1] != text[0]:
+ raise self._ParseError('String missing ending quote.')
+
+ try:
+ result = _CUnescape(text[1:-1])
+ except ValueError, e:
+ raise self._ParseError(str(e))
+ self.NextToken()
+ return result
+
+ def _ParseInteger(self, text, is_signed=False, is_long=False):
+ """Parses an integer.
+
+ Args:
+ text: The text to parse.
+ is_signed: True if a signed integer must be parsed.
+ is_long: True if a long integer must be parsed.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ValueError: Thrown Iff the text is not a valid integer.
+ """
+ pos = 0
+ if text.startswith('-'):
+ pos += 1
+
+ base = 10
+ if text.startswith('0x', pos) or text.startswith('0X', pos):
+ base = 16
+ elif text.startswith('0', pos):
+ base = 8
+
+ # Do the actual parsing. Exception handling is propagated to caller.
+ result = int(text, base)
+
+ # Check if the integer is sane. Exceptions handled by callers.
+ checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
+ checker.CheckValue(result)
+ return result
+
+ def ParseErrorPreviousToken(self, message):
+ """Creates and *returns* a ParseError for the previously read token.
+
+ Args:
+ message: A message to set for the exception.
+
+ Returns:
+ A ParseError instance.
+ """
+ return ParseError('%d:%d : %s' % (
+ self._previous_line + 1, self._previous_column + 1, message))
+
+ def _ParseError(self, message):
+ """Creates and *returns* a ParseError for the current token."""
+ return ParseError('%d:%d : %s' % (
+ self._line + 1, self._column + 1, message))
+
+ def _IntegerParseError(self, e):
+ return self._ParseError('Couldn\'t parse integer: ' + str(e))
+
+ def _FloatParseError(self, e):
+ return self._ParseError('Couldn\'t parse number: ' + str(e))
+
+ def NextToken(self):
+ """Reads the next meaningful token."""
+ self._previous_line = self._line
+ self._previous_column = self._column
+ if self.AtEnd():
+ self.token = ''
+ return
+ self._column += len(self.token)
+
+ # Make sure there is data to work on.
+ self._PopLine()
+
+ match = re.match(self._TOKEN, self._current_line)
+ if match:
+ token = match.group(0)
+ self._current_line = self._current_line[len(token):]
+ self.token = token
+ else:
+ self.token = self._current_line[0]
+ self._current_line = self._current_line[1:]
+ self._SkipWhitespace()
+
+
# text.encode('string_escape') does not seem to satisfy our needs as it
# encodes unprintable characters using two-digit hex escapes whereas our
# C++ unescaping function allows hex escapes to be any length. So,
@@ -123,3 +636,15 @@ def _CEscape(text):
if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
return c
return "".join([escape(c) for c in text])
+
+
+_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])')
+
+
+def _CUnescape(text):
+ def ReplaceHex(m):
+ return chr(int(m.group(0)[2:], 16))
+ # This is required because the 'string_escape' encoding doesn't
+ # allow single-digit hex escapes (like '\xf').
+ result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
+ return result.decode('string_escape')