diff options
author | Feng Xiao <xfxyjwf@gmail.com> | 2015-08-22 18:25:48 -0700 |
---|---|---|
committer | Feng Xiao <xfxyjwf@gmail.com> | 2015-08-22 18:25:48 -0700 |
commit | eee38b0c018b3279f77d03dff796f440f40d3516 (patch) | |
tree | 7ff0978e30238d493fc7899b75abeb6d66939f07 /python/google/protobuf/internal | |
parent | c3bc155aceda36ecb01cde2367a3b427f2d7ce40 (diff) |
Down-integrate from google3.
Diffstat (limited to 'python/google/protobuf/internal')
-rwxr-xr-x | python/google/protobuf/internal/containers.py | 11 | ||||
-rwxr-xr-x | python/google/protobuf/internal/generator_test.py | 3 | ||||
-rwxr-xr-x | python/google/protobuf/internal/message_test.py | 115 | ||||
-rw-r--r-- | python/google/protobuf/internal/packed_field_test.proto | 73 | ||||
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 151 | ||||
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 48 | ||||
-rwxr-xr-x | python/google/protobuf/internal/test_util.py | 3 | ||||
-rwxr-xr-x | python/google/protobuf/internal/text_format_test.py | 31 | ||||
-rwxr-xr-x | python/google/protobuf/internal/unknown_fields_test.py | 72 |
9 files changed, 407 insertions, 100 deletions
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 72c2fa01..9c8275eb 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,6 +41,7 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import collections import sys if sys.version_info[0] < 3: @@ -63,7 +64,6 @@ if sys.version_info[0] < 3: # Note: deriving from object is critical. It is the only thing that makes # this a true type, allowing us to derive from it in C++ cleanly and making # __slots__ properly disallow arbitrary element assignment. - from collections import Mapping as _Mapping class Mapping(object): __slots__ = () @@ -106,7 +106,7 @@ if sys.version_info[0] < 3: __hash__ = None def __eq__(self, other): - if not isinstance(other, _Mapping): + if not isinstance(other, collections.Mapping): return NotImplemented return dict(self.items()) == dict(other.items()) @@ -173,12 +173,13 @@ if sys.version_info[0] < 3: self[key] = default return default - _Mapping.register(Mapping) + collections.Mapping.register(Mapping) + collections.MutableMapping.register(MutableMapping) else: # In Python 3 we can just use MutableMapping directly, because it defines # __slots__. - from collections import MutableMapping + MutableMapping = collections.MutableMapping class BaseContainer(object): @@ -336,6 +337,8 @@ class RepeatedScalarFieldContainer(BaseContainer): # We are presumably comparing against some other sequence type. return other == self._values +collections.MutableSequence.register(BaseContainer) + class RepeatedCompositeFieldContainer(BaseContainer): diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 5c07cbe6..c30f633d 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -47,6 +47,7 @@ from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_mset_wire_format_pb2 from google.protobuf import unittest_no_generic_services_pb2 from google.protobuf import unittest_pb2 from google.protobuf import service @@ -142,7 +143,7 @@ class GeneratorTest(unittest.TestCase): self.assertTrue(not non_extension_descriptor.is_extension) def testOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = unittest_mset_wire_format_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) def testMessageWithCustomOptions(self): diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 320ff0d2..62abf1be 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -43,6 +43,7 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' +import collections import copy import math import operator @@ -56,6 +57,7 @@ from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message @@ -421,6 +423,31 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + def testSortingRepeatedCompositeFieldsStable(self, message_module): + """Check passing a custom comparator to sort a repeated composite field.""" + message = message_module.TestAllTypes() + + message.repeated_nested_message.add().bb = 21 + message.repeated_nested_message.add().bb = 20 + message.repeated_nested_message.add().bb = 13 + message.repeated_nested_message.add().bb = 33 + message.repeated_nested_message.add().bb = 11 + message.repeated_nested_message.add().bb = 24 + message.repeated_nested_message.add().bb = 10 + message.repeated_nested_message.sort(key=lambda z: z.bb // 10) + self.assertEquals( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + + # Make sure that for the C++ implementation, the underlying fields + # are actually reordered. + pb = message.SerializeToString() + message.Clear() + message.MergeFromString(pb) + self.assertEquals( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + def testRepeatedCompositeFieldSortArguments(self, message_module): """Check sorting a repeated composite field using list.sort() arguments.""" message = message_module.TestAllTypes() @@ -514,6 +541,12 @@ class MessageTest(unittest.TestCase): # TODO(anuraag): Implement extensiondict comparison in C++ and then add test + def testRepeatedFieldsAreSequences(self, message_module): + m = message_module.TestAllTypes() + self.assertIsInstance(m.repeated_int32, collections.MutableSequence) + self.assertIsInstance(m.repeated_nested_message, + collections.MutableSequence) + def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -556,6 +589,18 @@ class MessageTest(unittest.TestCase): self.assertFalse(m.HasField('oneof_uint32')) self.assertTrue(m.HasField('oneof_string')) + # Read nested message accessor without accessing submessage. + m.oneof_nested_message + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + + # Read accessor of nested message without accessing submessage. + m.oneof_nested_message.bb + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + m.oneof_nested_message.bb = 11 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_string')) @@ -1583,6 +1628,21 @@ class Proto3Test(unittest.TestCase): del msg.map_int32_int32[4] self.assertEqual(0, len(msg.map_int32_int32)) + def testMapsAreMapping(self): + msg = map_unittest_pb2.TestMap() + self.assertIsInstance(msg.map_int32_int32, collections.Mapping) + self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) + self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) + self.assertIsInstance(msg.map_int32_foreign_message, + collections.MutableMapping) + + def testMapFindInitializationErrorsSmokeTest(self): + msg = map_unittest_pb2.TestMap() + msg.map_string_string['abc'] = '123' + msg.map_int32_int32[35] = 64 + msg.map_string_foreign_message['foo'].c = 5 + self.assertEqual(0, len(msg.FindInitializationErrors())) + class ValidTypeNamesTest(unittest.TestCase): @@ -1606,6 +1666,61 @@ class ValidTypeNamesTest(unittest.TestCase): self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_nested_message, 'Composite') +class PackedFieldTest(unittest.TestCase): + + def setMessage(self, message): + message.repeated_int32.append(1) + message.repeated_int64.append(1) + message.repeated_uint32.append(1) + message.repeated_uint64.append(1) + message.repeated_sint32.append(1) + message.repeated_sint64.append(1) + message.repeated_fixed32.append(1) + message.repeated_fixed64.append(1) + message.repeated_sfixed32.append(1) + message.repeated_sfixed64.append(1) + message.repeated_float.append(1.0) + message.repeated_double.append(1.0) + message.repeated_bool.append(True) + message.repeated_nested_enum.append(1) + + def testPackedFields(self): + message = packed_field_test_pb2.TestPackedTypes() + self.setMessage(message) + golden_data = (b'\x0A\x01\x01' + b'\x12\x01\x01' + b'\x1A\x01\x01' + b'\x22\x01\x01' + b'\x2A\x01\x02' + b'\x32\x01\x02' + b'\x3A\x04\x01\x00\x00\x00' + b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4A\x04\x01\x00\x00\x00' + b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5A\x04\x00\x00\x80\x3f' + b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x6A\x01\x01' + b'\x72\x01\x01') + self.assertEqual(golden_data, message.SerializeToString()) + + def testUnpackedFields(self): + message = packed_field_test_pb2.TestUnpackedTypes() + self.setMessage(message) + golden_data = (b'\x08\x01' + b'\x10\x01' + b'\x18\x01' + b'\x20\x01' + b'\x28\x02' + b'\x30\x02' + b'\x3D\x01\x00\x00\x00' + b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4D\x01\x00\x00\x00' + b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5D\x00\x00\x80\x3f' + b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x68\x01' + b'\x70\x01') + self.assertEqual(golden_data, message.SerializeToString()) if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/packed_field_test.proto b/python/google/protobuf/internal/packed_field_test.proto index e69de29b..0dfdc10a 100644 --- a/python/google/protobuf/internal/packed_field_test.proto +++ b/python/google/protobuf/internal/packed_field_test.proto @@ -0,0 +1,73 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package google.protobuf.python.internal; + +message TestPackedTypes { + enum NestedEnum { + FOO = 0; + BAR = 1; + BAZ = 2; + } + + repeated int32 repeated_int32 = 1; + repeated int64 repeated_int64 = 2; + repeated uint32 repeated_uint32 = 3; + repeated uint64 repeated_uint64 = 4; + repeated sint32 repeated_sint32 = 5; + repeated sint64 repeated_sint64 = 6; + repeated fixed32 repeated_fixed32 = 7; + repeated fixed64 repeated_fixed64 = 8; + repeated sfixed32 repeated_sfixed32 = 9; + repeated sfixed64 repeated_sfixed64 = 10; + repeated float repeated_float = 11; + repeated double repeated_double = 12; + repeated bool repeated_bool = 13; + repeated NestedEnum repeated_nested_enum = 14; +} + +message TestUnpackedTypes { + repeated int32 repeated_int32 = 1 [packed = false]; + repeated int64 repeated_int64 = 2 [packed = false]; + repeated uint32 repeated_uint32 = 3 [packed = false]; + repeated uint64 repeated_uint64 = 4 [packed = false]; + repeated sint32 repeated_sint32 = 5 [packed = false]; + repeated sint64 repeated_sint64 = 6 [packed = false]; + repeated fixed32 repeated_fixed32 = 7 [packed = false]; + repeated fixed64 repeated_fixed64 = 8 [packed = false]; + repeated sfixed32 repeated_sfixed32 = 9 [packed = false]; + repeated sfixed64 repeated_sfixed64 = 10 [packed = false]; + repeated float repeated_float = 11 [packed = false]; + repeated double repeated_double = 12 [packed = false]; + repeated bool repeated_bool = 13 [packed = false]; + repeated TestPackedTypes.NestedEnum repeated_nested_enum = 14 [packed = false]; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index ca9f7675..a3e98467 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -85,34 +85,108 @@ from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -def NewMessage(bases, descriptor, dictionary): - _AddClassAttributesForNestedExtensions(descriptor, dictionary) - _AddSlots(descriptor, dictionary) - return bases - - -def InitMessage(descriptor, cls): - cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} - if (descriptor.has_options and - descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number), None) - - # Attach stuff to each FieldDescriptor for quick lookup later on. - for field in descriptor.fields: - _AttachFieldHelpers(cls, field) +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + + superclass = super(GeneratedProtocolMessageType, cls) + new_class = superclass.__new__(cls, name, bases, dictionary) + return new_class + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. - descriptor._concrete_class = cls # pylint: disable=protected-access - _AddEnumValues(descriptor, cls) - _AddInitMethod(descriptor, cls) - _AddPropertiesForFields(descriptor, cls) - _AddPropertiesForExtensions(descriptor, cls) - _AddStaticMethods(cls) - _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(descriptor, cls) - copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + descriptor._concrete_class = cls # pylint: disable=protected-access + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) # Stateless helpers for GeneratedProtocolMessageType below. @@ -362,9 +436,10 @@ def _DefaultValueConstructorForField(field): message_type = field.message_type def MakeSubMessageDefault(message): result = message_type._concrete_class() - result._SetListener(message._listener_for_children) - if field.containing_oneof: - message._UpdateOneofState(field) + result._SetListener( + _OneofListener(message, field) + if field.containing_oneof is not None + else message._listener_for_children) return result return MakeSubMessageDefault @@ -634,21 +709,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): proto_field_name = field.name property_name = _PropertyName(proto_field_name) - # TODO(komarek): Can anyone explain to me why we cache the message_type this - # way, instead of referring to field.message_type inside of getter(self)? - # What if someone sets message_type later on (which makes for simpler - # dyanmic proto descriptor and class creation code). - message_type = field.message_type - def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener( - _OneofListener(self, field) - if field.containing_oneof is not None - else self._listener_for_children) + field_value = field._default_constructor(self) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -1121,7 +1186,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): if _IsMessageMapField(field): for key in value: element = value[key] - prefix = "%s[%d]." % (name, key) + prefix = "%s[%s]." % (name, key) sub_errors = element.FindInitializationErrors() errors += [prefix + error for error in sub_errors] else: @@ -1173,8 +1238,6 @@ def _AddMergeFromMethod(cls): # Construct a new object to represent this field. field_value = field._default_constructor(self) fields[field] = field_value - if field.containing_oneof: - self._UpdateOneofState(field) field_value.MergeFrom(value) else: self._fields[field] = value diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 4eca4989..ef1ced4e 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -52,6 +52,7 @@ from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import wire_format from google.protobuf.internal import test_util from google.protobuf.internal import decoder @@ -1682,8 +1683,8 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = 'abc' def testStringUTF8Serialization(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message = message_set_extensions_pb2.TestMessageSetExtension2 extension = extension_message.message_set_extension test_utf8 = u'Тест' @@ -1703,15 +1704,14 @@ class ReflectionTest(unittest.TestCase): bytes_read = raw.MergeFromString(serialized) self.assertEqual(len(serialized), bytes_read) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual(1, len(raw.item)) # Check that the type_id is the same as the tag ID in the .proto file. - self.assertEqual(raw.item[0].type_id, 1547769) + self.assertEqual(raw.item[0].type_id, 98418634) # Check the actual bytes on the wire. - self.assertTrue( - raw.item[0].message.endswith(test_utf8_bytes)) + self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) bytes_read = message2.MergeFromString(raw.item[0].message) self.assertEqual(len(raw.item[0].message), bytes_read) @@ -2395,9 +2395,9 @@ class SerializationTest(unittest.TestCase): self.assertEqual(42, second_proto.optional_nested_message.bb) def testMessageSetWireFormat(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 extension1 = extension_message1.message_set_extension extension2 = extension_message2.message_set_extension proto.Extensions[extension1].i = 123 @@ -2415,20 +2415,20 @@ class SerializationTest(unittest.TestCase): raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) - message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1 = message_set_extensions_pb2.TestMessageSetExtension1() self.assertEqual( len(raw.item[0].message), message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual( len(raw.item[1].message), message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. - proto2 = unittest_mset_pb2.TestMessageSet() + proto2 = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto2.MergeFromString(serialized)) @@ -2446,37 +2446,37 @@ class SerializationTest(unittest.TestCase): # Add an item. item = raw.item.add() - item.type_id = 1545008 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418603 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() # Add a second, unknown extension. item = raw.item.add() - item.type_id = 1545009 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418604 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12346 item.message = message1.SerializeToString() # Add another unknown extension. item = raw.item.add() - item.type_id = 1545010 - message1 = unittest_mset_pb2.TestMessageSetExtension2() + item.type_id = 98418605 + message1 = message_set_extensions_pb2.TestMessageSetExtension2() message1.str = 'foo' item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto.MergeFromString(serialized)) # Check that the message parsed well. - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 extension1 = extension_message1.message_set_extension self.assertEquals(12345, proto.Extensions[extension1].i) @@ -2805,7 +2805,7 @@ class SerializationTest(unittest.TestCase): class OptionsTest(unittest.TestCase): def testMessageOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual(True, proto.DESCRIPTOR.GetOptions().message_set_wire_format) proto = unittest_pb2.TestAllTypes() @@ -2824,7 +2824,7 @@ class OptionsTest(unittest.TestCase): proto.packed_double.append(3.0) for field_descriptor, _ in proto.ListFields(): self.assertEqual(True, field_descriptor.GetOptions().packed) - self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, field_descriptor.label) diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index fec65382..ac88fa81 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -604,7 +604,8 @@ def GoldenFile(filename): # Search internally. path = '.' - full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', filename) + full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', + filename) if os.path.exists(full_path): # Found it. Load the golden file from the testdata directory. return open(full_path, 'rb') diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 06bd1ee5..00e67654 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -35,6 +35,7 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re +import string import unittest import unittest @@ -497,6 +498,36 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ' }\n' '}\n') + def testMapOrderEnforcement(self): + message = map_unittest_pb2.TestMap() + for letter in string.ascii_uppercase[13:26]: + message.map_string_string[letter] = 'dummy' + for letter in reversed(string.ascii_uppercase[0:13]): + message.map_string_string[letter] = 'dummy' + golden = ''.join(( + 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,) + for letter in string.ascii_uppercase)) + self.CompareToGoldenText(text_format.MessageToString(message), golden) + + def testMapOrderSemantics(self): + golden_lines = self.ReadGolden('map_test_data.txt') + # The C++ implementation emits defaulted-value fields, while the Python + # implementation does not. Adjusting for this is awkward, but it is + # valuable to test against a common golden file. + line_blacklist = (' key: 0\n', + ' value: 0\n', + ' key: false\n', + ' value: false\n') + golden_lines = [line for line in golden_lines if line not in line_blacklist] + + message = map_unittest_pb2.TestMap() + text_format.ParseLines(golden_lines, message) + candidate = text_format.MessageToString(message) + # The Python implementation emits "1.0" for the double value that the C++ + # implementation emits as "1". + candidate = candidate.replace('1.0', '1', 2) + self.assertMultiLineEqual(candidate, ''.join(golden_lines)) + # Tests of proto2-only features (MessageSet, extensions, etc.). class Proto2Tests(TextFormatBase): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 1b81ae79..0dda805b 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -41,11 +41,18 @@ from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers +def SkipIfCppImplementation(func): + return unittest.skipIf( + api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, + 'C++ implementation does not expose unknown fields to Python')(func) + + class UnknownFieldsTest(unittest.TestCase): def setUp(self): @@ -83,15 +90,15 @@ class UnknownFieldsTest(unittest.TestCase): # Add an unknown extension. item = raw.item.add() - item.type_id = 1545009 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418603 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() proto.MergeFromString(serialized) # Verify that the unknown extension is serialized unchanged @@ -100,13 +107,6 @@ class UnknownFieldsTest(unittest.TestCase): new_raw.MergeFromString(reserialized) self.assertEqual(raw, new_raw) - # C++ implementation for proto2 does not currently take into account unknown - # fields when checking equality. - # - # TODO(haberman): fix this. - @unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') def testEquals(self): message = unittest_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) @@ -117,9 +117,6 @@ class UnknownFieldsTest(unittest.TestCase): self.assertNotEqual(self.empty_message, message) -@unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') class UnknownFieldsAccessorsTest(unittest.TestCase): def setUp(self): @@ -129,7 +126,14 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.all_fields_data = self.all_fields.SerializeToString() self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - self.unknown_fields = self.empty_message._unknown_fields + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.empty_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -142,30 +146,37 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): decoder(value, 0, len(value), self.all_fields, result_dict) return result_dict[field_descriptor] + @SkipIfCppImplementation def testEnum(self): value = self.GetField('optional_nested_enum') self.assertEqual(self.all_fields.optional_nested_enum, value) + @SkipIfCppImplementation def testRepeatedEnum(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.all_fields.repeated_nested_enum, value) + @SkipIfCppImplementation def testVarint(self): value = self.GetField('optional_int32') self.assertEqual(self.all_fields.optional_int32, value) + @SkipIfCppImplementation def testFixed32(self): value = self.GetField('optional_fixed32') self.assertEqual(self.all_fields.optional_fixed32, value) + @SkipIfCppImplementation def testFixed64(self): value = self.GetField('optional_fixed64') self.assertEqual(self.all_fields.optional_fixed64, value) + @SkipIfCppImplementation def testLengthDelimited(self): value = self.GetField('optional_string') self.assertEqual(self.all_fields.optional_string, value) + @SkipIfCppImplementation def testGroup(self): value = self.GetField('optionalgroup') self.assertEqual(self.all_fields.optionalgroup, value) @@ -173,7 +184,7 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() message.CopyFrom(self.empty_message) - self.assertEqual(self.unknown_fields, message._unknown_fields) + self.assertEqual(message.SerializeToString(), self.all_fields_data) def testMergeFrom(self): message = unittest_pb2.TestAllTypes() @@ -187,27 +198,26 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) - unknown_fields = destination._unknown_fields[:] destination.MergeFrom(source) - self.assertEqual(unknown_fields + source._unknown_fields, - destination._unknown_fields) + # Check that the fields where correctly merged, even stored in the unknown + # fields set. + message.ParseFromString(destination.SerializeToString()) + self.assertEqual(message.optional_int32, 1) + self.assertEqual(message.optional_uint32, 2) + self.assertEqual(message.optional_int64, 3) def testClear(self): self.empty_message.Clear() - self.assertEqual(0, len(self.empty_message._unknown_fields)) + # All cleared, even unknown fields. + self.assertEqual(self.empty_message.SerializeToString(), b'') def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() message.ParseFromString(self.all_fields_data) - self.assertEqual(self.empty_message._unknown_fields, - message._unknown_fields) - + self.assertEqual(message.SerializeToString(), self.all_fields_data) -@unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') class UnknownEnumValuesTest(unittest.TestCase): def setUp(self): @@ -227,7 +237,14 @@ class UnknownEnumValuesTest(unittest.TestCase): self.message_data = self.message.SerializeToString() self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - self.unknown_fields = self.missing_message._unknown_fields + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.missing_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -241,15 +258,18 @@ class UnknownEnumValuesTest(unittest.TestCase): decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor] + @SkipIfCppImplementation def testUnknownEnumValue(self): self.assertFalse(self.missing_message.HasField('optional_nested_enum')) value = self.GetField('optional_nested_enum') self.assertEqual(self.message.optional_nested_enum, value) + @SkipIfCppImplementation def testUnknownRepeatedEnumValue(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.message.repeated_nested_enum, value) + @SkipIfCppImplementation def testUnknownPackedEnumValue(self): value = self.GetField('packed_nested_enum') self.assertEqual(self.message.packed_nested_enum, value) |