diff options
Diffstat (limited to 'python/google')
27 files changed, 749 insertions, 204 deletions
diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py index 622dfb3d..d26da0df 100755 --- a/python/google/protobuf/__init__.py +++ b/python/google/protobuf/__init__.py @@ -30,7 +30,7 @@ # Copyright 2007 Google Inc. All Rights Reserved. -__version__ = '3.3.2' +__version__ = '3.4.0' if __name__ != '__main__': try: diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index e1f2e3b7..b1f3ca38 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -406,6 +406,8 @@ class FieldDescriptor(DescriptorBase): containing_oneof: (OneofDescriptor) If the field is a member of a oneof union, contains its descriptor. Otherwise, None. + + file: (FileDescriptor) Reference to file descriptor. """ # Must be consistent with C++ FieldDescriptor::Type enum in @@ -490,7 +492,8 @@ class FieldDescriptor(DescriptorBase): def __new__(cls, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True, containing_oneof=None, json_name=None): + has_default_value=True, containing_oneof=None, json_name=None, + file=None): _message.Message._CheckCalledFromGeneratedFile() if is_extension: return _message.default_pool.FindExtensionByName(full_name) @@ -500,7 +503,8 @@ class FieldDescriptor(DescriptorBase): def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True, containing_oneof=None, json_name=None): + has_default_value=True, containing_oneof=None, json_name=None, + file=None): """The arguments are as described in the description of FieldDescriptor attributes above. @@ -511,6 +515,7 @@ class FieldDescriptor(DescriptorBase): super(FieldDescriptor, self).__init__(options, 'FieldOptions') self.name = name self.full_name = full_name + self.file = file self._camelcase_name = None if json_name is None: self.json_name = _ToJsonName(name) diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index 40bcdd72..eb45e127 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -134,8 +134,7 @@ def _ExtractSymbols(desc_proto, package): Yields: The fully qualified name found in the descriptor. """ - - message_name = '.'.join((package, desc_proto.name)) + message_name = package + '.' + desc_proto.name if package else desc_proto.name yield message_name for nested_type in desc_proto.nested_type: for symbol in _ExtractSymbols(nested_type, message_name): diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 29e9f1e3..3dbe0fd0 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -329,6 +329,11 @@ class DescriptorPool(object): pass try: + return self._service_descriptors[symbol].file + except KeyError: + pass + + try: return self._FindFileContainingSymbolInDb(symbol) except KeyError: pass @@ -344,7 +349,6 @@ class DescriptorPool(object): message = self.FindMessageTypeByName(message_name) assert message.extensions_by_name[extension_name] return message.file - except KeyError: raise KeyError('Cannot find a file containing %s' % symbol) @@ -557,7 +561,8 @@ class DescriptorPool(object): for index, extension_proto in enumerate(file_proto.extension): extension_desc = self._MakeFieldDescriptor( - extension_proto, file_proto.package, index, is_extension=True) + extension_proto, file_proto.package, index, file_descriptor, + is_extension=True) extension_desc.containing_type = self._GetTypeFromScope( file_descriptor.package, extension_proto.extendee, scope) self._SetFieldType(extension_proto, extension_desc, @@ -623,10 +628,10 @@ class DescriptorPool(object): enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self._MakeFieldDescriptor(field, desc_name, index) + fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) for index, field in enumerate(desc_proto.field)] extensions = [ - self._MakeFieldDescriptor(extension, desc_name, index, + self._MakeFieldDescriptor(extension, desc_name, index, file_desc, is_extension=True) for index, extension in enumerate(desc_proto.extension)] oneofs = [ @@ -708,7 +713,7 @@ class DescriptorPool(object): return desc def _MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + file_desc, is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -721,6 +726,7 @@ class DescriptorPool(object): field_proto: The proto describing the field. message_name: The name of the containing message. index: Index of the field + file_desc: The file containing the field descriptor. is_extension: Indication that this field is for an extension. Returns: @@ -747,7 +753,8 @@ class DescriptorPool(object): default_value=None, is_extension=is_extension, extension_scope=None, - options=_OptionsOrNone(field_proto)) + options=_OptionsOrNone(field_proto), + file=file_desc) def _SetAllFieldTypes(self, package, desc_proto, scope): """Sets all the descriptor's fields's types. diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index 460a4a6c..422af590 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -100,6 +100,27 @@ if _implementation_version_str != '2': _implementation_version = int(_implementation_version_str) +# Detect if serialization should be deterministic by default +try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps. + # + # NOTE: Merely importing this automatically enables deterministic proto + # serialization for C++ code, but we still need to export it as a boolean so + # that we can do the same for `_implementation_type == 'python'`. + # + # NOTE2: It is possible for C++ code to enable deterministic serialization by + # default _without_ affecting Python code, if the C++ implementation is not in + # use by this module. That is intended behavior, so we don't actually expose + # this boolean outside of this module. + # + # pylint: disable=g-import-not-at-top,unused-import + from google.protobuf import enable_deterministic_proto_serialization + _python_deterministic_proto_serialization = True +except ImportError: + _python_deterministic_proto_serialization = False + + # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. @@ -111,3 +132,8 @@ def Type(): # See comment on 'Type' above. def Version(): return _implementation_version + + +# For internal use only +def IsPythonDefaultSerializationDeterministic(): + return _python_deterministic_proto_serialization diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index c1733a48..6015e6f8 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -131,11 +131,19 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google/protobuf/internal/factory_test2.proto', file_desc4.name) + file_desc5 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestService') + self.assertIsInstance(file_desc5, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc5.name) + # Tests the generated pool. assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.one_more_field') assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.another_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'protobuf_unittest.TestService') def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): @@ -506,10 +514,10 @@ class MessageType(object): subtype.CheckType(test, desc, name, file_desc) for index, (name, field) in enumerate(self.field_list): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) for index, (name, field) in enumerate(self.extensions): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) class EnumField(object): @@ -519,7 +527,7 @@ class EnumField(object): self.type_name = type_name self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] enum_desc = msg_desc.enum_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -536,6 +544,7 @@ class EnumField(object): test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) + test.assertEqual(file_desc, enum_desc.file) class MessageField(object): @@ -544,7 +553,7 @@ class MessageField(object): self.number = number self.type_name = type_name - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] field_type_desc = msg_desc.nested_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -558,6 +567,7 @@ class MessageField(object): test.assertFalse(field_desc.has_default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) + test.assertEqual(file_desc, field_desc.file) class StringField(object): @@ -566,7 +576,7 @@ class StringField(object): self.number = number self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -578,6 +588,7 @@ class StringField(object): field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) test.assertEqual(self.default_value, field_desc.default_value) + test.assertEqual(file_desc, field_desc.file) class ExtensionField(object): @@ -586,7 +597,7 @@ class ExtensionField(object): self.number = number self.extended_type = extended_type - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.extensions_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -601,6 +612,7 @@ class ExtensionField(object): test.assertEqual(msg_desc, field_desc.extension_scope) test.assertEqual(msg_desc, field_desc.message_type) test.assertEqual(self.extended_type, field_desc.containing_type.name) + test.assertEqual(file_desc, field_desc.file) class AddDescriptorTest(unittest.TestCase): @@ -746,15 +758,10 @@ class AddDescriptorTest(unittest.TestCase): self.assertIs(options, file_descriptor.GetOptions()) -@unittest.skipIf( - api_implementation.Type() != 'cpp', - 'default_pool is only supported by the C++ implementation') class DefaultPoolTest(unittest.TestCase): def testFindMethods(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() self.assertIs( pool.FindFileByName('google/protobuf/unittest.proto'), unittest_pb2.DESCRIPTOR) @@ -765,19 +772,22 @@ class DefaultPoolTest(unittest.TestCase): pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), unittest_pb2.ForeignEnum.DESCRIPTOR) + if api_implementation.Type() != 'cpp': + self.skipTest('Only the C++ implementation correctly indexes all types') + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) self.assertIs( pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) def testAddFileDescriptor(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') pool.Add(file_desc) pool.AddSerializedFile(file_desc.SerializeToString()) diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 1f148ab9..c0010081 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -521,6 +521,12 @@ class GeneratedDescriptorTest(unittest.TestCase): del enum self.assertEqual('FOO', next(values_iter).name) + def testServiceDescriptor(self): + service_descriptor = unittest_pb2.DESCRIPTOR.services_by_name['TestService'] + self.assertEqual(service_descriptor.name, 'TestService') + self.assertEqual(service_descriptor.methods[0].name, 'Foo') + self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR) + class DescriptorCopyToProtoTest(unittest.TestCase): """Tests for CopyTo functions of Descriptor.""" diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 80e59cab..ebec42e5 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -372,7 +372,7 @@ def MapSizer(field_descriptor, is_message_map): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - def EncodeVarint(write, value): + def EncodeVarint(write, value, unused_deterministic): bits = value & 0x7f value >>= 7 while value: @@ -388,7 +388,7 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - def EncodeSignedVarint(write, value): + def EncodeSignedVarint(write, value, unused_deterministic): if value < 0: value += (1 << 64) bits = value & 0x7f @@ -411,7 +411,7 @@ def _VarintBytes(value): called at startup time so it doesn't need to be fast.""" pieces = [] - _EncodeVarint(pieces.append, value) + _EncodeVarint(pieces.append, value, True) return b"".join(pieces) @@ -440,27 +440,27 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) size = 0 for element in value: size += compute_value_size(element) - local_EncodeVarint(write, size) + local_EncodeVarint(write, size, deterministic) for element in value: - encode_value(write, element) + encode_value(write, element, deterministic) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag_bytes) - encode_value(write, element) + encode_value(write, element, deterministic) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag_bytes) - return encode_value(write, value) + return encode_value(write, value, deterministic) return EncodeField return SpecificEncoder @@ -474,27 +474,27 @@ def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) size = 0 for element in value: size += compute_value_size(modify_value(element)) - local_EncodeVarint(write, size) + local_EncodeVarint(write, size, deterministic) for element in value: - encode_value(write, modify_value(element)) + encode_value(write, modify_value(element), deterministic) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag_bytes) - encode_value(write, modify_value(element)) + encode_value(write, modify_value(element), deterministic) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag_bytes) - return encode_value(write, modify_value(value)) + return encode_value(write, modify_value(value), deterministic) return EncodeField return SpecificEncoder @@ -515,22 +515,22 @@ def _StructPackEncoder(wire_type, format): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value) * value_size) + local_EncodeVarint(write, len(value) * value_size, deterministic) for element in value: write(local_struct_pack(format, element)) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) write(local_struct_pack(format, element)) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) return write(local_struct_pack(format, value)) return EncodeField @@ -581,9 +581,9 @@ def _FloatingPointEncoder(wire_type, format): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value) * value_size) + local_EncodeVarint(write, len(value) * value_size, deterministic) for element in value: # This try/except block is going to be faster than any code that # we could write to check whether element is finite. @@ -594,7 +594,7 @@ def _FloatingPointEncoder(wire_type, format): return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) try: @@ -604,7 +604,7 @@ def _FloatingPointEncoder(wire_type, format): return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) try: write(local_struct_pack(format, value)) @@ -650,9 +650,9 @@ def BoolEncoder(field_number, is_repeated, is_packed): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value)) + local_EncodeVarint(write, len(value), deterministic) for element in value: if element: write(true_byte) @@ -661,7 +661,7 @@ def BoolEncoder(field_number, is_repeated, is_packed): return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) if element: @@ -671,7 +671,7 @@ def BoolEncoder(field_number, is_repeated, is_packed): return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) if value: return write(true_byte) @@ -687,18 +687,18 @@ def StringEncoder(field_number, is_repeated, is_packed): local_len = len assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: encoded = element.encode('utf-8') write(tag) - local_EncodeVarint(write, local_len(encoded)) + local_EncodeVarint(write, local_len(encoded), deterministic) write(encoded) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): encoded = value.encode('utf-8') write(tag) - local_EncodeVarint(write, local_len(encoded)) + local_EncodeVarint(write, local_len(encoded), deterministic) return write(encoded) return EncodeField @@ -711,16 +711,16 @@ def BytesEncoder(field_number, is_repeated, is_packed): local_len = len assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag) - local_EncodeVarint(write, local_len(element)) + local_EncodeVarint(write, local_len(element), deterministic) write(element) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag) - local_EncodeVarint(write, local_len(value)) + local_EncodeVarint(write, local_len(value), deterministic) return write(value) return EncodeField @@ -732,16 +732,16 @@ def GroupEncoder(field_number, is_repeated, is_packed): end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(start_tag) - element._InternalSerialize(write) + element._InternalSerialize(write, deterministic) write(end_tag) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(start_tag) - value._InternalSerialize(write) + value._InternalSerialize(write, deterministic) return write(end_tag) return EncodeField @@ -753,17 +753,17 @@ def MessageEncoder(field_number, is_repeated, is_packed): local_EncodeVarint = _EncodeVarint assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag) - local_EncodeVarint(write, element.ByteSize()) - element._InternalSerialize(write) + local_EncodeVarint(write, element.ByteSize(), deterministic) + element._InternalSerialize(write, deterministic) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag) - local_EncodeVarint(write, value.ByteSize()) - return value._InternalSerialize(write) + local_EncodeVarint(write, value.ByteSize(), deterministic) + return value._InternalSerialize(write, deterministic) return EncodeField @@ -790,10 +790,10 @@ def MessageSetItemEncoder(field_number): end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_EncodeVarint = _EncodeVarint - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(start_bytes) - local_EncodeVarint(write, value.ByteSize()) - value._InternalSerialize(write) + local_EncodeVarint(write, value.ByteSize(), deterministic) + value._InternalSerialize(write, deterministic) return write(end_bytes) return EncodeField @@ -818,9 +818,10 @@ def MapEncoder(field_descriptor): message_type = field_descriptor.message_type encode_message = MessageEncoder(field_descriptor.number, False, False) - def EncodeField(write, value): - for key in value: + def EncodeField(write, value, deterministic): + value_keys = sorted(value.keys()) if deterministic else value.keys() + for key in value_keys: entry_msg = message_type._concrete_class(key=key, value=value[key]) - encode_message(write, entry_msg) + encode_message(write, entry_msg, deterministic) return EncodeField diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto index bb1b54ad..5fcbc5ac 100644 --- a/python/google/protobuf/internal/factory_test2.proto +++ b/python/google/protobuf/internal/factory_test2.proto @@ -97,3 +97,8 @@ message MessageWithNestedEnumOnly { extend Factory1Message { optional string another_field = 1002; } + +message MessageWithOption { + option no_standard_descriptor_accessor = true; + optional int32 field1 = 1; +} diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 5ed65622..077b64db 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -49,6 +49,7 @@ from google.protobuf import field_mask_pb2 from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 from google.protobuf import wrappers_pb2 +from google.protobuf import unittest_mset_pb2 from google.protobuf.internal import well_known_types from google.protobuf import json_format from google.protobuf.util import json_format_proto3_pb2 @@ -158,6 +159,84 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse(text, parsed_message) self.assertEqual(message, parsed_message) + def testExtensionToJsonAndBack(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_text = json_format.MessageToJson( + message + ) + parsed_message = unittest_mset_pb2.TestMessageSetContainer() + json_format.Parse(message_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testExtensionToDictAndBack(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_dict = json_format.MessageToDict( + message + ) + parsed_message = unittest_mset_pb2.TestMessageSetContainer() + json_format.ParseDict(message_dict, parsed_message) + self.assertEqual(message, parsed_message) + + def testExtensionSerializationDictMatchesProto3Spec(self): + """See go/proto3-json-spec for spec. + """ + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_dict = json_format.MessageToDict( + message + ) + golden_dict = { + 'messageSet': { + '[protobuf_unittest.' + 'TestMessageSetExtension1.messageSetExtension]': { + 'i': 23, + }, + '[protobuf_unittest.' + 'TestMessageSetExtension2.messageSetExtension]': { + 'str': u'foo', + }, + }, + } + self.assertEqual(golden_dict, message_dict) + + + def testExtensionSerializationJsonMatchesProto3Spec(self): + """See go/proto3-json-spec for spec. + """ + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_text = json_format.MessageToJson( + message + ) + ext1_text = ('protobuf_unittest.TestMessageSetExtension1.' + 'messageSetExtension') + ext2_text = ('protobuf_unittest.TestMessageSetExtension2.' + 'messageSetExtension') + golden_text = ('{"messageSet": {' + ' "[%s]": {' + ' "i": 23' + ' },' + ' "[%s]": {' + ' "str": "foo"' + ' }' + '}}') % (ext1_text, ext2_text) + self.assertEqual(json.loads(golden_text), json.loads(message_text)) + + def testJsonEscapeString(self): message = json_format_proto3_pb2.TestMessage() if sys.version_info[0] < 3: @@ -768,7 +847,7 @@ class JsonFormatTest(JsonFormatBase): text = '{"value": "0000-01-01T00:00:00Z"}' self.assertRaisesRegexp( json_format.ParseError, - 'Failed to parse value field: year is out of range.', + 'Failed to parse value field: year (0 )?is out of range.', json_format.Parse, text, message) # Time bigger than maxinum time. message.value.seconds = 253402300800 @@ -840,6 +919,12 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse('{"int32_value": 12345}', message) self.assertEqual(12345, message.int32_value) + def testIndent(self): + message = json_format_proto3_pb2.TestMessage() + message.int32_value = 12345 + self.assertEqual('{\n"int32Value": 12345\n}', + json_format.MessageToJson(message, indent=0)) + def testParseDict(self): expected = 12345 js_dict = {'int32Value': expected} @@ -862,6 +947,22 @@ class JsonFormatTest(JsonFormatBase): parsed_message = json_format_proto3_pb2.TestCustomJsonName() self.CheckParseBack(message, parsed_message) + def testSortKeys(self): + # Testing sort_keys is not perfectly working, as by random luck we could + # get the output sorted. We just use a selection of names. + message = json_format_proto3_pb2.TestMessage(bool_value=True, + int32_value=1, + int64_value=3, + uint32_value=4, + string_value='bla') + self.assertEqual( + json_format.MessageToJson(message, sort_keys=True), + # We use json.dumps() instead of a hardcoded string due to differences + # between Python 2 and Python 3. + json.dumps({'boolValue': True, 'int32Value': 1, 'int64Value': '3', + 'uint32Value': 4, 'stringValue': 'bla'}, + indent=2, sort_keys=True)) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 3373b21f..dda72cdd 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -140,6 +140,42 @@ class MessageTest(BaseTestCase): golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) + def testDeterminismParameters(self, message_module): + # This message is always deterministically serialized, even if determinism + # is disabled, so we can use it to verify that all the determinism + # parameters work correctly. + golden_data = (b'\xe2\x02\nOne string' + b'\xe2\x02\nTwo string' + b'\xe2\x02\nRed string' + b'\xe2\x02\x0bBlue string') + golden_message = message_module.TestAllTypes() + golden_message.repeated_string.extend([ + 'One string', + 'Two string', + 'Red string', + 'Blue string', + ]) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=None)) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=False)) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=True)) + + class BadArgError(Exception): + pass + + class BadArg(object): + + def __nonzero__(self): + raise BadArgError() + + def __bool__(self): + raise BadArgError() + + with self.assertRaises(BadArgError): + golden_message.SerializeToString(deterministic=BadArg()) + def testPickleSupport(self, message_module): golden_data = test_util.GoldenFileData('golden_message') golden_message = message_module.TestAllTypes() @@ -381,6 +417,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_int32[0], 1) self.assertEqual(message.repeated_int32[1], 2) self.assertEqual(message.repeated_int32[2], 3) + self.assertEqual(str(message.repeated_int32), str([1, 2, 3])) message.repeated_float.append(1.1) message.repeated_float.append(1.3) @@ -397,6 +434,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_string[0], 'a') self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') + self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c'])) message.repeated_bytes.append(b'a') message.repeated_bytes.append(b'c') @@ -405,6 +443,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_bytes[0], b'a') self.assertEqual(message.repeated_bytes[1], b'b') self.assertEqual(message.repeated_bytes[2], b'c') + self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c'])) def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): """Check some different types with custom comparator.""" @@ -443,6 +482,8 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_nested_message[3].bb, 4) self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + self.assertEqual(str(message.repeated_nested_message), + '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') def testSortingRepeatedCompositeFieldsStable(self, message_module): """Check passing a custom comparator to sort a repeated composite field.""" @@ -1270,6 +1311,14 @@ class Proto3Test(BaseTestCase): self.assertEqual(1234567, m2.optional_nested_enum) self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # ParseFromString in Proto2 should accept unknown enums too. + m3 = unittest_pb2.TestAllTypes() + m3.ParseFromString(serialized) + m2.Clear() + m2.ParseFromString(m3.SerializeToString()) + self.assertEqual(1234567, m2.optional_nested_enum) + self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # Map isn't really a proto3-only feature. But there is no proto2 equivalent # of google/protobuf/map_unittest.proto right now, so it's not easy to # test both with the same test like we do for the other proto2/proto3 tests. @@ -1441,6 +1490,23 @@ class Proto3Test(BaseTestCase): self.assertIn(-456, msg2.map_int32_foreign_message) self.assertEqual(2, len(msg2.map_int32_foreign_message)) + def testNestedMessageMapItemDelete(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_all_types[1].optional_nested_message.bb = 1 + del msg.map_int32_all_types[1] + msg.map_int32_all_types[2].optional_nested_message.bb = 2 + self.assertEqual(1, len(msg.map_int32_all_types)) + msg.map_int32_all_types[1].optional_nested_message.bb = 1 + self.assertEqual(2, len(msg.map_int32_all_types)) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + keys = [1, 2] + # The loop triggers PyErr_Occurred() in c extension. + for key in keys: + del msg2.map_int32_all_types[key] + def testMapByteSize(self): msg = map_unittest_pb2.TestMap() msg.map_int32_int32[1] = 1 @@ -1655,6 +1721,35 @@ class Proto3Test(BaseTestCase): items2 = msg.map_string_string.items() self.assertEqual(items1, items2) + def testMapDeterministicSerialization(self): + golden_data = (b'r\x0c\n\x07init_op\x12\x01d' + b'r\n\n\x05item1\x12\x01e' + b'r\n\n\x05item2\x12\x01f' + b'r\n\n\x05item3\x12\x01g' + b'r\x0b\n\x05item4\x12\x02QQ' + b'r\x12\n\rlocal_init_op\x12\x01a' + b'r\x0e\n\tsummaries\x12\x01e' + b'r\x18\n\x13trainable_variables\x12\x01b' + b'r\x0e\n\tvariables\x12\x01c') + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = 'a' + msg.map_string_string['trainable_variables'] = 'b' + msg.map_string_string['variables'] = 'c' + msg.map_string_string['init_op'] = 'd' + msg.map_string_string['summaries'] = 'e' + msg.map_string_string['item1'] = 'e' + msg.map_string_string['item2'] = 'f' + msg.map_string_string['item3'] = 'g' + msg.map_string_string['item4'] = 'QQ' + + # If deterministic serialization is not working correctly, this will be + # "flaky" depending on the exact python dict hash seed. + # + # Fortunately, there are enough items in this map that it is extremely + # unlikely to ever hit the "right" in-order combination, so the test + # itself should fail reliably. + self.assertEqual(golden_data, msg.SerializeToString(deterministic=True)) + def testMapIterationClearMessage(self): # Iterator needs to work even if message and map are deleted. msg = map_unittest_pb2.TestMap() diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto index 11f85ef6..98fcbcb6 100644 --- a/python/google/protobuf/internal/more_extensions_dynamic.proto +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -47,4 +47,5 @@ message DynamicMessageType { extend ExtendedMessage { optional int32 dynamic_int32_extension = 100; optional DynamicMessageType dynamic_message_extension = 101; + repeated DynamicMessageType repeated_dynamic_message_extension = 102; } diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index cb97cb28..c363d843 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -58,6 +58,7 @@ import weakref import six # We use "as" to avoid name collisions with variables. +from google.protobuf.internal import api_implementation from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder @@ -1026,29 +1027,34 @@ def _AddByteSizeMethod(message_descriptor, cls): def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializeToString(self): + def SerializeToString(self, **kwargs): # Check if the message has all of its required fields set. errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( 'Message %s is missing required fields: %s' % ( self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) - return self.SerializePartialToString() + return self.SerializePartialToString(**kwargs) cls.SerializeToString = SerializeToString def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializePartialToString(self): + def SerializePartialToString(self, **kwargs): out = BytesIO() - self._InternalSerialize(out.write) + self._InternalSerialize(out.write, **kwargs) return out.getvalue() cls.SerializePartialToString = SerializePartialToString - def InternalSerialize(self, write_bytes): + def InternalSerialize(self, write_bytes, deterministic=None): + if deterministic is None: + deterministic = ( + api_implementation.IsPythonDefaultSerializationDeterministic()) + else: + deterministic = bool(deterministic) for field_descriptor, field_value in self.ListFields(): - field_descriptor._encoder(write_bytes, field_value) + field_descriptor._encoder(write_bytes, field_value, deterministic) for tag_bytes, value_bytes in self._unknown_fields: write_bytes(tag_bytes) write_bytes(value_bytes) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 188310b2..424b29cc 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -452,16 +452,18 @@ class TextFormatTest(TextFormatBase): text_format.Parse(text_format.MessageToString(m), m2) self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + def testMergeMultipleOneof(self, message_module): + m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) + m2 = message_module.TestAllTypes() + text_format.Merge(m_string, m2) + self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) + def testParseMultipleOneof(self, message_module): m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) m2 = message_module.TestAllTypes() - if message_module is unittest_pb2: - with self.assertRaisesRegexp(text_format.ParseError, - ' is specified along with field '): - text_format.Parse(m_string, m2) - else: + with self.assertRaisesRegexp(text_format.ParseError, + ' is specified along with field '): text_format.Parse(m_string, m2) - self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) # These are tests that aren't fundamentally specific to proto2, but are at @@ -1026,8 +1028,7 @@ class Proto3Tests(unittest.TestCase): packed_message.data = 'string1' message.repeated_any_value.add().Pack(packed_message) self.assertEqual( - text_format.MessageToString(message, - descriptor_pool=descriptor_pool.Default()), + text_format.MessageToString(message), 'repeated_any_value {\n' ' [type.googleapis.com/protobuf_unittest.OneString] {\n' ' data: "string0"\n' @@ -1039,18 +1040,6 @@ class Proto3Tests(unittest.TestCase): ' }\n' '}\n') - def testPrintMessageExpandAnyNoDescriptorPool(self): - packed_message = unittest_pb2.OneString() - packed_message.data = 'string' - message = any_test_pb2.TestAny() - message.any_value.Pack(packed_message) - self.assertEqual( - text_format.MessageToString(message, descriptor_pool=None), - 'any_value {\n' - ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n' - ' value: "\\n\\006string"\n' - '}\n') - def testPrintMessageExpandAnyDescriptorPoolMissingType(self): packed_message = unittest_pb2.OneString() packed_message.data = 'string' @@ -1071,8 +1060,7 @@ class Proto3Tests(unittest.TestCase): message.any_value.Pack(packed_message) self.assertEqual( text_format.MessageToString(message, - pointy_brackets=True, - descriptor_pool=descriptor_pool.Default()), + pointy_brackets=True), 'any_value <\n' ' [type.googleapis.com/protobuf_unittest.OneString] <\n' ' data: "string"\n' @@ -1086,8 +1074,7 @@ class Proto3Tests(unittest.TestCase): message.any_value.Pack(packed_message) self.assertEqual( text_format.MessageToString(message, - as_one_line=True, - descriptor_pool=descriptor_pool.Default()), + as_one_line=True), 'any_value {' ' [type.googleapis.com/protobuf_unittest.OneString]' ' { data: "string" } ' @@ -1115,12 +1102,12 @@ class Proto3Tests(unittest.TestCase): ' data: "string"\n' ' }\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) message.Clear() - text_format.Parse(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Parse(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) @@ -1137,7 +1124,7 @@ class Proto3Tests(unittest.TestCase): ' data: "string1"\n' ' }\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.repeated_any_value[0].Unpack(packed_message) self.assertEqual('string0', packed_message.data) @@ -1151,22 +1138,22 @@ class Proto3Tests(unittest.TestCase): ' data: "string"\n' ' >\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) - def testMergeExpandedAnyNoDescriptorPool(self): + def testMergeAlternativeUrl(self): message = any_test_pb2.TestAny() text = ('any_value {\n' - ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' [type.otherapi.com/protobuf_unittest.OneString] {\n' ' data: "string"\n' ' }\n' '}\n') - with self.assertRaises(text_format.ParseError) as e: - text_format.Merge(text, message, descriptor_pool=None) - self.assertEqual(str(e.exception), - 'Descriptor pool required to parse expanded Any field') + text_format.Merge(text, message) + packed_message = unittest_pb2.OneString() + self.assertEqual('type.otherapi.com/protobuf_unittest.OneString', + message.any_value.type_url) def testMergeExpandedAnyDescriptorPoolMissingType(self): message = any_test_pb2.TestAny() @@ -1425,5 +1412,101 @@ class TokenizerTest(unittest.TestCase): tokenizer.ConsumeCommentOrTrailingComment()) self.assertTrue(tokenizer.AtEnd()) + +# Tests for pretty printer functionality. +@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2)) +class PrettyPrinterTest(TextFormatBase): + + def testPrettyPrintNoMatch(self, message_module): + + def printer(message, indent, as_one_line): + del message, indent, as_one_line + return None + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { bb: 42 }') + + def testPrettyPrintOneLine(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + return 'My lucky number is %s' % m.bb + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { My lucky number is 42 }') + + def testPrettyPrintMultiLine(self, message_module): + + def printer(m, indent, as_one_line): + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + line_deliminator = (' ' if as_one_line else '\n') + ' ' * indent + return 'My lucky number is:%s%s' % (line_deliminator, m.bb) + return None + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { My lucky number is: 42 }') + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=False, message_formatter=printer), + 'repeated_nested_message {\n My lucky number is:\n 42\n}\n') + + def testPrettyPrintEntireMessage(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.DESCRIPTOR: + return 'The is the message!' + return None + + message = message_module.TestAllTypes() + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=False, message_formatter=printer), + 'The is the message!\n') + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'The is the message!') + + def testPrettyPrintMultipleParts(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + return 'My lucky number is %s' % m.bb + return None + + message = message_module.TestAllTypes() + message.optional_int32 = 61 + msg = message.repeated_nested_message.add() + msg.bb = 42 + msg = message.repeated_nested_message.add() + msg.bb = 99 + msg = message.optional_nested_message + msg.bb = 1 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + ('optional_int32: 61 ' + 'optional_nested_message { My lucky number is 1 } ' + 'repeated_nested_message { My lucky number is 42 } ' + 'repeated_nested_message { My lucky number is 99 }')) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d631abee..d0c7ffda 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -350,12 +350,12 @@ class Duration(object): self.nanos, _NANOS_PER_MICROSECOND)) def FromTimedelta(self, td): - """Convertd timedelta to Duration.""" + """Converts timedelta to Duration.""" self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, td.microseconds * _NANOS_PER_MICROSECOND) def _NormalizeDuration(self, seconds, nanos): - """Set Duration by seconds and nonas.""" + """Set Duration by seconds and nanos.""" # Force nanos to be negative if the duration is negative. if seconds < 0 and nanos > 0: seconds += 1 diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 077f630f..123a537c 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -284,7 +284,7 @@ class TimeUtilTest(TimeUtilTestBase): '1972-01-01T01:00:00.01+08',) self.assertRaisesRegexp( ValueError, - 'year is out of range', + 'year (0 )?is out of range', message.FromJsonString, '0000-01-01T00:00:00Z') message.seconds = 253402300800 diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index 937285b0..801eed60 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -74,6 +74,9 @@ _UNPAIRED_SURROGATE_PATTERN = re.compile(six.u( r'[\ud800-\udbff](?![\udc00-\udfff])|(?<![\ud800-\udbff])[\udc00-\udfff]' )) +_VALID_EXTENSION_NAME = re.compile(r'\[[a-zA-Z0-9\._]*\]$') + + class Error(Exception): """Top-level module error for json_format.""" @@ -88,7 +91,9 @@ class ParseError(Error): def MessageToJson(message, including_default_value_fields=False, - preserving_proto_field_name=False): + preserving_proto_field_name=False, + indent=2, + sort_keys=False): """Converts protobuf message to JSON format. Args: @@ -100,19 +105,24 @@ def MessageToJson(message, preserving_proto_field_name: If True, use the original proto field names as defined in the .proto file. If False, convert the field names to lowerCamelCase. + indent: The JSON object will be pretty-printed with this indent level. + An indent level of 0 or negative will only insert newlines. + sort_keys: If True, then the output will be sorted by field names. Returns: A string containing the JSON formatted protocol buffer message. """ printer = _Printer(including_default_value_fields, preserving_proto_field_name) - return printer.ToJsonString(message) + return printer.ToJsonString(message, indent, sort_keys) def MessageToDict(message, including_default_value_fields=False, preserving_proto_field_name=False): - """Converts protobuf message to a JSON dictionary. + """Converts protobuf message to a dictionary. + + When the dictionary is encoded to JSON, it conforms to proto3 JSON spec. Args: message: The protocol buffers message instance to serialize. @@ -125,7 +135,7 @@ def MessageToDict(message, names to lowerCamelCase. Returns: - A dict representation of the JSON formatted protocol buffer message. + A dict representation of the protocol buffer message. """ printer = _Printer(including_default_value_fields, preserving_proto_field_name) @@ -148,9 +158,9 @@ class _Printer(object): self.including_default_value_fields = including_default_value_fields self.preserving_proto_field_name = preserving_proto_field_name - def ToJsonString(self, message): + def ToJsonString(self, message, indent, sort_keys): js = self._MessageToJsonObject(message) - return json.dumps(js, indent=2) + return json.dumps(js, indent=indent, sort_keys=sort_keys) def _MessageToJsonObject(self, message): """Converts message to an object according to Proto3 JSON Specification.""" @@ -192,6 +202,14 @@ class _Printer(object): # Convert a repeated field. js[name] = [self._FieldToJsonObject(field, k) for k in value] + elif field.is_extension: + f = field + if (f.containing_type.GetOptions().message_set_wire_format and + f.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + f.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + f = f.message_type + name = '[%s.%s]' % (f.full_name, name) + js[name] = self._FieldToJsonObject(field, value) else: js[name] = self._FieldToJsonObject(field, value) @@ -433,12 +451,23 @@ class _Parser(object): field = fields_by_json_name.get(name, None) if not field: field = message_descriptor.fields_by_name.get(name, None) + if not field and _VALID_EXTENSION_NAME.match(name): + if not message_descriptor.is_extendable: + raise ParseError('Message type {0} does not have extensions'.format( + message_descriptor.full_name)) + identifier = name[1:-1] # strip [] brackets + identifier = '.'.join(identifier.split('.')[:-1]) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(identifier) + # pylint: enable=protected-access if not field: if self.ignore_unknown_fields: continue raise ParseError( - 'Message type "{0}" has no field named "{1}".'.format( - message_descriptor.full_name, name)) + ('Message type "{0}" has no field named "{1}".\n' + ' Available Fields(except extensions): {2}').format( + message_descriptor.full_name, name, + message_descriptor.fields)) if name in names: raise ParseError('Message type "{0}" should not have multiple ' '"{1}" fields.'.format( @@ -491,7 +520,10 @@ class _Parser(object): getattr(message, field.name).append( _ConvertScalarFieldValue(item, field)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - sub_message = getattr(message, field.name) + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) sub_message.SetInParent() self.ConvertMessage(value, sub_message) else: diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index aab250e4..eeb0d576 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -184,9 +184,15 @@ class Message(object): self.Clear() self.MergeFromString(serialized) - def SerializeToString(self): + def SerializeToString(self, **kwargs): """Serializes the protocol message to a binary string. + Arguments: + **kwargs: Keyword arguments to the serialize method, accepts + the following keyword args: + deterministic: If true, requests deterministic serialization of the + protobuf, with predictable ordering of map keys. + Returns: A binary string representation of the message if all of the required fields in the message are set (i.e. the message is initialized). @@ -196,12 +202,18 @@ class Message(object): """ raise NotImplementedError - def SerializePartialToString(self): + def SerializePartialToString(self, **kwargs): """Serializes the protocol message to a binary string. This method is similar to SerializeToString but doesn't check if the message is initialized. + Arguments: + **kwargs: Keyword arguments to the serialize method, accepts + the following keyword args: + deterministic: If true, requests deterministic serialization of the + protobuf, with predictable ordering of map keys. + Returns: A string representation of the partial message. """ diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 8ab1c513..15740280 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -66,7 +66,7 @@ class MessageFactory(object): Returns: A class describing the passed in descriptor. """ - if descriptor.full_name not in self._classes: + if descriptor not in self._classes: descriptor_name = descriptor.name if str is bytes: # PY2 descriptor_name = descriptor.name.encode('ascii', 'ignore') @@ -75,16 +75,16 @@ class MessageFactory(object): (message.Message,), {'DESCRIPTOR': descriptor, '__module__': None}) # If module not set, it wrongly points to the reflection.py module. - self._classes[descriptor.full_name] = result_class + self._classes[descriptor] = result_class for field in descriptor.fields: if field.message_type: self.GetPrototype(field.message_type) for extension in result_class.DESCRIPTOR.extensions: - if extension.containing_type.full_name not in self._classes: + if extension.containing_type not in self._classes: self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] + extended_class = self._classes[extension.containing_type] extended_class.RegisterExtension(extension) - return self._classes[descriptor.full_name] + return self._classes[descriptor] def GetMessages(self, files): """Gets all the messages from a specified file. @@ -116,9 +116,9 @@ class MessageFactory(object): # an error if they were different. for extension in file_desc.extensions_by_name.values(): - if extension.containing_type.full_name not in self._classes: + if extension.containing_type not in self._classes: self.GetPrototype(extension.containing_type) - extended_class = self._classes[extension.containing_type.full_name] + extended_class = self._classes[extension.containing_type] extended_class.RegisterExtension(extension) return result diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index f13e1bc1..9634ea05 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -709,6 +709,10 @@ static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) { return PyString_FromCppString(_GetDescriptor(self)->json_name()); } +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + static PyObject* GetType(PyBaseDescriptor *self, void *closure) { return PyInt_FromLong(_GetDescriptor(self)->type()); } @@ -899,6 +903,7 @@ static PyGetSetDef Getters[] = { { "name", (getter)GetName, NULL, "Unqualified name"}, { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"}, { "json_name", (getter)GetJsonName, NULL, "Json name"}, + { "file", (getter)GetFile, NULL, "File Descriptor"}, { "type", (getter)GetType, NULL, "C++ Type"}, { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, { "label", (getter)GetLabel, NULL, "Label"}, @@ -1570,6 +1575,10 @@ static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { return PyString_FromCppString(_GetDescriptor(self)->full_name()); } +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); +} + static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { return PyInt_FromLong(_GetDescriptor(self)->index()); } @@ -1611,6 +1620,7 @@ static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { static PyGetSetDef Getters[] = { { "name", (getter)GetName, NULL, "Name", NULL}, { "full_name", (getter)GetFullName, NULL, "Full name", NULL}, + { "file", (getter)GetFile, NULL, "File descriptor"}, { "index", (getter)GetIndex, NULL, "Index", NULL}, { "methods", (getter)GetMethods, NULL, "Methods", NULL}, diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 088ddf93..43be0701 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -712,8 +712,30 @@ int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, } // Delete key from map. - if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, map_key)) { + // Delete key from CMessage dict. + MapValueRef value; + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + ScopedPyObjectPtr key(PyLong_FromVoidPtr(value.MutableMessageValue())); + + // PyDict_DelItem will have key error if the key is not in the map. We do + // not want to call PyErr_Clear() which may clear other errors. Thus + // PyDict_Contains() check is called before delete. + int contains = PyDict_Contains(self->message_dict, key.get()); + if (contains < 0) { + return -1; + } + if (contains) { + if (PyDict_DelItem(self->message_dict, key.get()) < 0) { + return -1; + } + } + + // Delete key from map. + reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key); return 0; } else { PyErr_Format(PyExc_KeyError, "Key not present in map"); diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 49113c7c..702c5d03 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -52,6 +52,7 @@ #include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/logging.h> #include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include <google/protobuf/util/message_differencer.h> #include <google/protobuf/descriptor.h> #include <google/protobuf/message.h> @@ -1810,8 +1811,25 @@ static string GetMessageName(CMessage* self) { } } -static PyObject* SerializeToString(CMessage* self, PyObject* args) { - if (!self->message->IsInitialized()) { +static PyObject* InternalSerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs, + bool require_initialized) { + // Parse the "deterministic" kwarg; defaults to False. + static char* kwlist[] = { "deterministic", 0 }; + PyObject* deterministic_obj = Py_None; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, + &deterministic_obj)) { + return NULL; + } + // Preemptively convert to a bool first, so we don't need to back out of + // allocating memory if this raises an exception. + // NOTE: This is unused later if deterministic == Py_None, but that's fine. + int deterministic = PyObject_IsTrue(deterministic_obj); + if (deterministic < 0) { + return NULL; + } + + if (require_initialized && !self->message->IsInitialized()) { ScopedPyObjectPtr errors(FindInitializationErrors(self)); if (errors == NULL) { return NULL; @@ -1849,24 +1867,36 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) { GetMessageName(self).c_str(), PyString_AsString(joined.get())); return NULL; } - int size = self->message->ByteSize(); - if (size <= 0) { + + // Ok, arguments parsed and errors checked, now encode to a string + const size_t size = self->message->ByteSizeLong(); + if (size == 0) { return PyBytes_FromString(""); } PyObject* result = PyBytes_FromStringAndSize(NULL, size); if (result == NULL) { return NULL; } - char* buffer = PyBytes_AS_STRING(result); - self->message->SerializeWithCachedSizesToArray( - reinterpret_cast<uint8*>(buffer)); + io::ArrayOutputStream out(PyBytes_AS_STRING(result), size); + io::CodedOutputStream coded_out(&out); + if (deterministic_obj != Py_None) { + coded_out.SetSerializationDeterministic(deterministic); + } + self->message->SerializeWithCachedSizes(&coded_out); + GOOGLE_CHECK(!coded_out.HadError()); return result; } -static PyObject* SerializePartialToString(CMessage* self) { - string contents; - self->message->SerializePartialToString(&contents); - return PyBytes_FromStringAndSize(contents.c_str(), contents.size()); +static PyObject* SerializeToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/true); +} + +static PyObject* SerializePartialToString( + CMessage* self, PyObject* args, PyObject* kwargs) { + return InternalSerializeToString(self, args, kwargs, + /*require_initialized=*/false); } // Formats proto fields for ascii dumps using python formatting functions where @@ -2537,7 +2567,10 @@ PyObject* Reduce(CMessage* self) { if (state == NULL) { return NULL; } - ScopedPyObjectPtr serialized(SerializePartialToString(self)); + string contents; + self->message->SerializePartialToString(&contents); + ScopedPyObjectPtr serialized( + PyBytes_FromStringAndSize(contents.c_str(), contents.size())); if (serialized == NULL) { return NULL; } @@ -2658,9 +2691,10 @@ static PyMethodDef Methods[] = { { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS, "Registers an extension with the current message." }, { "SerializePartialToString", (PyCFunction)SerializePartialToString, - METH_NOARGS, + METH_VARARGS | METH_KEYWORDS, "Serializes the message to a string, even if it isn't initialized." }, - { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS, + { "SerializeToString", (PyCFunction)SerializeToString, + METH_VARARGS | METH_KEYWORDS, "Serializes the message to a string, only for initialized messages." }, { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, "Sets the has bit of the given field in its parent message." }, diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index e0b45bf2..571bae2b 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -133,11 +133,7 @@ int RegisterMessageClass(PyMessageFactory* self, CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, const Descriptor* descriptor) { // This is the same implementation as MessageFactory.GetPrototype(). - ScopedPyObjectPtr py_descriptor( - PyMessageDescriptor_FromDescriptor(descriptor)); - if (py_descriptor == NULL) { - return NULL; - } + // Do not create a MessageClass that already exists. hash_map<const Descriptor*, CMessageClass*>::iterator it = self->classes_by_descriptor->find(descriptor); @@ -145,6 +141,11 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, Py_INCREF(it->second); return it->second; } + ScopedPyObjectPtr py_descriptor( + PyMessageDescriptor_FromDescriptor(descriptor)); + if (py_descriptor == NULL) { + return NULL; + } // Create a new message class. ScopedPyObjectPtr args(Py_BuildValue( "s(){sOsOsO}", descriptor->name().c_str(), diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 06a976de..5ad71db5 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -46,6 +46,7 @@ #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/message_factory.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/reflection.h> @@ -137,9 +138,12 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, if (cmessage::AssureWritable(self->parent) == -1) return NULL; Message* message = self->message; + Message* sub_message = - message->GetReflection()->AddMessage(message, - self->parent_field_descriptor); + message->GetReflection()->AddMessage( + message, + self->parent_field_descriptor, + self->child_message_class->py_message_factory->message_factory); CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class); if (cmsg == NULL) return NULL; @@ -336,6 +340,18 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self, } } +static PyObject* ToStr(RepeatedCompositeContainer* self) { + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); + if (list == NULL) { + return NULL; + } + return PyObject_Repr(list.get()); +} + // --------------------------------------------------------------------- // sort() @@ -608,7 +624,7 @@ PyTypeObject RepeatedCompositeContainer_Type = { 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + (reprfunc)repeated_composite_container::ToStr, // tp_repr 0, // tp_as_number &repeated_composite_container::SqMethods, // tp_as_sequence &repeated_composite_container::MpMethods, // tp_as_mapping diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index 9fedb952..54998800 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -659,6 +659,18 @@ static PyObject* Pop(RepeatedScalarContainer* self, return item; } +static PyObject* ToStr(RepeatedScalarContainer* self) { + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice.get())); + if (list == NULL) { + return NULL; + } + return PyObject_Repr(list.get()); +} + // The private constructor of RepeatedScalarContainer objects. PyObject *NewContainer( CMessage* parent, const FieldDescriptor* parent_field_descriptor) { @@ -781,7 +793,7 @@ PyTypeObject RepeatedScalarContainer_Type = { 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + (reprfunc)repeated_scalar_container::ToStr, // tp_repr 0, // tp_as_number &repeated_scalar_container::SqMethods, // tp_as_sequence &repeated_scalar_container::MpMethods, // tp_as_mapping diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 07341efa..5ad869f4 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -78,10 +78,18 @@ class SymbolDatabase(message_factory.MessageFactory): """ desc = message.DESCRIPTOR - self._classes[desc.full_name] = message - self.pool.AddDescriptor(desc) + self._classes[desc] = message + self.RegisterMessageDescriptor(desc) return message + def RegisterMessageDescriptor(self, message_descriptor): + """Registers the given message descriptor in the local database. + + Args: + message_descriptor: a descriptor.MessageDescriptor. + """ + self.pool.AddDescriptor(message_descriptor) + def RegisterEnumDescriptor(self, enum_descriptor): """Registers the given enum descriptor in the local database. @@ -132,7 +140,7 @@ class SymbolDatabase(message_factory.MessageFactory): KeyError: if the symbol could not be found. """ - return self._classes[symbol] + return self._classes[self.pool.FindMessageTypeByName(symbol)] def GetMessages(self, files): # TODO(amauryfa): Fix the differences with MessageFactory. @@ -153,20 +161,20 @@ class SymbolDatabase(message_factory.MessageFactory): KeyError: if a file could not be found. """ - def _GetAllMessageNames(desc): + def _GetAllMessages(desc): """Walk a message Descriptor and recursively yields all message names.""" - yield desc.full_name + yield desc for msg_desc in desc.nested_types: - for full_name in _GetAllMessageNames(msg_desc): - yield full_name + for nested_desc in _GetAllMessages(msg_desc): + yield nested_desc result = {} for file_name in files: file_desc = self.pool.FindFileByName(file_name) for msg_desc in file_desc.message_types_by_name.values(): - for full_name in _GetAllMessageNames(msg_desc): + for desc in _GetAllMessages(msg_desc): try: - result[full_name] = self._classes[full_name] + result[desc.full_name] = self._classes[desc] except KeyError: # This descriptor has no registered class, skip it. pass diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index c216e097..aaca78ad 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -126,7 +126,8 @@ def MessageToString(message, float_format=None, use_field_number=False, descriptor_pool=None, - indent=0): + indent=0, + message_formatter=None): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -148,6 +149,9 @@ def MessageToString(message, use_field_number: If True, print field numbers instead of names. descriptor_pool: A DescriptorPool used to resolve Any types. indent: The indent level, in terms of spaces, for pretty print. + message_formatter: A function(message, indent, as_one_line): unicode|None + to custom format selected sub-messages (usually based on message type). + Use to pretty print parts of the protobuf for easier diffing. Returns: A string of the text formatted protocol buffer message. @@ -155,7 +159,7 @@ def MessageToString(message, out = TextWriter(as_utf8) printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, use_index_order, float_format, use_field_number, - descriptor_pool) + descriptor_pool, message_formatter) printer.PrintMessage(message) result = out.getvalue() out.close() @@ -179,10 +183,11 @@ def PrintMessage(message, use_index_order=False, float_format=None, use_field_number=False, - descriptor_pool=None): + descriptor_pool=None, + message_formatter=None): printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, use_index_order, float_format, use_field_number, - descriptor_pool) + descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -194,10 +199,11 @@ def PrintField(field, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, + message_formatter=None): """Print a single field name/value pair.""" printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, - use_index_order, float_format) + use_index_order, float_format, message_formatter) printer.PrintField(field, value) @@ -209,10 +215,11 @@ def PrintFieldValue(field, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, + message_formatter=None): """Print a single field value (not including name).""" printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, - use_index_order, float_format) + use_index_order, float_format, message_formatter) printer.PrintFieldValue(field, value) @@ -228,6 +235,9 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool): wasn't found matching type_name. """ # pylint: disable=g-import-not-at-top + if descriptor_pool is None: + from google.protobuf import descriptor_pool as pool_mod + descriptor_pool = pool_mod.Default() from google.protobuf import symbol_database database = symbol_database.Default() try: @@ -250,7 +260,8 @@ class _Printer(object): use_index_order=False, float_format=None, use_field_number=False, - descriptor_pool=None): + descriptor_pool=None, + message_formatter=None): """Initialize the Printer. Floating point values can be formatted compactly with 15 digits of @@ -273,6 +284,9 @@ class _Printer(object): used. use_field_number: If True, print field numbers instead of names. descriptor_pool: A DescriptorPool used to resolve Any types. + message_formatter: A function(message, indent, as_one_line): unicode|None + to custom format selected sub-messages (usually based on message type). + Use to pretty print parts of the protobuf for easier diffing. """ self.out = out self.indent = indent @@ -283,6 +297,7 @@ class _Printer(object): self.float_format = float_format self.use_field_number = use_field_number self.descriptor_pool = descriptor_pool + self.message_formatter = message_formatter def _TryPrintAsAnyMessage(self, message): """Serializes if message is a google.protobuf.Any field.""" @@ -297,14 +312,27 @@ class _Printer(object): else: return False + def _TryCustomFormatMessage(self, message): + formatted = self.message_formatter(message, self.indent, self.as_one_line) + if formatted is None: + return False + + out = self.out + out.write(' ' * self.indent) + out.write(formatted) + out.write(' ' if self.as_one_line else '\n') + return True + def PrintMessage(self, message): """Convert protobuf message to text format. Args: message: The protocol buffers message. """ + if self.message_formatter and self._TryCustomFormatMessage(message): + return if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and - self.descriptor_pool and self._TryPrintAsAnyMessage(message)): + self._TryPrintAsAnyMessage(message)): return fields = message.ListFields() if self.use_index_order: @@ -426,6 +454,22 @@ def Parse(text, descriptor_pool=None): """Parses a text representation of a protocol message into a message. + NOTE: for historical reasons this function does not clear the input + message. This is different from what the binary msg.ParseFrom(...) does. + + Example + a = MyProto() + a.repeated_field.append('test') + b = MyProto() + + text_format.Parse(repr(a), b) + text_format.Parse(repr(a), b) # repeated_field contains ["test", "test"] + + # Binary version: + b.ParseFromString(a.SerializeToString()) # repeated_field is now "test" + + Caller is responsible for clearing the message as needed. + Args: text: Message text representation. message: A protocol buffer message to merge into. @@ -593,11 +637,6 @@ class _Parser(object): ParseError: In case of text parsing problems. """ message_descriptor = message.DESCRIPTOR - if (hasattr(message_descriptor, 'syntax') and - message_descriptor.syntax == 'proto3'): - # Proto3 doesn't represent presence so we can't test if multiple - # scalars have occurred. We have to allow them. - self._allow_multiple_scalars = True if tokenizer.TryConsume('['): name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): @@ -616,7 +655,11 @@ class _Parser(object): field = None else: raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) + 'Extension "%s" not registered. ' + 'Did you import the _pb2 module which defines it? ' + 'If you are trying to place the extension in the MessageSet ' + 'field of another message that is in an Any or MessageSet field, ' + 'that message\'s _pb2 module must be imported as well' % name) elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( 'Extension "%s" does not extend message type "%s".' % @@ -695,17 +738,17 @@ class _Parser(object): def _ConsumeAnyTypeUrl(self, tokenizer): """Consumes a google.protobuf.Any type URL and returns the type name.""" # Consume "type.googleapis.com/". - tokenizer.ConsumeIdentifier() + prefix = [tokenizer.ConsumeIdentifier()] tokenizer.Consume('.') - tokenizer.ConsumeIdentifier() + prefix.append(tokenizer.ConsumeIdentifier()) tokenizer.Consume('.') - tokenizer.ConsumeIdentifier() + prefix.append(tokenizer.ConsumeIdentifier()) tokenizer.Consume('/') # Consume the fully-qualified type name. name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): name.append(tokenizer.ConsumeIdentifier()) - return '.'.join(name) + return '.'.join(prefix), '.'.join(name) def _MergeMessageField(self, tokenizer, message, field): """Merges a single scalar field into a message. @@ -728,7 +771,7 @@ class _Parser(object): if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and tokenizer.TryConsume('[')): - packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) + type_url_prefix, packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) tokenizer.Consume(']') tokenizer.TryConsume(':') if tokenizer.TryConsume('<'): @@ -736,8 +779,6 @@ class _Parser(object): else: tokenizer.Consume('{') expanded_any_end_token = '}' - if not self.descriptor_pool: - raise ParseError('Descriptor pool required to parse expanded Any field') expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, self.descriptor_pool) if not expanded_any_sub_message: @@ -752,7 +793,8 @@ class _Parser(object): any_message = getattr(message, field.name).add() else: any_message = getattr(message, field.name) - any_message.Pack(expanded_any_sub_message) + any_message.Pack(expanded_any_sub_message, + type_url_prefix=type_url_prefix) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() @@ -780,6 +822,12 @@ class _Parser(object): else: getattr(message, field.name)[sub_message.key] = sub_message.value + @staticmethod + def _IsProto3Syntax(message): + message_descriptor = message.DESCRIPTOR + return (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3') + def _MergeScalarField(self, tokenizer, message, field): """Merges a single scalar field into a message. @@ -829,15 +877,20 @@ class _Parser(object): else: getattr(message, field.name).append(value) else: + # Proto3 doesn't represent presence so we can't test if multiple scalars + # have occurred. We have to allow them. + can_check_presence = not self._IsProto3Syntax(message) if field.is_extension: - if not self._allow_multiple_scalars and message.HasExtension(field): + if (not self._allow_multiple_scalars and can_check_presence and + message.HasExtension(field)): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" should not have multiple "%s" extensions.' % (message.DESCRIPTOR.full_name, field.full_name)) else: message.Extensions[field] = value else: - if not self._allow_multiple_scalars and message.HasField(field.name): + if (not self._allow_multiple_scalars and can_check_presence and + message.HasField(field.name)): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" should not have multiple "%s" fields.' % (message.DESCRIPTOR.full_name, field.name)) @@ -1088,7 +1141,7 @@ class Tokenizer(object): """ result = self.token if not self._IDENTIFIER_OR_NUMBER.match(result): - raise self.ParseError('Expected identifier or number.') + raise self.ParseError('Expected identifier or number, got %s.' % result) self.NextToken() return result |