aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google
diff options
context:
space:
mode:
Diffstat (limited to 'python/google')
-rwxr-xr-xpython/google/protobuf/__init__.py2
-rwxr-xr-xpython/google/protobuf/descriptor.py9
-rw-r--r--python/google/protobuf/descriptor_database.py3
-rw-r--r--python/google/protobuf/descriptor_pool.py19
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py26
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py46
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py6
-rwxr-xr-xpython/google/protobuf/internal/encoder.py107
-rw-r--r--python/google/protobuf/internal/factory_test2.proto5
-rw-r--r--python/google/protobuf/internal/json_format_test.py103
-rwxr-xr-xpython/google/protobuf/internal/message_test.py95
-rw-r--r--python/google/protobuf/internal/more_extensions_dynamic.proto1
-rwxr-xr-xpython/google/protobuf/internal/python_message.py18
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py151
-rw-r--r--python/google/protobuf/internal/well_known_types.py4
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py2
-rw-r--r--python/google/protobuf/json_format.py50
-rwxr-xr-xpython/google/protobuf/message.py16
-rw-r--r--python/google/protobuf/message_factory.py14
-rw-r--r--python/google/protobuf/pyext/descriptor.cc10
-rw-r--r--python/google/protobuf/pyext/map_container.cc24
-rw-r--r--python/google/protobuf/pyext/message.cc62
-rw-r--r--python/google/protobuf/pyext/message_factory.cc11
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc22
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc14
-rw-r--r--python/google/protobuf/symbol_database.py26
-rwxr-xr-xpython/google/protobuf/text_format.py107
27 files changed, 749 insertions, 204 deletions
diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py
index 622dfb3d..d26da0df 100755
--- a/python/google/protobuf/__init__.py
+++ b/python/google/protobuf/__init__.py
@@ -30,7 +30,7 @@
# Copyright 2007 Google Inc. All Rights Reserved.
-__version__ = '3.3.2'
+__version__ = '3.4.0'
if __name__ != '__main__':
try:
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index e1f2e3b7..b1f3ca38 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -406,6 +406,8 @@ class FieldDescriptor(DescriptorBase):
containing_oneof: (OneofDescriptor) If the field is a member of a oneof
union, contains its descriptor. Otherwise, None.
+
+ file: (FileDescriptor) Reference to file descriptor.
"""
# Must be consistent with C++ FieldDescriptor::Type enum in
@@ -490,7 +492,8 @@ class FieldDescriptor(DescriptorBase):
def __new__(cls, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None, json_name=None):
+ has_default_value=True, containing_oneof=None, json_name=None,
+ file=None):
_message.Message._CheckCalledFromGeneratedFile()
if is_extension:
return _message.default_pool.FindExtensionByName(full_name)
@@ -500,7 +503,8 @@ class FieldDescriptor(DescriptorBase):
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None, json_name=None):
+ has_default_value=True, containing_oneof=None, json_name=None,
+ file=None):
"""The arguments are as described in the description of FieldDescriptor
attributes above.
@@ -511,6 +515,7 @@ class FieldDescriptor(DescriptorBase):
super(FieldDescriptor, self).__init__(options, 'FieldOptions')
self.name = name
self.full_name = full_name
+ self.file = file
self._camelcase_name = None
if json_name is None:
self.json_name = _ToJsonName(name)
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index 40bcdd72..eb45e127 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -134,8 +134,7 @@ def _ExtractSymbols(desc_proto, package):
Yields:
The fully qualified name found in the descriptor.
"""
-
- message_name = '.'.join((package, desc_proto.name))
+ message_name = package + '.' + desc_proto.name if package else desc_proto.name
yield message_name
for nested_type in desc_proto.nested_type:
for symbol in _ExtractSymbols(nested_type, message_name):
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 29e9f1e3..3dbe0fd0 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -329,6 +329,11 @@ class DescriptorPool(object):
pass
try:
+ return self._service_descriptors[symbol].file
+ except KeyError:
+ pass
+
+ try:
return self._FindFileContainingSymbolInDb(symbol)
except KeyError:
pass
@@ -344,7 +349,6 @@ class DescriptorPool(object):
message = self.FindMessageTypeByName(message_name)
assert message.extensions_by_name[extension_name]
return message.file
-
except KeyError:
raise KeyError('Cannot find a file containing %s' % symbol)
@@ -557,7 +561,8 @@ class DescriptorPool(object):
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self._MakeFieldDescriptor(
- extension_proto, file_proto.package, index, is_extension=True)
+ extension_proto, file_proto.package, index, file_descriptor,
+ is_extension=True)
extension_desc.containing_type = self._GetTypeFromScope(
file_descriptor.package, extension_proto.extendee, scope)
self._SetFieldType(extension_proto, extension_desc,
@@ -623,10 +628,10 @@ class DescriptorPool(object):
enums = [
self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
for enum in desc_proto.enum_type]
- fields = [self._MakeFieldDescriptor(field, desc_name, index)
+ fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
for index, field in enumerate(desc_proto.field)]
extensions = [
- self._MakeFieldDescriptor(extension, desc_name, index,
+ self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
is_extension=True)
for index, extension in enumerate(desc_proto.extension)]
oneofs = [
@@ -708,7 +713,7 @@ class DescriptorPool(object):
return desc
def _MakeFieldDescriptor(self, field_proto, message_name, index,
- is_extension=False):
+ file_desc, is_extension=False):
"""Creates a field descriptor from a FieldDescriptorProto.
For message and enum type fields, this method will do a look up
@@ -721,6 +726,7 @@ class DescriptorPool(object):
field_proto: The proto describing the field.
message_name: The name of the containing message.
index: Index of the field
+ file_desc: The file containing the field descriptor.
is_extension: Indication that this field is for an extension.
Returns:
@@ -747,7 +753,8 @@ class DescriptorPool(object):
default_value=None,
is_extension=is_extension,
extension_scope=None,
- options=_OptionsOrNone(field_proto))
+ options=_OptionsOrNone(field_proto),
+ file=file_desc)
def _SetAllFieldTypes(self, package, desc_proto, scope):
"""Sets all the descriptor's fields's types.
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index 460a4a6c..422af590 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -100,6 +100,27 @@ if _implementation_version_str != '2':
_implementation_version = int(_implementation_version_str)
+# Detect if serialization should be deterministic by default
+try:
+ # The presence of this module in a build allows the proto implementation to
+ # be upgraded merely via build deps.
+ #
+ # NOTE: Merely importing this automatically enables deterministic proto
+ # serialization for C++ code, but we still need to export it as a boolean so
+ # that we can do the same for `_implementation_type == 'python'`.
+ #
+ # NOTE2: It is possible for C++ code to enable deterministic serialization by
+ # default _without_ affecting Python code, if the C++ implementation is not in
+ # use by this module. That is intended behavior, so we don't actually expose
+ # this boolean outside of this module.
+ #
+ # pylint: disable=g-import-not-at-top,unused-import
+ from google.protobuf import enable_deterministic_proto_serialization
+ _python_deterministic_proto_serialization = True
+except ImportError:
+ _python_deterministic_proto_serialization = False
+
+
# Usage of this function is discouraged. Clients shouldn't care which
# implementation of the API is in use. Note that there is no guarantee
# that differences between APIs will be maintained.
@@ -111,3 +132,8 @@ def Type():
# See comment on 'Type' above.
def Version():
return _implementation_version
+
+
+# For internal use only
+def IsPythonDefaultSerializationDeterministic():
+ return _python_deterministic_proto_serialization
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index c1733a48..6015e6f8 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -131,11 +131,19 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google/protobuf/internal/factory_test2.proto',
file_desc4.name)
+ file_desc5 = self.pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestService')
+ self.assertIsInstance(file_desc5, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/unittest.proto',
+ file_desc5.name)
+
# Tests the generated pool.
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field')
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.another_field')
+ assert descriptor_pool.Default().FindFileContainingSymbol(
+ 'protobuf_unittest.TestService')
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
@@ -506,10 +514,10 @@ class MessageType(object):
subtype.CheckType(test, desc, name, file_desc)
for index, (name, field) in enumerate(self.field_list):
- field.CheckField(test, desc, name, index)
+ field.CheckField(test, desc, name, index, file_desc)
for index, (name, field) in enumerate(self.extensions):
- field.CheckField(test, desc, name, index)
+ field.CheckField(test, desc, name, index, file_desc)
class EnumField(object):
@@ -519,7 +527,7 @@ class EnumField(object):
self.type_name = type_name
self.default_value = default_value
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
enum_desc = msg_desc.enum_types_by_name[self.type_name]
test.assertEqual(name, field_desc.name)
@@ -536,6 +544,7 @@ class EnumField(object):
test.assertFalse(enum_desc.values_by_name[self.default_value].has_options)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(enum_desc, field_desc.enum_type)
+ test.assertEqual(file_desc, enum_desc.file)
class MessageField(object):
@@ -544,7 +553,7 @@ class MessageField(object):
self.number = number
self.type_name = type_name
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
field_type_desc = msg_desc.nested_types_by_name[self.type_name]
test.assertEqual(name, field_desc.name)
@@ -558,6 +567,7 @@ class MessageField(object):
test.assertFalse(field_desc.has_default_value)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
+ test.assertEqual(file_desc, field_desc.file)
class StringField(object):
@@ -566,7 +576,7 @@ class StringField(object):
self.number = number
self.default_value = default_value
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.fields_by_name[name]
test.assertEqual(name, field_desc.name)
expected_field_full_name = '.'.join([msg_desc.full_name, name])
@@ -578,6 +588,7 @@ class StringField(object):
field_desc.cpp_type)
test.assertTrue(field_desc.has_default_value)
test.assertEqual(self.default_value, field_desc.default_value)
+ test.assertEqual(file_desc, field_desc.file)
class ExtensionField(object):
@@ -586,7 +597,7 @@ class ExtensionField(object):
self.number = number
self.extended_type = extended_type
- def CheckField(self, test, msg_desc, name, index):
+ def CheckField(self, test, msg_desc, name, index, file_desc):
field_desc = msg_desc.extensions_by_name[name]
test.assertEqual(name, field_desc.name)
expected_field_full_name = '.'.join([msg_desc.full_name, name])
@@ -601,6 +612,7 @@ class ExtensionField(object):
test.assertEqual(msg_desc, field_desc.extension_scope)
test.assertEqual(msg_desc, field_desc.message_type)
test.assertEqual(self.extended_type, field_desc.containing_type.name)
+ test.assertEqual(file_desc, field_desc.file)
class AddDescriptorTest(unittest.TestCase):
@@ -746,15 +758,10 @@ class AddDescriptorTest(unittest.TestCase):
self.assertIs(options, file_descriptor.GetOptions())
-@unittest.skipIf(
- api_implementation.Type() != 'cpp',
- 'default_pool is only supported by the C++ implementation')
class DefaultPoolTest(unittest.TestCase):
def testFindMethods(self):
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- pool = _message.default_pool
+ pool = descriptor_pool.Default()
self.assertIs(
pool.FindFileByName('google/protobuf/unittest.proto'),
unittest_pb2.DESCRIPTOR)
@@ -765,19 +772,22 @@ class DefaultPoolTest(unittest.TestCase):
pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'),
unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32'])
self.assertIs(
- pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
- unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
- self.assertIs(
pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'),
unittest_pb2.ForeignEnum.DESCRIPTOR)
+ if api_implementation.Type() != 'cpp':
+ self.skipTest('Only the C++ implementation correctly indexes all types')
+ self.assertIs(
+ pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'),
+ unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension'])
self.assertIs(
pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'),
unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field'])
+ self.assertIs(
+ pool.FindServiceByName('protobuf_unittest.TestService'),
+ unittest_pb2.DESCRIPTOR.services_by_name['TestService'])
def testAddFileDescriptor(self):
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
- pool = _message.default_pool
+ pool = descriptor_pool.Default()
file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto')
pool.Add(file_desc)
pool.AddSerializedFile(file_desc.SerializeToString())
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 1f148ab9..c0010081 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -521,6 +521,12 @@ class GeneratedDescriptorTest(unittest.TestCase):
del enum
self.assertEqual('FOO', next(values_iter).name)
+ def testServiceDescriptor(self):
+ service_descriptor = unittest_pb2.DESCRIPTOR.services_by_name['TestService']
+ self.assertEqual(service_descriptor.name, 'TestService')
+ self.assertEqual(service_descriptor.methods[0].name, 'Foo')
+ self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR)
+
class DescriptorCopyToProtoTest(unittest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 80e59cab..ebec42e5 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -372,7 +372,7 @@ def MapSizer(field_descriptor, is_message_map):
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
- def EncodeVarint(write, value):
+ def EncodeVarint(write, value, unused_deterministic):
bits = value & 0x7f
value >>= 7
while value:
@@ -388,7 +388,7 @@ def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
- def EncodeSignedVarint(write, value):
+ def EncodeSignedVarint(write, value, unused_deterministic):
if value < 0:
value += (1 << 64)
bits = value & 0x7f
@@ -411,7 +411,7 @@ def _VarintBytes(value):
called at startup time so it doesn't need to be fast."""
pieces = []
- _EncodeVarint(pieces.append, value)
+ _EncodeVarint(pieces.append, value, True)
return b"".join(pieces)
@@ -440,27 +440,27 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(element)
- local_EncodeVarint(write, size)
+ local_EncodeVarint(write, size, deterministic)
for element in value:
- encode_value(write, element)
+ encode_value(write, element, deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
- encode_value(write, element)
+ encode_value(write, element, deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag_bytes)
- return encode_value(write, value)
+ return encode_value(write, value, deterministic)
return EncodeField
return SpecificEncoder
@@ -474,27 +474,27 @@ def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(modify_value(element))
- local_EncodeVarint(write, size)
+ local_EncodeVarint(write, size, deterministic)
for element in value:
- encode_value(write, modify_value(element))
+ encode_value(write, modify_value(element), deterministic)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag_bytes)
- encode_value(write, modify_value(element))
+ encode_value(write, modify_value(element), deterministic)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag_bytes)
- return encode_value(write, modify_value(value))
+ return encode_value(write, modify_value(value), deterministic)
return EncodeField
return SpecificEncoder
@@ -515,22 +515,22 @@ def _StructPackEncoder(wire_type, format):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value) * value_size)
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
write(local_struct_pack(format, element))
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic):
for element in value:
write(tag_bytes)
write(local_struct_pack(format, element))
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic):
write(tag_bytes)
return write(local_struct_pack(format, value))
return EncodeField
@@ -581,9 +581,9 @@ def _FloatingPointEncoder(wire_type, format):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value) * value_size)
+ local_EncodeVarint(write, len(value) * value_size, deterministic)
for element in value:
# This try/except block is going to be faster than any code that
# we could write to check whether element is finite.
@@ -594,7 +594,7 @@ def _FloatingPointEncoder(wire_type, format):
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic):
for element in value:
write(tag_bytes)
try:
@@ -604,7 +604,7 @@ def _FloatingPointEncoder(wire_type, format):
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic):
write(tag_bytes)
try:
write(local_struct_pack(format, value))
@@ -650,9 +650,9 @@ def BoolEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
+ def EncodePackedField(write, value, deterministic):
write(tag_bytes)
- local_EncodeVarint(write, len(value))
+ local_EncodeVarint(write, len(value), deterministic)
for element in value:
if element:
write(true_byte)
@@ -661,7 +661,7 @@ def BoolEncoder(field_number, is_repeated, is_packed):
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, unused_deterministic):
for element in value:
write(tag_bytes)
if element:
@@ -671,7 +671,7 @@ def BoolEncoder(field_number, is_repeated, is_packed):
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeField(write, value):
+ def EncodeField(write, value, unused_deterministic):
write(tag_bytes)
if value:
return write(true_byte)
@@ -687,18 +687,18 @@ def StringEncoder(field_number, is_repeated, is_packed):
local_len = len
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
encoded = element.encode('utf-8')
write(tag)
- local_EncodeVarint(write, local_len(encoded))
+ local_EncodeVarint(write, local_len(encoded), deterministic)
write(encoded)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
encoded = value.encode('utf-8')
write(tag)
- local_EncodeVarint(write, local_len(encoded))
+ local_EncodeVarint(write, local_len(encoded), deterministic)
return write(encoded)
return EncodeField
@@ -711,16 +711,16 @@ def BytesEncoder(field_number, is_repeated, is_packed):
local_len = len
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
- local_EncodeVarint(write, local_len(element))
+ local_EncodeVarint(write, local_len(element), deterministic)
write(element)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag)
- local_EncodeVarint(write, local_len(value))
+ local_EncodeVarint(write, local_len(value), deterministic)
return write(value)
return EncodeField
@@ -732,16 +732,16 @@ def GroupEncoder(field_number, is_repeated, is_packed):
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(start_tag)
- element._InternalSerialize(write)
+ element._InternalSerialize(write, deterministic)
write(end_tag)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(start_tag)
- value._InternalSerialize(write)
+ value._InternalSerialize(write, deterministic)
return write(end_tag)
return EncodeField
@@ -753,17 +753,17 @@ def MessageEncoder(field_number, is_repeated, is_packed):
local_EncodeVarint = _EncodeVarint
assert not is_packed
if is_repeated:
- def EncodeRepeatedField(write, value):
+ def EncodeRepeatedField(write, value, deterministic):
for element in value:
write(tag)
- local_EncodeVarint(write, element.ByteSize())
- element._InternalSerialize(write)
+ local_EncodeVarint(write, element.ByteSize(), deterministic)
+ element._InternalSerialize(write, deterministic)
return EncodeRepeatedField
else:
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(tag)
- local_EncodeVarint(write, value.ByteSize())
- return value._InternalSerialize(write)
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
+ return value._InternalSerialize(write, deterministic)
return EncodeField
@@ -790,10 +790,10 @@ def MessageSetItemEncoder(field_number):
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
local_EncodeVarint = _EncodeVarint
- def EncodeField(write, value):
+ def EncodeField(write, value, deterministic):
write(start_bytes)
- local_EncodeVarint(write, value.ByteSize())
- value._InternalSerialize(write)
+ local_EncodeVarint(write, value.ByteSize(), deterministic)
+ value._InternalSerialize(write, deterministic)
return write(end_bytes)
return EncodeField
@@ -818,9 +818,10 @@ def MapEncoder(field_descriptor):
message_type = field_descriptor.message_type
encode_message = MessageEncoder(field_descriptor.number, False, False)
- def EncodeField(write, value):
- for key in value:
+ def EncodeField(write, value, deterministic):
+ value_keys = sorted(value.keys()) if deterministic else value.keys()
+ for key in value_keys:
entry_msg = message_type._concrete_class(key=key, value=value[key])
- encode_message(write, entry_msg)
+ encode_message(write, entry_msg, deterministic)
return EncodeField
diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto
index bb1b54ad..5fcbc5ac 100644
--- a/python/google/protobuf/internal/factory_test2.proto
+++ b/python/google/protobuf/internal/factory_test2.proto
@@ -97,3 +97,8 @@ message MessageWithNestedEnumOnly {
extend Factory1Message {
optional string another_field = 1002;
}
+
+message MessageWithOption {
+ option no_standard_descriptor_accessor = true;
+ optional int32 field1 = 1;
+}
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index 5ed65622..077b64db 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -49,6 +49,7 @@ from google.protobuf import field_mask_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import wrappers_pb2
+from google.protobuf import unittest_mset_pb2
from google.protobuf.internal import well_known_types
from google.protobuf import json_format
from google.protobuf.util import json_format_proto3_pb2
@@ -158,6 +159,84 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse(text, parsed_message)
self.assertEqual(message, parsed_message)
+ def testExtensionToJsonAndBack(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_text = json_format.MessageToJson(
+ message
+ )
+ parsed_message = unittest_mset_pb2.TestMessageSetContainer()
+ json_format.Parse(message_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testExtensionToDictAndBack(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_dict = json_format.MessageToDict(
+ message
+ )
+ parsed_message = unittest_mset_pb2.TestMessageSetContainer()
+ json_format.ParseDict(message_dict, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testExtensionSerializationDictMatchesProto3Spec(self):
+ """See go/proto3-json-spec for spec.
+ """
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_dict = json_format.MessageToDict(
+ message
+ )
+ golden_dict = {
+ 'messageSet': {
+ '[protobuf_unittest.'
+ 'TestMessageSetExtension1.messageSetExtension]': {
+ 'i': 23,
+ },
+ '[protobuf_unittest.'
+ 'TestMessageSetExtension2.messageSetExtension]': {
+ 'str': u'foo',
+ },
+ },
+ }
+ self.assertEqual(golden_dict, message_dict)
+
+
+ def testExtensionSerializationJsonMatchesProto3Spec(self):
+ """See go/proto3-json-spec for spec.
+ """
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ message_text = json_format.MessageToJson(
+ message
+ )
+ ext1_text = ('protobuf_unittest.TestMessageSetExtension1.'
+ 'messageSetExtension')
+ ext2_text = ('protobuf_unittest.TestMessageSetExtension2.'
+ 'messageSetExtension')
+ golden_text = ('{"messageSet": {'
+ ' "[%s]": {'
+ ' "i": 23'
+ ' },'
+ ' "[%s]": {'
+ ' "str": "foo"'
+ ' }'
+ '}}') % (ext1_text, ext2_text)
+ self.assertEqual(json.loads(golden_text), json.loads(message_text))
+
+
def testJsonEscapeString(self):
message = json_format_proto3_pb2.TestMessage()
if sys.version_info[0] < 3:
@@ -768,7 +847,7 @@ class JsonFormatTest(JsonFormatBase):
text = '{"value": "0000-01-01T00:00:00Z"}'
self.assertRaisesRegexp(
json_format.ParseError,
- 'Failed to parse value field: year is out of range.',
+ 'Failed to parse value field: year (0 )?is out of range.',
json_format.Parse, text, message)
# Time bigger than maxinum time.
message.value.seconds = 253402300800
@@ -840,6 +919,12 @@ class JsonFormatTest(JsonFormatBase):
json_format.Parse('{"int32_value": 12345}', message)
self.assertEqual(12345, message.int32_value)
+ def testIndent(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ self.assertEqual('{\n"int32Value": 12345\n}',
+ json_format.MessageToJson(message, indent=0))
+
def testParseDict(self):
expected = 12345
js_dict = {'int32Value': expected}
@@ -862,6 +947,22 @@ class JsonFormatTest(JsonFormatBase):
parsed_message = json_format_proto3_pb2.TestCustomJsonName()
self.CheckParseBack(message, parsed_message)
+ def testSortKeys(self):
+ # Testing sort_keys is not perfectly working, as by random luck we could
+ # get the output sorted. We just use a selection of names.
+ message = json_format_proto3_pb2.TestMessage(bool_value=True,
+ int32_value=1,
+ int64_value=3,
+ uint32_value=4,
+ string_value='bla')
+ self.assertEqual(
+ json_format.MessageToJson(message, sort_keys=True),
+ # We use json.dumps() instead of a hardcoded string due to differences
+ # between Python 2 and Python 3.
+ json.dumps({'boolValue': True, 'int32Value': 1, 'int64Value': '3',
+ 'uint32Value': 4, 'stringValue': 'bla'},
+ indent=2, sort_keys=True))
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 3373b21f..dda72cdd 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -140,6 +140,42 @@ class MessageTest(BaseTestCase):
golden_copy = copy.deepcopy(golden_message)
self.assertEqual(golden_data, golden_copy.SerializeToString())
+ def testDeterminismParameters(self, message_module):
+ # This message is always deterministically serialized, even if determinism
+ # is disabled, so we can use it to verify that all the determinism
+ # parameters work correctly.
+ golden_data = (b'\xe2\x02\nOne string'
+ b'\xe2\x02\nTwo string'
+ b'\xe2\x02\nRed string'
+ b'\xe2\x02\x0bBlue string')
+ golden_message = message_module.TestAllTypes()
+ golden_message.repeated_string.extend([
+ 'One string',
+ 'Two string',
+ 'Red string',
+ 'Blue string',
+ ])
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=None))
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=False))
+ self.assertEqual(golden_data,
+ golden_message.SerializeToString(deterministic=True))
+
+ class BadArgError(Exception):
+ pass
+
+ class BadArg(object):
+
+ def __nonzero__(self):
+ raise BadArgError()
+
+ def __bool__(self):
+ raise BadArgError()
+
+ with self.assertRaises(BadArgError):
+ golden_message.SerializeToString(deterministic=BadArg())
+
def testPickleSupport(self, message_module):
golden_data = test_util.GoldenFileData('golden_message')
golden_message = message_module.TestAllTypes()
@@ -381,6 +417,7 @@ class MessageTest(BaseTestCase):
self.assertEqual(message.repeated_int32[0], 1)
self.assertEqual(message.repeated_int32[1], 2)
self.assertEqual(message.repeated_int32[2], 3)
+ self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
message.repeated_float.append(1.1)
message.repeated_float.append(1.3)
@@ -397,6 +434,7 @@ class MessageTest(BaseTestCase):
self.assertEqual(message.repeated_string[0], 'a')
self.assertEqual(message.repeated_string[1], 'b')
self.assertEqual(message.repeated_string[2], 'c')
+ self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
message.repeated_bytes.append(b'a')
message.repeated_bytes.append(b'c')
@@ -405,6 +443,7 @@ class MessageTest(BaseTestCase):
self.assertEqual(message.repeated_bytes[0], b'a')
self.assertEqual(message.repeated_bytes[1], b'b')
self.assertEqual(message.repeated_bytes[2], b'c')
+ self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
"""Check some different types with custom comparator."""
@@ -443,6 +482,8 @@ class MessageTest(BaseTestCase):
self.assertEqual(message.repeated_nested_message[3].bb, 4)
self.assertEqual(message.repeated_nested_message[4].bb, 5)
self.assertEqual(message.repeated_nested_message[5].bb, 6)
+ self.assertEqual(str(message.repeated_nested_message),
+ '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
def testSortingRepeatedCompositeFieldsStable(self, message_module):
"""Check passing a custom comparator to sort a repeated composite field."""
@@ -1270,6 +1311,14 @@ class Proto3Test(BaseTestCase):
self.assertEqual(1234567, m2.optional_nested_enum)
self.assertEqual(7654321, m2.repeated_nested_enum[0])
+ # ParseFromString in Proto2 should accept unknown enums too.
+ m3 = unittest_pb2.TestAllTypes()
+ m3.ParseFromString(serialized)
+ m2.Clear()
+ m2.ParseFromString(m3.SerializeToString())
+ self.assertEqual(1234567, m2.optional_nested_enum)
+ self.assertEqual(7654321, m2.repeated_nested_enum[0])
+
# Map isn't really a proto3-only feature. But there is no proto2 equivalent
# of google/protobuf/map_unittest.proto right now, so it's not easy to
# test both with the same test like we do for the other proto2/proto3 tests.
@@ -1441,6 +1490,23 @@ class Proto3Test(BaseTestCase):
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ def testNestedMessageMapItemDelete(self):
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_all_types[1].optional_nested_message.bb = 1
+ del msg.map_int32_all_types[1]
+ msg.map_int32_all_types[2].optional_nested_message.bb = 2
+ self.assertEqual(1, len(msg.map_int32_all_types))
+ msg.map_int32_all_types[1].optional_nested_message.bb = 1
+ self.assertEqual(2, len(msg.map_int32_all_types))
+
+ serialized = msg.SerializeToString()
+ msg2 = map_unittest_pb2.TestMap()
+ msg2.ParseFromString(serialized)
+ keys = [1, 2]
+ # The loop triggers PyErr_Occurred() in c extension.
+ for key in keys:
+ del msg2.map_int32_all_types[key]
+
def testMapByteSize(self):
msg = map_unittest_pb2.TestMap()
msg.map_int32_int32[1] = 1
@@ -1655,6 +1721,35 @@ class Proto3Test(BaseTestCase):
items2 = msg.map_string_string.items()
self.assertEqual(items1, items2)
+ def testMapDeterministicSerialization(self):
+ golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
+ b'r\n\n\x05item1\x12\x01e'
+ b'r\n\n\x05item2\x12\x01f'
+ b'r\n\n\x05item3\x12\x01g'
+ b'r\x0b\n\x05item4\x12\x02QQ'
+ b'r\x12\n\rlocal_init_op\x12\x01a'
+ b'r\x0e\n\tsummaries\x12\x01e'
+ b'r\x18\n\x13trainable_variables\x12\x01b'
+ b'r\x0e\n\tvariables\x12\x01c')
+ msg = map_unittest_pb2.TestMap()
+ msg.map_string_string['local_init_op'] = 'a'
+ msg.map_string_string['trainable_variables'] = 'b'
+ msg.map_string_string['variables'] = 'c'
+ msg.map_string_string['init_op'] = 'd'
+ msg.map_string_string['summaries'] = 'e'
+ msg.map_string_string['item1'] = 'e'
+ msg.map_string_string['item2'] = 'f'
+ msg.map_string_string['item3'] = 'g'
+ msg.map_string_string['item4'] = 'QQ'
+
+ # If deterministic serialization is not working correctly, this will be
+ # "flaky" depending on the exact python dict hash seed.
+ #
+ # Fortunately, there are enough items in this map that it is extremely
+ # unlikely to ever hit the "right" in-order combination, so the test
+ # itself should fail reliably.
+ self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
+
def testMapIterationClearMessage(self):
# Iterator needs to work even if message and map are deleted.
msg = map_unittest_pb2.TestMap()
diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto
index 11f85ef6..98fcbcb6 100644
--- a/python/google/protobuf/internal/more_extensions_dynamic.proto
+++ b/python/google/protobuf/internal/more_extensions_dynamic.proto
@@ -47,4 +47,5 @@ message DynamicMessageType {
extend ExtendedMessage {
optional int32 dynamic_int32_extension = 100;
optional DynamicMessageType dynamic_message_extension = 101;
+ repeated DynamicMessageType repeated_dynamic_message_extension = 102;
}
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index cb97cb28..c363d843 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -58,6 +58,7 @@ import weakref
import six
# We use "as" to avoid name collisions with variables.
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import containers
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
@@ -1026,29 +1027,34 @@ def _AddByteSizeMethod(message_descriptor, cls):
def _AddSerializeToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializeToString(self):
+ def SerializeToString(self, **kwargs):
# Check if the message has all of its required fields set.
errors = []
if not self.IsInitialized():
raise message_mod.EncodeError(
'Message %s is missing required fields: %s' % (
self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
- return self.SerializePartialToString()
+ return self.SerializePartialToString(**kwargs)
cls.SerializeToString = SerializeToString
def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializePartialToString(self):
+ def SerializePartialToString(self, **kwargs):
out = BytesIO()
- self._InternalSerialize(out.write)
+ self._InternalSerialize(out.write, **kwargs)
return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
- def InternalSerialize(self, write_bytes):
+ def InternalSerialize(self, write_bytes, deterministic=None):
+ if deterministic is None:
+ deterministic = (
+ api_implementation.IsPythonDefaultSerializationDeterministic())
+ else:
+ deterministic = bool(deterministic)
for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value)
+ field_descriptor._encoder(write_bytes, field_value, deterministic)
for tag_bytes, value_bytes in self._unknown_fields:
write_bytes(tag_bytes)
write_bytes(value_bytes)
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 188310b2..424b29cc 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -452,16 +452,18 @@ class TextFormatTest(TextFormatBase):
text_format.Parse(text_format.MessageToString(m), m2)
self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+ def testMergeMultipleOneof(self, message_module):
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ text_format.Merge(m_string, m2)
+ self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
+
def testParseMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
- if message_module is unittest_pb2:
- with self.assertRaisesRegexp(text_format.ParseError,
- ' is specified along with field '):
- text_format.Parse(m_string, m2)
- else:
+ with self.assertRaisesRegexp(text_format.ParseError,
+ ' is specified along with field '):
text_format.Parse(m_string, m2)
- self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
# These are tests that aren't fundamentally specific to proto2, but are at
@@ -1026,8 +1028,7 @@ class Proto3Tests(unittest.TestCase):
packed_message.data = 'string1'
message.repeated_any_value.add().Pack(packed_message)
self.assertEqual(
- text_format.MessageToString(message,
- descriptor_pool=descriptor_pool.Default()),
+ text_format.MessageToString(message),
'repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string0"\n'
@@ -1039,18 +1040,6 @@ class Proto3Tests(unittest.TestCase):
' }\n'
'}\n')
- def testPrintMessageExpandAnyNoDescriptorPool(self):
- packed_message = unittest_pb2.OneString()
- packed_message.data = 'string'
- message = any_test_pb2.TestAny()
- message.any_value.Pack(packed_message)
- self.assertEqual(
- text_format.MessageToString(message, descriptor_pool=None),
- 'any_value {\n'
- ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
- ' value: "\\n\\006string"\n'
- '}\n')
-
def testPrintMessageExpandAnyDescriptorPoolMissingType(self):
packed_message = unittest_pb2.OneString()
packed_message.data = 'string'
@@ -1071,8 +1060,7 @@ class Proto3Tests(unittest.TestCase):
message.any_value.Pack(packed_message)
self.assertEqual(
text_format.MessageToString(message,
- pointy_brackets=True,
- descriptor_pool=descriptor_pool.Default()),
+ pointy_brackets=True),
'any_value <\n'
' [type.googleapis.com/protobuf_unittest.OneString] <\n'
' data: "string"\n'
@@ -1086,8 +1074,7 @@ class Proto3Tests(unittest.TestCase):
message.any_value.Pack(packed_message)
self.assertEqual(
text_format.MessageToString(message,
- as_one_line=True,
- descriptor_pool=descriptor_pool.Default()),
+ as_one_line=True),
'any_value {'
' [type.googleapis.com/protobuf_unittest.OneString]'
' { data: "string" } '
@@ -1115,12 +1102,12 @@ class Proto3Tests(unittest.TestCase):
' data: "string"\n'
' }\n'
'}\n')
- text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ text_format.Merge(text, message)
packed_message = unittest_pb2.OneString()
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
message.Clear()
- text_format.Parse(text, message, descriptor_pool=descriptor_pool.Default())
+ text_format.Parse(text, message)
packed_message = unittest_pb2.OneString()
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
@@ -1137,7 +1124,7 @@ class Proto3Tests(unittest.TestCase):
' data: "string1"\n'
' }\n'
'}\n')
- text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ text_format.Merge(text, message)
packed_message = unittest_pb2.OneString()
message.repeated_any_value[0].Unpack(packed_message)
self.assertEqual('string0', packed_message.data)
@@ -1151,22 +1138,22 @@ class Proto3Tests(unittest.TestCase):
' data: "string"\n'
' >\n'
'}\n')
- text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ text_format.Merge(text, message)
packed_message = unittest_pb2.OneString()
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
- def testMergeExpandedAnyNoDescriptorPool(self):
+ def testMergeAlternativeUrl(self):
message = any_test_pb2.TestAny()
text = ('any_value {\n'
- ' [type.googleapis.com/protobuf_unittest.OneString] {\n'
+ ' [type.otherapi.com/protobuf_unittest.OneString] {\n'
' data: "string"\n'
' }\n'
'}\n')
- with self.assertRaises(text_format.ParseError) as e:
- text_format.Merge(text, message, descriptor_pool=None)
- self.assertEqual(str(e.exception),
- 'Descriptor pool required to parse expanded Any field')
+ text_format.Merge(text, message)
+ packed_message = unittest_pb2.OneString()
+ self.assertEqual('type.otherapi.com/protobuf_unittest.OneString',
+ message.any_value.type_url)
def testMergeExpandedAnyDescriptorPoolMissingType(self):
message = any_test_pb2.TestAny()
@@ -1425,5 +1412,101 @@ class TokenizerTest(unittest.TestCase):
tokenizer.ConsumeCommentOrTrailingComment())
self.assertTrue(tokenizer.AtEnd())
+
+# Tests for pretty printer functionality.
+@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2))
+class PrettyPrinterTest(TextFormatBase):
+
+ def testPrettyPrintNoMatch(self, message_module):
+
+ def printer(message, indent, as_one_line):
+ del message, indent, as_one_line
+ return None
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { bb: 42 }')
+
+ def testPrettyPrintOneLine(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ return 'My lucky number is %s' % m.bb
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { My lucky number is 42 }')
+
+ def testPrettyPrintMultiLine(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ line_deliminator = (' ' if as_one_line else '\n') + ' ' * indent
+ return 'My lucky number is:%s%s' % (line_deliminator, m.bb)
+ return None
+
+ message = message_module.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'repeated_nested_message { My lucky number is: 42 }')
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=False, message_formatter=printer),
+ 'repeated_nested_message {\n My lucky number is:\n 42\n}\n')
+
+ def testPrettyPrintEntireMessage(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.DESCRIPTOR:
+ return 'The is the message!'
+ return None
+
+ message = message_module.TestAllTypes()
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=False, message_formatter=printer),
+ 'The is the message!\n')
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ 'The is the message!')
+
+ def testPrettyPrintMultipleParts(self, message_module):
+
+ def printer(m, indent, as_one_line):
+ del indent, as_one_line
+ if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR:
+ return 'My lucky number is %s' % m.bb
+ return None
+
+ message = message_module.TestAllTypes()
+ message.optional_int32 = 61
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ msg = message.repeated_nested_message.add()
+ msg.bb = 99
+ msg = message.optional_nested_message
+ msg.bb = 1
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, as_one_line=True, message_formatter=printer),
+ ('optional_int32: 61 '
+ 'optional_nested_message { My lucky number is 1 } '
+ 'repeated_nested_message { My lucky number is 42 } '
+ 'repeated_nested_message { My lucky number is 99 }'))
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index d631abee..d0c7ffda 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -350,12 +350,12 @@ class Duration(object):
self.nanos, _NANOS_PER_MICROSECOND))
def FromTimedelta(self, td):
- """Convertd timedelta to Duration."""
+ """Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
def _NormalizeDuration(self, seconds, nanos):
- """Set Duration by seconds and nonas."""
+ """Set Duration by seconds and nanos."""
# Force nanos to be negative if the duration is negative.
if seconds < 0 and nanos > 0:
seconds += 1
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 077f630f..123a537c 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -284,7 +284,7 @@ class TimeUtilTest(TimeUtilTestBase):
'1972-01-01T01:00:00.01+08',)
self.assertRaisesRegexp(
ValueError,
- 'year is out of range',
+ 'year (0 )?is out of range',
message.FromJsonString,
'0000-01-01T00:00:00Z')
message.seconds = 253402300800
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index 937285b0..801eed60 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -74,6 +74,9 @@ _UNPAIRED_SURROGATE_PATTERN = re.compile(six.u(
r'[\ud800-\udbff](?![\udc00-\udfff])|(?<![\ud800-\udbff])[\udc00-\udfff]'
))
+_VALID_EXTENSION_NAME = re.compile(r'\[[a-zA-Z0-9\._]*\]$')
+
+
class Error(Exception):
"""Top-level module error for json_format."""
@@ -88,7 +91,9 @@ class ParseError(Error):
def MessageToJson(message,
including_default_value_fields=False,
- preserving_proto_field_name=False):
+ preserving_proto_field_name=False,
+ indent=2,
+ sort_keys=False):
"""Converts protobuf message to JSON format.
Args:
@@ -100,19 +105,24 @@ def MessageToJson(message,
preserving_proto_field_name: If True, use the original proto field
names as defined in the .proto file. If False, convert the field
names to lowerCamelCase.
+ indent: The JSON object will be pretty-printed with this indent level.
+ An indent level of 0 or negative will only insert newlines.
+ sort_keys: If True, then the output will be sorted by field names.
Returns:
A string containing the JSON formatted protocol buffer message.
"""
printer = _Printer(including_default_value_fields,
preserving_proto_field_name)
- return printer.ToJsonString(message)
+ return printer.ToJsonString(message, indent, sort_keys)
def MessageToDict(message,
including_default_value_fields=False,
preserving_proto_field_name=False):
- """Converts protobuf message to a JSON dictionary.
+ """Converts protobuf message to a dictionary.
+
+ When the dictionary is encoded to JSON, it conforms to proto3 JSON spec.
Args:
message: The protocol buffers message instance to serialize.
@@ -125,7 +135,7 @@ def MessageToDict(message,
names to lowerCamelCase.
Returns:
- A dict representation of the JSON formatted protocol buffer message.
+ A dict representation of the protocol buffer message.
"""
printer = _Printer(including_default_value_fields,
preserving_proto_field_name)
@@ -148,9 +158,9 @@ class _Printer(object):
self.including_default_value_fields = including_default_value_fields
self.preserving_proto_field_name = preserving_proto_field_name
- def ToJsonString(self, message):
+ def ToJsonString(self, message, indent, sort_keys):
js = self._MessageToJsonObject(message)
- return json.dumps(js, indent=2)
+ return json.dumps(js, indent=indent, sort_keys=sort_keys)
def _MessageToJsonObject(self, message):
"""Converts message to an object according to Proto3 JSON Specification."""
@@ -192,6 +202,14 @@ class _Printer(object):
# Convert a repeated field.
js[name] = [self._FieldToJsonObject(field, k)
for k in value]
+ elif field.is_extension:
+ f = field
+ if (f.containing_type.GetOptions().message_set_wire_format and
+ f.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ f.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
+ f = f.message_type
+ name = '[%s.%s]' % (f.full_name, name)
+ js[name] = self._FieldToJsonObject(field, value)
else:
js[name] = self._FieldToJsonObject(field, value)
@@ -433,12 +451,23 @@ class _Parser(object):
field = fields_by_json_name.get(name, None)
if not field:
field = message_descriptor.fields_by_name.get(name, None)
+ if not field and _VALID_EXTENSION_NAME.match(name):
+ if not message_descriptor.is_extendable:
+ raise ParseError('Message type {0} does not have extensions'.format(
+ message_descriptor.full_name))
+ identifier = name[1:-1] # strip [] brackets
+ identifier = '.'.join(identifier.split('.')[:-1])
+ # pylint: disable=protected-access
+ field = message.Extensions._FindExtensionByName(identifier)
+ # pylint: enable=protected-access
if not field:
if self.ignore_unknown_fields:
continue
raise ParseError(
- 'Message type "{0}" has no field named "{1}".'.format(
- message_descriptor.full_name, name))
+ ('Message type "{0}" has no field named "{1}".\n'
+ ' Available Fields(except extensions): {2}').format(
+ message_descriptor.full_name, name,
+ message_descriptor.fields))
if name in names:
raise ParseError('Message type "{0}" should not have multiple '
'"{1}" fields.'.format(
@@ -491,7 +520,10 @@ class _Parser(object):
getattr(message, field.name).append(
_ConvertScalarFieldValue(item, field))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- sub_message = getattr(message, field.name)
+ if field.is_extension:
+ sub_message = message.Extensions[field]
+ else:
+ sub_message = getattr(message, field.name)
sub_message.SetInParent()
self.ConvertMessage(value, sub_message)
else:
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index aab250e4..eeb0d576 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -184,9 +184,15 @@ class Message(object):
self.Clear()
self.MergeFromString(serialized)
- def SerializeToString(self):
+ def SerializeToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
+ Arguments:
+ **kwargs: Keyword arguments to the serialize method, accepts
+ the following keyword args:
+ deterministic: If true, requests deterministic serialization of the
+ protobuf, with predictable ordering of map keys.
+
Returns:
A binary string representation of the message if all of the required
fields in the message are set (i.e. the message is initialized).
@@ -196,12 +202,18 @@ class Message(object):
"""
raise NotImplementedError
- def SerializePartialToString(self):
+ def SerializePartialToString(self, **kwargs):
"""Serializes the protocol message to a binary string.
This method is similar to SerializeToString but doesn't check if the
message is initialized.
+ Arguments:
+ **kwargs: Keyword arguments to the serialize method, accepts
+ the following keyword args:
+ deterministic: If true, requests deterministic serialization of the
+ protobuf, with predictable ordering of map keys.
+
Returns:
A string representation of the partial message.
"""
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 8ab1c513..15740280 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -66,7 +66,7 @@ class MessageFactory(object):
Returns:
A class describing the passed in descriptor.
"""
- if descriptor.full_name not in self._classes:
+ if descriptor not in self._classes:
descriptor_name = descriptor.name
if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
@@ -75,16 +75,16 @@ class MessageFactory(object):
(message.Message,),
{'DESCRIPTOR': descriptor, '__module__': None})
# If module not set, it wrongly points to the reflection.py module.
- self._classes[descriptor.full_name] = result_class
+ self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:
self.GetPrototype(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
- if extension.containing_type.full_name not in self._classes:
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
- return self._classes[descriptor.full_name]
+ return self._classes[descriptor]
def GetMessages(self, files):
"""Gets all the messages from a specified file.
@@ -116,9 +116,9 @@ class MessageFactory(object):
# an error if they were different.
for extension in file_desc.extensions_by_name.values():
- if extension.containing_type.full_name not in self._classes:
+ if extension.containing_type not in self._classes:
self.GetPrototype(extension.containing_type)
- extended_class = self._classes[extension.containing_type.full_name]
+ extended_class = self._classes[extension.containing_type]
extended_class.RegisterExtension(extension)
return result
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index f13e1bc1..9634ea05 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -709,6 +709,10 @@ static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) {
return PyString_FromCppString(_GetDescriptor(self)->json_name());
}
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
return PyInt_FromLong(_GetDescriptor(self)->type());
}
@@ -899,6 +903,7 @@ static PyGetSetDef Getters[] = {
{ "name", (getter)GetName, NULL, "Unqualified name"},
{ "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"},
{ "json_name", (getter)GetJsonName, NULL, "Json name"},
+ { "file", (getter)GetFile, NULL, "File Descriptor"},
{ "type", (getter)GetType, NULL, "C++ Type"},
{ "cpp_type", (getter)GetCppType, NULL, "C++ Type"},
{ "label", (getter)GetLabel, NULL, "Label"},
@@ -1570,6 +1575,10 @@ static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) {
return PyString_FromCppString(_GetDescriptor(self)->full_name());
}
+static PyObject* GetFile(PyBaseDescriptor *self, void *closure) {
+ return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file());
+}
+
static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) {
return PyInt_FromLong(_GetDescriptor(self)->index());
}
@@ -1611,6 +1620,7 @@ static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) {
static PyGetSetDef Getters[] = {
{ "name", (getter)GetName, NULL, "Name", NULL},
{ "full_name", (getter)GetFullName, NULL, "Full name", NULL},
+ { "file", (getter)GetFile, NULL, "File descriptor"},
{ "index", (getter)GetIndex, NULL, "Index", NULL},
{ "methods", (getter)GetMethods, NULL, "Methods", NULL},
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 088ddf93..43be0701 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -712,8 +712,30 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
}
// Delete key from map.
- if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
map_key)) {
+ // Delete key from CMessage dict.
+ MapValueRef value;
+ reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
+ map_key, &value);
+ ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue()));
+
+ // PyDict_DelItem will have key error if the key is not in the map. We do
+ // not want to call PyErr_Clear() which may clear other errors. Thus
+ // PyDict_Contains() check is called before delete.
+ int contains = PyDict_Contains(self->message_dict, key.get());
+ if (contains < 0) {
+ return -1;
+ }
+ if (contains) {
+ if (PyDict_DelItem(self->message_dict, key.get()) < 0) {
+ return -1;
+ }
+ }
+
+ // Delete key from map.
+ reflection->DeleteMapValue(message, self->parent_field_descriptor,
+ map_key);
return 0;
} else {
PyErr_Format(PyExc_KeyError, "Key not present in map");
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 49113c7c..702c5d03 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -52,6 +52,7 @@
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/util/message_differencer.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
@@ -1810,8 +1811,25 @@ static string GetMessageName(CMessage* self) {
}
}
-static PyObject* SerializeToString(CMessage* self, PyObject* args) {
- if (!self->message->IsInitialized()) {
+static PyObject* InternalSerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs,
+ bool require_initialized) {
+ // Parse the "deterministic" kwarg; defaults to False.
+ static char* kwlist[] = { "deterministic", 0 };
+ PyObject* deterministic_obj = Py_None;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist,
+ &deterministic_obj)) {
+ return NULL;
+ }
+ // Preemptively convert to a bool first, so we don't need to back out of
+ // allocating memory if this raises an exception.
+ // NOTE: This is unused later if deterministic == Py_None, but that's fine.
+ int deterministic = PyObject_IsTrue(deterministic_obj);
+ if (deterministic < 0) {
+ return NULL;
+ }
+
+ if (require_initialized && !self->message->IsInitialized()) {
ScopedPyObjectPtr errors(FindInitializationErrors(self));
if (errors == NULL) {
return NULL;
@@ -1849,24 +1867,36 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) {
GetMessageName(self).c_str(), PyString_AsString(joined.get()));
return NULL;
}
- int size = self->message->ByteSize();
- if (size <= 0) {
+
+ // Ok, arguments parsed and errors checked, now encode to a string
+ const size_t size = self->message->ByteSizeLong();
+ if (size == 0) {
return PyBytes_FromString("");
}
PyObject* result = PyBytes_FromStringAndSize(NULL, size);
if (result == NULL) {
return NULL;
}
- char* buffer = PyBytes_AS_STRING(result);
- self->message->SerializeWithCachedSizesToArray(
- reinterpret_cast<uint8*>(buffer));
+ io::ArrayOutputStream out(PyBytes_AS_STRING(result), size);
+ io::CodedOutputStream coded_out(&out);
+ if (deterministic_obj != Py_None) {
+ coded_out.SetSerializationDeterministic(deterministic);
+ }
+ self->message->SerializeWithCachedSizes(&coded_out);
+ GOOGLE_CHECK(!coded_out.HadError());
return result;
}
-static PyObject* SerializePartialToString(CMessage* self) {
- string contents;
- self->message->SerializePartialToString(&contents);
- return PyBytes_FromStringAndSize(contents.c_str(), contents.size());
+static PyObject* SerializeToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/true);
+}
+
+static PyObject* SerializePartialToString(
+ CMessage* self, PyObject* args, PyObject* kwargs) {
+ return InternalSerializeToString(self, args, kwargs,
+ /*require_initialized=*/false);
}
// Formats proto fields for ascii dumps using python formatting functions where
@@ -2537,7 +2567,10 @@ PyObject* Reduce(CMessage* self) {
if (state == NULL) {
return NULL;
}
- ScopedPyObjectPtr serialized(SerializePartialToString(self));
+ string contents;
+ self->message->SerializePartialToString(&contents);
+ ScopedPyObjectPtr serialized(
+ PyBytes_FromStringAndSize(contents.c_str(), contents.size()));
if (serialized == NULL) {
return NULL;
}
@@ -2658,9 +2691,10 @@ static PyMethodDef Methods[] = {
{ "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
"Registers an extension with the current message." },
{ "SerializePartialToString", (PyCFunction)SerializePartialToString,
- METH_NOARGS,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, even if it isn't initialized." },
- { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS,
+ { "SerializeToString", (PyCFunction)SerializeToString,
+ METH_VARARGS | METH_KEYWORDS,
"Serializes the message to a string, only for initialized messages." },
{ "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
"Sets the has bit of the given field in its parent message." },
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
index e0b45bf2..571bae2b 100644
--- a/python/google/protobuf/pyext/message_factory.cc
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -133,11 +133,7 @@ int RegisterMessageClass(PyMessageFactory* self,
CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
const Descriptor* descriptor) {
// This is the same implementation as MessageFactory.GetPrototype().
- ScopedPyObjectPtr py_descriptor(
- PyMessageDescriptor_FromDescriptor(descriptor));
- if (py_descriptor == NULL) {
- return NULL;
- }
+
// Do not create a MessageClass that already exists.
hash_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
@@ -145,6 +141,11 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
Py_INCREF(it->second);
return it->second;
}
+ ScopedPyObjectPtr py_descriptor(
+ PyMessageDescriptor_FromDescriptor(descriptor));
+ if (py_descriptor == NULL) {
+ return NULL;
+ }
// Create a new message class.
ScopedPyObjectPtr args(Py_BuildValue(
"s(){sOsOsO}", descriptor->name().c_str(),
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index 06a976de..5ad71db5 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -46,6 +46,7 @@
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/reflection.h>
@@ -137,9 +138,12 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self,
if (cmessage::AssureWritable(self->parent) == -1)
return NULL;
Message* message = self->message;
+
Message* sub_message =
- message->GetReflection()->AddMessage(message,
- self->parent_field_descriptor);
+ message->GetReflection()->AddMessage(
+ message,
+ self->parent_field_descriptor,
+ self->child_message_class->py_message_factory->message_factory);
CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class);
if (cmsg == NULL)
return NULL;
@@ -336,6 +340,18 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self,
}
}
+static PyObject* ToStr(RepeatedCompositeContainer* self) {
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
+ if (list == NULL) {
+ return NULL;
+ }
+ return PyObject_Repr(list.get());
+}
+
// ---------------------------------------------------------------------
// sort()
@@ -608,7 +624,7 @@ PyTypeObject RepeatedCompositeContainer_Type = {
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ (reprfunc)repeated_composite_container::ToStr, // tp_repr
0, // tp_as_number
&repeated_composite_container::SqMethods, // tp_as_sequence
&repeated_composite_container::MpMethods, // tp_as_mapping
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index 9fedb952..54998800 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -659,6 +659,18 @@ static PyObject* Pop(RepeatedScalarContainer* self,
return item;
}
+static PyObject* ToStr(RepeatedScalarContainer* self) {
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(self, full_slice.get()));
+ if (list == NULL) {
+ return NULL;
+ }
+ return PyObject_Repr(list.get());
+}
+
// The private constructor of RepeatedScalarContainer objects.
PyObject *NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor) {
@@ -781,7 +793,7 @@ PyTypeObject RepeatedScalarContainer_Type = {
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ (reprfunc)repeated_scalar_container::ToStr, // tp_repr
0, // tp_as_number
&repeated_scalar_container::SqMethods, // tp_as_sequence
&repeated_scalar_container::MpMethods, // tp_as_mapping
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index 07341efa..5ad869f4 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -78,10 +78,18 @@ class SymbolDatabase(message_factory.MessageFactory):
"""
desc = message.DESCRIPTOR
- self._classes[desc.full_name] = message
- self.pool.AddDescriptor(desc)
+ self._classes[desc] = message
+ self.RegisterMessageDescriptor(desc)
return message
+ def RegisterMessageDescriptor(self, message_descriptor):
+ """Registers the given message descriptor in the local database.
+
+ Args:
+ message_descriptor: a descriptor.MessageDescriptor.
+ """
+ self.pool.AddDescriptor(message_descriptor)
+
def RegisterEnumDescriptor(self, enum_descriptor):
"""Registers the given enum descriptor in the local database.
@@ -132,7 +140,7 @@ class SymbolDatabase(message_factory.MessageFactory):
KeyError: if the symbol could not be found.
"""
- return self._classes[symbol]
+ return self._classes[self.pool.FindMessageTypeByName(symbol)]
def GetMessages(self, files):
# TODO(amauryfa): Fix the differences with MessageFactory.
@@ -153,20 +161,20 @@ class SymbolDatabase(message_factory.MessageFactory):
KeyError: if a file could not be found.
"""
- def _GetAllMessageNames(desc):
+ def _GetAllMessages(desc):
"""Walk a message Descriptor and recursively yields all message names."""
- yield desc.full_name
+ yield desc
for msg_desc in desc.nested_types:
- for full_name in _GetAllMessageNames(msg_desc):
- yield full_name
+ for nested_desc in _GetAllMessages(msg_desc):
+ yield nested_desc
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
for msg_desc in file_desc.message_types_by_name.values():
- for full_name in _GetAllMessageNames(msg_desc):
+ for desc in _GetAllMessages(msg_desc):
try:
- result[full_name] = self._classes[full_name]
+ result[desc.full_name] = self._classes[desc]
except KeyError:
# This descriptor has no registered class, skip it.
pass
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index c216e097..aaca78ad 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -126,7 +126,8 @@ def MessageToString(message,
float_format=None,
use_field_number=False,
descriptor_pool=None,
- indent=0):
+ indent=0,
+ message_formatter=None):
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -148,6 +149,9 @@ def MessageToString(message,
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
indent: The indent level, in terms of spaces, for pretty print.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
Returns:
A string of the text formatted protocol buffer message.
@@ -155,7 +159,7 @@ def MessageToString(message,
out = TextWriter(as_utf8)
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
result = out.getvalue()
out.close()
@@ -179,10 +183,11 @@ def PrintMessage(message,
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
use_index_order, float_format, use_field_number,
- descriptor_pool)
+ descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -194,10 +199,11 @@ def PrintField(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field name/value pair."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintField(field, value)
@@ -209,10 +215,11 @@ def PrintFieldValue(field,
as_one_line=False,
pointy_brackets=False,
use_index_order=False,
- float_format=None):
+ float_format=None,
+ message_formatter=None):
"""Print a single field value (not including name)."""
printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
- use_index_order, float_format)
+ use_index_order, float_format, message_formatter)
printer.PrintFieldValue(field, value)
@@ -228,6 +235,9 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool):
wasn't found matching type_name.
"""
# pylint: disable=g-import-not-at-top
+ if descriptor_pool is None:
+ from google.protobuf import descriptor_pool as pool_mod
+ descriptor_pool = pool_mod.Default()
from google.protobuf import symbol_database
database = symbol_database.Default()
try:
@@ -250,7 +260,8 @@ class _Printer(object):
use_index_order=False,
float_format=None,
use_field_number=False,
- descriptor_pool=None):
+ descriptor_pool=None,
+ message_formatter=None):
"""Initialize the Printer.
Floating point values can be formatted compactly with 15 digits of
@@ -273,6 +284,9 @@ class _Printer(object):
used.
use_field_number: If True, print field numbers instead of names.
descriptor_pool: A DescriptorPool used to resolve Any types.
+ message_formatter: A function(message, indent, as_one_line): unicode|None
+ to custom format selected sub-messages (usually based on message type).
+ Use to pretty print parts of the protobuf for easier diffing.
"""
self.out = out
self.indent = indent
@@ -283,6 +297,7 @@ class _Printer(object):
self.float_format = float_format
self.use_field_number = use_field_number
self.descriptor_pool = descriptor_pool
+ self.message_formatter = message_formatter
def _TryPrintAsAnyMessage(self, message):
"""Serializes if message is a google.protobuf.Any field."""
@@ -297,14 +312,27 @@ class _Printer(object):
else:
return False
+ def _TryCustomFormatMessage(self, message):
+ formatted = self.message_formatter(message, self.indent, self.as_one_line)
+ if formatted is None:
+ return False
+
+ out = self.out
+ out.write(' ' * self.indent)
+ out.write(formatted)
+ out.write(' ' if self.as_one_line else '\n')
+ return True
+
def PrintMessage(self, message):
"""Convert protobuf message to text format.
Args:
message: The protocol buffers message.
"""
+ if self.message_formatter and self._TryCustomFormatMessage(message):
+ return
if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and
- self.descriptor_pool and self._TryPrintAsAnyMessage(message)):
+ self._TryPrintAsAnyMessage(message)):
return
fields = message.ListFields()
if self.use_index_order:
@@ -426,6 +454,22 @@ def Parse(text,
descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
+ NOTE: for historical reasons this function does not clear the input
+ message. This is different from what the binary msg.ParseFrom(...) does.
+
+ Example
+ a = MyProto()
+ a.repeated_field.append('test')
+ b = MyProto()
+
+ text_format.Parse(repr(a), b)
+ text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"]
+
+ # Binary version:
+ b.ParseFromString(a.SerializeToString()) # repeated_field is now "test"
+
+ Caller is responsible for clearing the message as needed.
+
Args:
text: Message text representation.
message: A protocol buffer message to merge into.
@@ -593,11 +637,6 @@ class _Parser(object):
ParseError: In case of text parsing problems.
"""
message_descriptor = message.DESCRIPTOR
- if (hasattr(message_descriptor, 'syntax') and
- message_descriptor.syntax == 'proto3'):
- # Proto3 doesn't represent presence so we can't test if multiple
- # scalars have occurred. We have to allow them.
- self._allow_multiple_scalars = True
if tokenizer.TryConsume('['):
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
@@ -616,7 +655,11 @@ class _Parser(object):
field = None
else:
raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
+ 'Extension "%s" not registered. '
+ 'Did you import the _pb2 module which defines it? '
+ 'If you are trying to place the extension in the MessageSet '
+ 'field of another message that is in an Any or MessageSet field, '
+ 'that message\'s _pb2 module must be imported as well' % name)
elif message_descriptor != field.containing_type:
raise tokenizer.ParseErrorPreviousToken(
'Extension "%s" does not extend message type "%s".' %
@@ -695,17 +738,17 @@ class _Parser(object):
def _ConsumeAnyTypeUrl(self, tokenizer):
"""Consumes a google.protobuf.Any type URL and returns the type name."""
# Consume "type.googleapis.com/".
- tokenizer.ConsumeIdentifier()
+ prefix = [tokenizer.ConsumeIdentifier()]
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('.')
- tokenizer.ConsumeIdentifier()
+ prefix.append(tokenizer.ConsumeIdentifier())
tokenizer.Consume('/')
# Consume the fully-qualified type name.
name = [tokenizer.ConsumeIdentifier()]
while tokenizer.TryConsume('.'):
name.append(tokenizer.ConsumeIdentifier())
- return '.'.join(name)
+ return '.'.join(prefix), '.'.join(name)
def _MergeMessageField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -728,7 +771,7 @@ class _Parser(object):
if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and
tokenizer.TryConsume('[')):
- packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
+ type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer)
tokenizer.Consume(']')
tokenizer.TryConsume(':')
if tokenizer.TryConsume('<'):
@@ -736,8 +779,6 @@ class _Parser(object):
else:
tokenizer.Consume('{')
expanded_any_end_token = '}'
- if not self.descriptor_pool:
- raise ParseError('Descriptor pool required to parse expanded Any field')
expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name,
self.descriptor_pool)
if not expanded_any_sub_message:
@@ -752,7 +793,8 @@ class _Parser(object):
any_message = getattr(message, field.name).add()
else:
any_message = getattr(message, field.name)
- any_message.Pack(expanded_any_sub_message)
+ any_message.Pack(expanded_any_sub_message,
+ type_url_prefix=type_url_prefix)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_extension:
sub_message = message.Extensions[field].add()
@@ -780,6 +822,12 @@ class _Parser(object):
else:
getattr(message, field.name)[sub_message.key] = sub_message.value
+ @staticmethod
+ def _IsProto3Syntax(message):
+ message_descriptor = message.DESCRIPTOR
+ return (hasattr(message_descriptor, 'syntax') and
+ message_descriptor.syntax == 'proto3')
+
def _MergeScalarField(self, tokenizer, message, field):
"""Merges a single scalar field into a message.
@@ -829,15 +877,20 @@ class _Parser(object):
else:
getattr(message, field.name).append(value)
else:
+ # Proto3 doesn't represent presence so we can't test if multiple scalars
+ # have occurred. We have to allow them.
+ can_check_presence = not self._IsProto3Syntax(message)
if field.is_extension:
- if not self._allow_multiple_scalars and message.HasExtension(field):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasExtension(field)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" extensions.' %
(message.DESCRIPTOR.full_name, field.full_name))
else:
message.Extensions[field] = value
else:
- if not self._allow_multiple_scalars and message.HasField(field.name):
+ if (not self._allow_multiple_scalars and can_check_presence and
+ message.HasField(field.name)):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" should not have multiple "%s" fields.' %
(message.DESCRIPTOR.full_name, field.name))
@@ -1088,7 +1141,7 @@ class Tokenizer(object):
"""
result = self.token
if not self._IDENTIFIER_OR_NUMBER.match(result):
- raise self.ParseError('Expected identifier or number.')
+ raise self.ParseError('Expected identifier or number, got %s.' % result)
self.NextToken()
return result