diff options
author | 2012-09-22 02:40:50 +0000 | |
---|---|---|
committer | 2012-09-22 02:40:50 +0000 | |
commit | b55a20fa2c669b181f47ea9219b8e74d1263da19 (patch) | |
tree | 3936a0e7c22196587a6d8397372de41434fe2129 /python/google/protobuf/internal | |
parent | 9ced30caf94bb4e7e9629c199679ff44e8ca7389 (diff) |
Down-integrate from internal branch
Diffstat (limited to 'python/google/protobuf/internal')
22 files changed, 2003 insertions, 94 deletions
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index b3e412e2..ce02a329 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -56,9 +56,32 @@ if _implementation_type != 'python': # _implementation_type = 'python' +# This environment variable can be used to switch between the two +# 'cpp' implementations. Right now only 1 and 2 are valid values. Any +# other value will be ignored. +_implementation_version_str = os.getenv( + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', + '1') + + +if _implementation_version_str not in ('1', '2'): + raise ValueError( + "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + + _implementation_version_str + "' (supported versions: 1, 2)" + ) + + +_implementation_version = int(_implementation_version_str) + + + # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. # Please don't use this function if possible. def Type(): return _implementation_type + +# See comment on 'Type' above. +def Version(): + return _implementation_version diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 097a3c26..34b35f8a 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -78,8 +78,13 @@ class BaseContainer(object): def __repr__(self): return repr(self._values) - def sort(self, sort_function=cmp): - self._values.sort(sort_function) + def sort(self, *args, **kwargs): + # Continue to support the old sort_function keyword argument. + # This is expected to be a rare occurrence, so use LBYL to avoid + # the overhead of actually catching KeyError. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._values.sort(*args, **kwargs) class RepeatedScalarFieldContainer(BaseContainer): @@ -235,6 +240,11 @@ class RepeatedCompositeFieldContainer(BaseContainer): """ self.extend(other._values) + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py index 3f426502..23ab9ba4 100755 --- a/python/google/protobuf/internal/cpp_message.py +++ b/python/google/protobuf/internal/cpp_message.py @@ -34,8 +34,10 @@ Descriptor objects at runtime backed by the protocol buffer C++ API. __author__ = 'petar@google.com (Petar Petrov)' +import copy_reg import operator from google.protobuf.internal import _net_proto2___python +from google.protobuf.internal import enum_type_wrapper from google.protobuf import message @@ -156,10 +158,12 @@ class RepeatedScalarContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - values = self[slice(None, None, None)] - values.sort(sort_function) - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + def sort(self, *args, **kwargs): + # Maintain compatibility with the previous interface. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, + sorted(self, *args, **kwargs)) def RepeatedScalarProperty(cdescriptor): @@ -202,6 +206,12 @@ class RepeatedCompositeContainer(object): for message in elem_seq: self.add().MergeFrom(message) + def remove(self, value): + # TODO(protocol-devel): This is inefficient as it needs to generate a + # message pointer for each message only to do index(). Move this to a C++ + # extension function. + self.__delitem__(self[slice(None, None, None)].index(value)) + def MergeFrom(self, other): for message in other[:]: self.add().MergeFrom(message) @@ -236,27 +246,29 @@ class RepeatedCompositeContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - messages = [] - for index in range(len(self)): - # messages[i][0] is where the i-th element of the new array has to come - # from. - # messages[i][1] is where the i-th element of the old array has to go. - messages.append([index, 0, self[index]]) - messages.sort(lambda x,y: sort_function(x[2], y[2])) + def sort(self, cmp=None, key=None, reverse=False, **kwargs): + # Maintain compatibility with the old interface. + if cmp is None and 'sort_function' in kwargs: + cmp = kwargs.pop('sort_function') - # Remember which position each elements has to move to. - for i in range(len(messages)): - messages[messages[i][0]][1] = i + # The cmp function, if provided, is passed the results of the key function, + # so we only need to wrap one of them. + if key is None: + index_key = self.__getitem__ + else: + index_key = lambda i: key(self[i]) + + # Sort the list of current indexes by the underlying object. + indexes = range(len(self)) + indexes.sort(cmp=cmp, key=index_key, reverse=reverse) # Apply the transposition. - for i in range(len(messages)): - from_position = messages[i][0] - if i == from_position: + for dest, src in enumerate(indexes): + if dest == src: continue - self._cmsg.SwapRepeatedFieldElements( - self._cfield_descriptor, i, from_position) - messages[messages[i][1]][0] = from_position + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) + # Don't swap the same value twice. + indexes[src] = src def RepeatedCompositeProperty(cdescriptor, message_type): @@ -359,11 +371,12 @@ class ExtensionDict(object): return None -def NewMessage(message_descriptor, dictionary): +def NewMessage(bases, message_descriptor, dictionary): """Creates a new protocol message *class*.""" _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) _AddEnumValues(message_descriptor, dictionary) _AddDescriptors(message_descriptor, dictionary) + return bases def InitMessage(message_descriptor, cls): @@ -372,6 +385,7 @@ def InitMessage(message_descriptor, cls): _AddInitMethod(message_descriptor, cls) _AddMessageMethods(message_descriptor, cls) _AddPropertiesForExtensions(message_descriptor, cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) def _AddDescriptors(message_descriptor, dictionary): @@ -387,7 +401,7 @@ def _AddDescriptors(message_descriptor, dictionary): field.full_name) dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ - '_cmsg', '_owner', '_composite_fields', 'Extensions'] + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] def _AddEnumValues(message_descriptor, dictionary): @@ -398,6 +412,7 @@ def _AddEnumValues(message_descriptor, dictionary): dictionary: Class dictionary that should be populated. """ for enum_type in message_descriptor.enum_types: + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) for enum_value in enum_type.values: dictionary[enum_value.name] = enum_value.number @@ -439,28 +454,35 @@ def _AddInitMethod(message_descriptor, cls): def Init(self, **kwargs): """Message constructor.""" cmessage = kwargs.pop('__cmessage', None) - if cmessage is None: - self._cmsg = NewCMessage(message_descriptor.full_name) - else: + if cmessage: self._cmsg = cmessage + else: + self._cmsg = NewCMessage(message_descriptor.full_name) # Keep a reference to the owner, as the owner keeps a reference to the # underlying protocol buffer message. owner = kwargs.pop('__owner', None) - if owner is not None: + if owner: self._owner = owner - self.Extensions = ExtensionDict(self) + if message_descriptor.is_extendable: + self.Extensions = ExtensionDict(self) + else: + # Reference counting in the C++ code is broken and depends on + # the Extensions reference to keep this object alive during unit + # tests (see b/4856052). Remove this once b/4945904 is fixed. + self._HACK_REFCOUNTS = self self._composite_fields = {} for field_name, field_value in kwargs.iteritems(): field_cdescriptor = self.__descriptors.get(field_name, None) - if field_cdescriptor is None: + if not field_cdescriptor: raise ValueError('Protocol message has no "%s" field.' % field_name) if field_cdescriptor.label == _LABEL_REPEATED: if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + field_name = getattr(self, field_name) for val in field_value: - getattr(self, field_name).add().MergeFrom(val) + field_name.add().MergeFrom(val) else: getattr(self, field_name).extend(field_value) elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: @@ -497,12 +519,34 @@ def _AddMessageMethods(message_descriptor, cls): return self._cmsg.HasField(field_name) def ClearField(self, field_name): + child_cmessage = None if field_name in self._composite_fields: + child_field = self._composite_fields[field_name] del self._composite_fields[field_name] - self._cmsg.ClearField(field_name) + + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + child_cmessage = child_field._cmsg + + if child_cmessage is not None: + self._cmsg.ClearField(field_name, child_cmessage) + else: + self._cmsg.ClearField(field_name) def Clear(self): - return self._cmsg.Clear() + cmessages_to_release = [] + for field_name, child_field in self._composite_fields.iteritems(): + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) + self._composite_fields.clear() + self._cmsg.Clear(cmessages_to_release) def IsInitialized(self, errors=None): if self._cmsg.IsInitialized(): @@ -514,8 +558,8 @@ def _AddMessageMethods(message_descriptor, cls): def SerializeToString(self): if not self.IsInitialized(): raise message.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) return self._cmsg.SerializeToString() def SerializePartialToString(self): @@ -534,7 +578,8 @@ def _AddMessageMethods(message_descriptor, cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) self._cmsg.MergeFrom(msg._cmsg) def CopyFrom(self, msg): @@ -581,6 +626,8 @@ def _AddMessageMethods(message_descriptor, cls): raise TypeError('unhashable object') def __unicode__(self): + # Lazy import to prevent circular import when text_format imports this file. + from google.protobuf import text_format return text_format.MessageToString(self, as_utf8=True).decode('utf-8') # Attach the local methods to the message class. diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 55f746f5..cb6f5729 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -576,6 +576,7 @@ def MessageSetItemDecoder(extensions_by_number): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 @@ -614,6 +615,11 @@ def MessageSetItemDecoder(extensions_by_number): # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, + buffer[message_set_item_start:pos])) return pos diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py new file mode 100644 index 00000000..d0ca7892 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_database.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database + + +class DescriptorDatabaseTest(unittest.TestCase): + + def testAdd(self): + db = descriptor_database.DescriptorDatabase() + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db.Add(file_desc_proto) + + self.assertEquals(file_desc_proto, db.FindFileByName( + 'net/proto2/python/internal/factory_test2.proto')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Enum')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum')) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py new file mode 100644 index 00000000..a615d787 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -0,0 +1,220 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_pool.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool + + +class DescriptorPoolTest(unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + def testFindFileByName(self): + name1 = 'net/proto2/python/internal/factory_test1.proto' + file_desc1 = self.pool.FindFileByName(name1) + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals(name1, file_desc1.name) + self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + name2 = 'net/proto2/python/internal/factory_test2.proto' + file_desc2 = self.pool.FindFileByName(name2) + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals(name2, file_desc2.name) + self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileByNameFailure(self): + try: + self.pool.FindFileByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindFileContainingSymbol(self): + file_desc1 = self.pool.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory1Message') + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals('net/proto2/python/internal/factory_test1.proto', + file_desc1.name) + self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + file_desc2 = self.pool.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message') + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals('net/proto2/python/internal/factory_test2.proto', + file_desc2.name) + self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileContainingSymbolFailure(self): + try: + self.pool.FindFileContainingSymbol('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindMessageTypeByName(self): + msg1 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory1Message') + self.assertIsInstance(msg1, descriptor.Descriptor) + self.assertEquals('Factory1Message', msg1.name) + self.assertEquals('net.proto2.python.internal.Factory1Message', + msg1.full_name) + self.assertEquals(None, msg1.containing_type) + + nested_msg1 = msg1.nested_types[0] + self.assertEquals('NestedFactory1Message', nested_msg1.name) + self.assertEquals(msg1, nested_msg1.containing_type) + + nested_enum1 = msg1.enum_types[0] + self.assertEquals('NestedFactory1Enum', nested_enum1.name) + self.assertEquals(msg1, nested_enum1.containing_type) + + self.assertEquals(nested_msg1, msg1.fields_by_name[ + 'nested_factory_1_message'].message_type) + self.assertEquals(nested_enum1, msg1.fields_by_name[ + 'nested_factory_1_enum'].enum_type) + + msg2 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message') + self.assertIsInstance(msg2, descriptor.Descriptor) + self.assertEquals('Factory2Message', msg2.name) + self.assertEquals('net.proto2.python.internal.Factory2Message', + msg2.full_name) + self.assertIsNone(msg2.containing_type) + + nested_msg2 = msg2.nested_types[0] + self.assertEquals('NestedFactory2Message', nested_msg2.name) + self.assertEquals(msg2, nested_msg2.containing_type) + + nested_enum2 = msg2.enum_types[0] + self.assertEquals('NestedFactory2Enum', nested_enum2.name) + self.assertEquals(msg2, nested_enum2.containing_type) + + self.assertEquals(nested_msg2, msg2.fields_by_name[ + 'nested_factory_2_message'].message_type) + self.assertEquals(nested_enum2, msg2.fields_by_name[ + 'nested_factory_2_enum'].enum_type) + + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default) + self.assertEquals( + 1776, msg2.fields_by_name['int_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['double_with_default'].has_default) + self.assertEquals( + 9.99, msg2.fields_by_name['double_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['string_with_default'].has_default) + self.assertEquals( + 'hello world', msg2.fields_by_name['string_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default) + self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default) + self.assertEquals( + 1, msg2.fields_by_name['enum_with_default'].default_value) + + msg3 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Message') + self.assertEquals(nested_msg2, msg3) + + def testFindMessageTypeByNameFailure(self): + try: + self.pool.FindMessageTypeByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindEnumTypeByName(self): + enum1 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory1Enum') + self.assertIsInstance(enum1, descriptor.EnumDescriptor) + self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) + self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + + nested_enum1 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory1Message.NestedFactory1Enum') + self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) + self.assertEquals( + 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) + + enum2 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory2Enum') + self.assertIsInstance(enum2, descriptor.EnumDescriptor) + self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) + self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) + + nested_enum2 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum') + self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) + self.assertEquals( + 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) + + def testFindEnumTypeByNameFailure(self): + try: + self.pool.FindEnumTypeByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testUserDefinedDB(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + self.testFindMessageTypeByName() + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 05c27452..c74f882e 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -35,6 +35,7 @@ __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 @@ -101,6 +102,15 @@ class DescriptorTest(unittest.TestCase): self.my_method ]) + def testEnumValueName(self): + self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), + 'FOREIGN_FOO') + + self.assertEqual( + self.my_message.enum_types_by_name[ + 'ForeignEnum'].values_by_number[4].name, + self.my_message.EnumValueName('ForeignEnum', 4)) + def testEnumFixups(self): self.assertEqual(self.my_enum, self.my_enum.values[0].type) @@ -125,6 +135,257 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_service.GetOptions(), descriptor_pb2.ServiceOptions()) + def testSimpleCustomOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["field1"] + enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"] + enum_value_descriptor =\ + message_descriptor.enum_values_by_name["ANENUM_VAL2"] + service_descriptor =\ + unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Foo") + + file_options = file_descriptor.GetOptions() + file_opt1 = unittest_custom_options_pb2.file_opt1 + self.assertEqual(9876543210, file_options.Extensions[file_opt1]) + message_options = message_descriptor.GetOptions() + message_opt1 = unittest_custom_options_pb2.message_opt1 + self.assertEqual(-56, message_options.Extensions[message_opt1]) + field_options = field_descriptor.GetOptions() + field_opt1 = unittest_custom_options_pb2.field_opt1 + self.assertEqual(8765432109, field_options.Extensions[field_opt1]) + field_opt2 = unittest_custom_options_pb2.field_opt2 + self.assertEqual(42, field_options.Extensions[field_opt2]) + enum_options = enum_descriptor.GetOptions() + enum_opt1 = unittest_custom_options_pb2.enum_opt1 + self.assertEqual(-789, enum_options.Extensions[enum_opt1]) + enum_value_options = enum_value_descriptor.GetOptions() + enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1 + self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1]) + + service_options = service_descriptor.GetOptions() + service_opt1 = unittest_custom_options_pb2.service_opt1 + self.assertEqual(-9876543210, service_options.Extensions[service_opt1]) + method_options = method_descriptor.GetOptions() + method_opt1 = unittest_custom_options_pb2.method_opt1 + self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, + method_options.Extensions[method_opt1]) + + def testDifferentCustomOptionTypes(self): + kint32min = -2**31 + kint64min = -2**63 + kint32max = 2**31 - 1 + kint64max = 2**63 - 1 + kuint32max = 2**32 - 1 + kuint64max = 2**64 - 1 + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(False, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(True, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(-100, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertAlmostEqual(12.3456789, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(1.234567890123456789, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + self.assertEqual("Hello, \"World\"", message_options.Extensions[ + unittest_custom_options_pb2.string_opt]) + self.assertEqual("Hello\0World", message_options.Extensions[ + unittest_custom_options_pb2.bytes_opt]) + dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum + self.assertEqual( + dummy_enum.TEST_OPTION_ENUM_TYPE2, + message_options.Extensions[unittest_custom_options_pb2.enum_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(-12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(-154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + def testComplexExtensionOptions(self): + descriptor =\ + unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR + options = descriptor.GetOptions() + self.assertEqual(42, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].foo) + self.assertEqual(324, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(876, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(987, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].baz) + self.assertEqual(654, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.grault]) + self.assertEqual(743, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.foo) + self.assertEqual(1999, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2008, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(741, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].foo) + self.assertEqual(1998, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2121, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(1971, options.Extensions[ + unittest_custom_options_pb2.ComplexOptionType2 + .ComplexOptionType4.complex_opt4].waldo) + self.assertEqual(321, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].fred.waldo) + self.assertEqual(9, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].qux) + self.assertEqual(22, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].complexoptiontype5.plugh) + self.assertEqual(24, options.Extensions[ + unittest_custom_options_pb2.complexopt6].xyzzy) + + # Check that aggregate options were parsed and saved correctly in + # the appropriate descriptors. + def testAggregateOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["fieldname"] + enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR + enum_value_descriptor = enum_descriptor.values_by_name["VALUE"] + service_descriptor =\ + unittest_custom_options_pb2.AggregateService.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Method") + + # Tests for the different types of data embedded in fileopt + file_options = file_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fileopt] + self.assertEqual(100, file_options.i) + self.assertEqual("FileAnnotation", file_options.s) + self.assertEqual("NestedFileAnnotation", file_options.sub.s) + self.assertEqual("FileExtensionAnnotation", file_options.file.Extensions[ + unittest_custom_options_pb2.fileopt].s) + self.assertEqual("EmbeddedMessageSetElement", file_options.mset.Extensions[ + unittest_custom_options_pb2.AggregateMessageSetElement + .message_set_extension].s) + + # Simple tests for all the other types of annotations + self.assertEqual( + "MessageAnnotation", + message_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.msgopt].s) + self.assertEqual( + "FieldAnnotation", + field_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fieldopt].s) + self.assertEqual( + "EnumAnnotation", + enum_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumopt].s) + self.assertEqual( + "EnumValueAnnotation", + enum_value_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumvalopt].s) + self.assertEqual( + "ServiceAnnotation", + service_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.serviceopt].s) + self.assertEqual( + "MethodAnnotation", + method_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.methodopt].s) + + def testNestedOptions(self): + nested_message =\ + unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR + self.assertEqual(1001, nested_message.GetOptions().Extensions[ + unittest_custom_options_pb2.message_opt1]) + nested_field = nested_message.fields_by_name["nested_field"] + self.assertEqual(1002, nested_field.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt1]) + outer_message =\ + unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR + nested_enum = outer_message.enum_types_by_name["NestedEnum"] + self.assertEqual(1003, nested_enum.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_opt1]) + nested_enum_value = outer_message.enum_values_by_name["NESTED_ENUM_VALUE"] + self.assertEqual(1004, nested_enum_value.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_value_opt1]) + nested_extension = outer_message.extensions_by_name["nested_extension"] + self.assertEqual(1005, nested_extension.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt2]) + def testFileDescriptorReferences(self): self.assertEqual(self.my_enum.file, self.my_file) self.assertEqual(self.my_message.file, self.my_file) @@ -273,6 +534,7 @@ class DescriptorCopyToProtoTest(unittest.TestCase): UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" name: 'google/protobuf/unittest_import.proto' package: 'protobuf_unittest_import' + dependency: 'google/protobuf/unittest_import_public.proto' message_type: < name: 'ImportMessage' field: < @@ -302,6 +564,7 @@ class DescriptorCopyToProtoTest(unittest.TestCase): java_package: 'com.google.protobuf.test' optimize_for: 1 # SPEED > + public_dependency: 0 """) self._InternalTestCopyToProto( @@ -330,5 +593,21 @@ class DescriptorCopyToProtoTest(unittest.TestCase): TEST_SERVICE_ASCII) +class MakeDescriptorTest(unittest.TestCase): + def testMakeDescriptorWithUnsignedIntField(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py new file mode 100644 index 00000000..7b28645a --- /dev/null +++ b/python/google/protobuf/internal/enum_type_wrapper.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A simple wrapper around enum types to expose utility functions. + +Instances are created as properties with the same name as the enum they wrap +on proto classes. For usage, see: + reflection_test.py +""" + +__author__ = 'rabsatt@google.com (Kevin Rabsatt)' + + +class EnumTypeWrapper(object): + """A utility for finding the names of enum values.""" + + DESCRIPTOR = None + + def __init__(self, enum_type): + """Inits EnumTypeWrapper with an EnumDescriptor.""" + self._enum_type = enum_type + self.DESCRIPTOR = enum_type; + + def Name(self, number): + """Returns a string containing the name of an enum value.""" + if number in self._enum_type.values_by_number: + return self._enum_type.values_by_number[number].name + raise ValueError('Enum %s has no name defined for value %d' % ( + self._enum_type.name, number)) + + def Value(self, name): + """Returns the value coresponding to the given enum name.""" + if name in self._enum_type.values_by_name: + return self._enum_type.values_by_name[name].number + raise ValueError('Enum %s has no value defined for name %s' % ( + self._enum_type.name, name)) + + def keys(self): + """Return a list of the string names in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.name + for value_descriptor in self._enum_type.values] + + def values(self): + """Return a list of the integer values in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.number + for value_descriptor in self._enum_type.values] + + def items(self): + """Return a list of the (name, value) pairs of the enum. + + These are returned in the order they were defined in the .proto file. + """ + return [(value_descriptor.name, value_descriptor.number) + for value_descriptor in self._enum_type.values] diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto new file mode 100644 index 00000000..9f55e037 --- /dev/null +++ b/python/google/protobuf/internal/factory_test1.proto @@ -0,0 +1,55 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + + +enum Factory1Enum { + FACTORY_1_VALUE_0 = 0; + FACTORY_1_VALUE_1 = 1; +} + +message Factory1Message { + optional Factory1Enum factory_1_enum = 1; + enum NestedFactory1Enum { + NESTED_FACTORY_1_VALUE_0 = 0; + NESTED_FACTORY_1_VALUE_1 = 1; + } + optional NestedFactory1Enum nested_factory_1_enum = 2; + message NestedFactory1Message { + optional string value = 1; + } + optional NestedFactory1Message nested_factory_1_message = 3; + optional int32 scalar_value = 4; + repeated string list_value = 5; +} diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto new file mode 100644 index 00000000..d3ce4d7f --- /dev/null +++ b/python/google/protobuf/internal/factory_test2.proto @@ -0,0 +1,77 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + +import "google/protobuf/internal/factory_test1.proto"; + + +enum Factory2Enum { + FACTORY_2_VALUE_0 = 0; + FACTORY_2_VALUE_1 = 1; +} + +message Factory2Message { + required int32 mandatory = 1; + optional Factory2Enum factory_2_enum = 2; + enum NestedFactory2Enum { + NESTED_FACTORY_2_VALUE_0 = 0; + NESTED_FACTORY_2_VALUE_1 = 1; + } + optional NestedFactory2Enum nested_factory_2_enum = 3; + message NestedFactory2Message { + optional string value = 1; + } + optional NestedFactory2Message nested_factory_2_message = 4; + optional Factory1Message factory_1_message = 5; + optional Factory1Enum factory_1_enum = 6; + optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7; + optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8; + optional Factory2Message circular_message = 9; + optional string scalar_value = 10; + repeated string list_value = 11; + repeated group Grouped = 12 { + optional string part_1 = 13; + optional string part_2 = 14; + } + optional LoopMessage loop = 15; + optional int32 int_with_default = 16 [default = 1776]; + optional double double_with_default = 17 [default = 9.99]; + optional string string_with_default = 18 [default = "hello world"]; + optional bool bool_with_default = 19 [default = false]; + optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1]; +} + +message LoopMessage { + optional Factory2Message loop = 1; +} diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index b3f7d9b1..8343aba1 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -42,8 +42,10 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_no_generic_services_pb2 @@ -239,6 +241,29 @@ class GeneratorTest(unittest.TestCase): unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in file_type.message_types_by_name) + def testPublicImports(self): + # Test public imports as embedded message. + all_type_proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, all_type_proto.optional_public_import_message.e) + + # PublicImportMessage is actually defined in unittest_import_public_pb2 + # module, and is public imported by unittest_import_pb2 module. + public_import_proto = unittest_import_pb2.PublicImportMessage() + self.assertEqual(0, public_import_proto.e) + self.assertTrue(unittest_import_public_pb2.PublicImportMessage is + unittest_import_pb2.PublicImportMessage) + + def testBadIdentifiers(self): + # We're just testing that the code was imported without problems. + message = test_bad_identifiers_pb2.TestBadIdentifiers() + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message], + "foo") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor], + "bar") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection], + "baz") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service], + "qux") if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_cpp_test.py b/python/google/protobuf/internal/message_cpp_test.py new file mode 100644 index 00000000..0d84b320 --- /dev/null +++ b/python/google/protobuf/internal/message_cpp_test.py @@ -0,0 +1,45 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.internal.message_cpp.""" + +__author__ = 'shahms@google.com (Shahms King)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' + +import unittest +from google.protobuf.internal.message_test import * + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py new file mode 100644 index 00000000..0bc9be99 --- /dev/null +++ b/python/google/protobuf/internal/message_factory_test.py @@ -0,0 +1,113 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message_factory + + +class MessageFactoryTest(unittest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def _ExerciseDynamicClass(self, cls): + msg = cls() + msg.mandatory = 42 + msg.nested_factory_2_enum = 0 + msg.nested_factory_2_message.value = 'nested message value' + msg.factory_1_message.factory_1_enum = 1 + msg.factory_1_message.nested_factory_1_enum = 0 + msg.factory_1_message.nested_factory_1_message.value = ( + 'nested message value') + msg.factory_1_message.scalar_value = 22 + msg.factory_1_message.list_value.extend(['one', 'two', 'three']) + msg.factory_1_message.list_value.append('four') + msg.factory_1_enum = 1 + msg.nested_factory_1_enum = 0 + msg.nested_factory_1_message.value = 'nested message value' + msg.circular_message.mandatory = 1 + msg.circular_message.circular_message.mandatory = 2 + msg.circular_message.scalar_value = 'one deep' + msg.scalar_value = 'zero deep' + msg.list_value.extend(['four', 'three', 'two']) + msg.list_value.append('one') + msg.grouped.add() + msg.grouped[0].part_1 = 'hello' + msg.grouped[0].part_2 = 'world' + msg.grouped.add(part_1='testing', part_2='123') + msg.loop.loop.mandatory = 2 + msg.loop.loop.loop.loop.mandatory = 4 + serialized = msg.SerializeToString() + converted = factory_test2_pb2.Factory2Message.FromString(serialized) + reserialized = converted.SerializeToString() + self.assertEquals(serialized, reserialized) + result = cls.FromString(reserialized) + self.assertEquals(msg, result) + + def testGetPrototype(self): + db = descriptor_database.DescriptorDatabase() + pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + factory = message_factory.MessageFactory() + cls = factory.GetPrototype(pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message')) + self.assertIsNot(cls, factory_test2_pb2.Factory2Message) + self._ExerciseDynamicClass(cls) + cls2 = factory.GetPrototype(pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message')) + self.assertIs(cls, cls2) + + def testGetMessages(self): + messages = message_factory.GetMessages([self.factory_test2_fd, + self.factory_test1_fd]) + self.assertContainsSubset( + ['net.proto2.python.internal.Factory2Message', + 'net.proto2.python.internal.Factory1Message'], + messages.keys()) + self._ExerciseDynamicClass( + messages['net.proto2.python.internal.Factory2Message']) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 65174373..53e9d507 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -45,10 +45,15 @@ __author__ = 'gps@google.com (Gregory P. Smith)' import copy import math +import operator +import pickle + import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf import message # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. @@ -70,9 +75,9 @@ class MessageTest(unittest.TestCase): golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): golden_data = test_util.GoldenFile('golden_message').read() @@ -81,9 +86,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -92,9 +97,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -103,9 +108,28 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedExtensions() test_util.SetAllPackedExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleSupport(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + self.assertEquals(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) def testPositiveInfinity(self): golden_data = ('\x5D\x00\x00\x80\x7F' @@ -118,7 +142,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsPosInf(golden_message.optional_double)) self.assertTrue(IsPosInf(golden_message.repeated_float[0])) self.assertTrue(IsPosInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self): golden_data = ('\x5D\x00\x00\x80\xFF' @@ -131,7 +155,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsNegInf(golden_message.optional_double)) self.assertTrue(IsNegInf(golden_message.repeated_float[0])) self.assertTrue(IsNegInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumber(self): golden_data = ('\x5D\x00\x00\xC0\x7F' @@ -144,7 +168,18 @@ class MessageTest(unittest.TestCase): self.assertTrue(isnan(golden_message.optional_double)) self.assertTrue(isnan(golden_message.repeated_float[0])) self.assertTrue(isnan(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + # The protocol buffer may serialize to any one of multiple different + # representations of a NaN. Rather than verify a specific representation, + # verify the serialized string can be converted into a correctly + # behaving protocol buffer. + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestAllTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.optional_float)) + self.assertTrue(isnan(message.optional_double)) + self.assertTrue(isnan(message.repeated_float[0])) + self.assertTrue(isnan(message.repeated_double[0])) def testPositiveInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' @@ -153,7 +188,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) self.assertTrue(IsPosInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' @@ -162,7 +197,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) self.assertTrue(IsNegInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumberPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' @@ -171,7 +206,12 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) self.assertTrue(isnan(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestPackedTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.packed_float[0])) + self.assertTrue(isnan(message.packed_double[0])) def testExtremeFloatValues(self): message = unittest_pb2.TestAllTypes() @@ -218,7 +258,7 @@ class MessageTest(unittest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - def testExtremeFloatValues(self): + def testExtremeDoubleValues(self): message = unittest_pb2.TestAllTypes() # Most positive exponent, no significand bits set. @@ -338,6 +378,117 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + def testRepeatedCompositeFieldSortArguments(self): + """Check sorting a repeated composite field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + get_bb = operator.attrgetter('bb') + cmp_bb = lambda a, b: cmp(a.bb, b.bb) + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=get_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(key=get_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + message.repeated_nested_message.sort(sort_function=cmp_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + + def testRepeatedScalarFieldSortArguments(self): + """Check sorting a scalar field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(key=abs, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(key=len, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + + def testSortEmptyRepeatedCompositeContainer(self): + """Exercise a scenario that has led to segfaults in the past. + """ + m = unittest_pb2.TestAllTypes() + m.repeated_nested_message.sort() + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto new file mode 100644 index 00000000..df98ac4b --- /dev/null +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -0,0 +1,49 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jasonh@google.com (Jason Hsueh) +// +// This file is used to test a corner case in the CPP implementation where the +// generated C++ type is available for the extendee, but the extension is +// defined in a file whose C++ type is not in the binary. + + +import "google/protobuf/internal/more_extensions.proto"; + +package google.protobuf.internal; + +message DynamicMessageType { + optional int32 a = 1; +} + +extend ExtendedMessage { + optional int32 dynamic_int32_extension = 100; + optional DynamicMessageType dynamic_message_extension = 101; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 66fca918..4bea57ac 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -54,6 +54,7 @@ try: from cStringIO import StringIO except ImportError: from StringIO import StringIO +import copy_reg import struct import weakref @@ -61,6 +62,7 @@ import weakref from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format @@ -71,9 +73,10 @@ from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -def NewMessage(descriptor, dictionary): +def NewMessage(bases, descriptor, dictionary): _AddClassAttributesForNestedExtensions(descriptor, dictionary) _AddSlots(descriptor, dictionary) + return bases def InitMessage(descriptor, cls): @@ -96,6 +99,7 @@ def InitMessage(descriptor, cls): _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) _AddPrivateHelperMethods(cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -145,6 +149,10 @@ def _VerifyExtensionHandle(message, extension_handle): if not extension_handle.is_extension: raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + if extension_handle.containing_type is not message.DESCRIPTOR: raise KeyError('Extension "%s" extends message type "%s", but this ' 'message is of type "%s".' % @@ -164,6 +172,7 @@ def _AddSlots(message_descriptor, dictionary): dictionary['__slots__'] = ['_cached_byte_size', '_cached_byte_size_dirty', '_fields', + '_unknown_fields', '_is_present_in_parent', '_listener', '_listener_for_children', @@ -224,11 +233,14 @@ def _AddClassAttributesForNestedExtensions(descriptor, dictionary): def _AddEnumValues(descriptor, cls): """Sets class-level attributes for all enum fields defined in this message. + Also exporting a class-level object that can name enum values. + Args: descriptor: Descriptor object for this message type. cls: Class we're constructing for this message type. """ for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) for enum_value in enum_type.values: setattr(cls, enum_value.name, enum_value.number) @@ -248,7 +260,7 @@ def _DefaultValueConstructorForField(field): """ if field.label == _FieldDescriptor.LABEL_REPEATED: - if field.default_value != []: + if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( field.default_value)) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: @@ -276,6 +288,8 @@ def _DefaultValueConstructorForField(field): return MakeSubMessageDefault def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return field.default_value return MakeScalarDefault @@ -287,6 +301,9 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) @@ -428,6 +445,8 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): valid_values = set() def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -462,13 +481,18 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # for non-repeated scalars. proto_field_name = field.name property_name = _PropertyName(proto_field_name) + + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). message_type = field.message_type def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() + field_value = message_type._concrete_class() # use field.message_type? field_value._SetListener(self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap @@ -620,6 +644,7 @@ def _AddClearMethod(message_descriptor, cls): def Clear(self): # Clear fields. self._fields = {} + self._unknown_fields = () self._Modified() cls.Clear = Clear @@ -649,7 +674,16 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - return self.ListFields() == other.ListFields() + if not self.ListFields() == other.ListFields(): + return False + + # Sort unknown fields because their order shouldn't affect equality test. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + + return unknown_fields == other_unknown_fields cls.__eq__ = __eq__ @@ -710,6 +744,9 @@ def _AddByteSizeMethod(message_descriptor, cls): for field_descriptor, field_value in self.ListFields(): size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + self._cached_byte_size = size self._cached_byte_size_dirty = False self._listener_for_children.dirty = False @@ -726,8 +763,8 @@ def _AddSerializeToStringMethod(message_descriptor, cls): errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) return self.SerializePartialToString() cls.SerializeToString = SerializeToString @@ -744,6 +781,9 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): def InternalSerialize(self, write_bytes): for field_descriptor, field_value in self.ListFields(): field_descriptor._encoder(write_bytes, field_value) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) cls._InternalSerialize = InternalSerialize @@ -770,13 +810,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls): def InternalParse(self, buffer, pos, end): self._Modified() field_dict = self._fields + unknown_field_list = self._unknown_fields while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) field_decoder = decoders_by_tag.get(tag_bytes) if field_decoder is None: + value_start_pos = new_pos new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -873,7 +918,8 @@ def _AddMergeFromMethod(cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) assert msg is not self self._Modified() @@ -898,6 +944,12 @@ def _AddMergeFromMethod(cls): field_value.MergeFrom(value) else: self._fields[field] = value + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + cls.MergeFrom = MergeFrom diff --git a/python/google/protobuf/internal/reflection_cpp_generated_test.py b/python/google/protobuf/internal/reflection_cpp_generated_test.py new file mode 100755 index 00000000..2a0a5124 --- /dev/null +++ b/python/google/protobuf/internal/reflection_cpp_generated_test.py @@ -0,0 +1,91 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which tests the generated C++ implementation.""" + +__author__ = 'jasonh@google.com (Jason Hsueh)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' + +import unittest +from google.protobuf.internal import api_implementation +from google.protobuf.internal import more_extensions_dynamic_pb2 +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal.reflection_test import * + + +class ReflectionCppTest(unittest.TestCase): + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + + def testExtensionOfGeneratedTypeInDynamicFile(self): + """Tests that a file built dynamically can extend a generated C++ type. + + The C++ implementation uses a DescriptorPool that has the generated + DescriptorPool as an underlay. Typically, a type can only find + extensions in its own pool. With the python C-extension, the generated C++ + extendee may be available, but not the extension. This tests that the + C-extension implements the correct special handling to make such extensions + available. + """ + pb1 = more_extensions_pb2.ExtendedMessage() + # Test that basic accessors work. + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + + # Now serialize the data and parse to a new message. + pb2 = more_extensions_pb2.ExtendedMessage() + pb2.MergeFromString(pb1.SerializeToString()) + + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + self.assertEqual( + 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) + self.assertEqual( + 24, + pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 7b9d3398..ed286461 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,6 +37,7 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import gc import operator import struct @@ -318,15 +319,6 @@ class ReflectionTest(unittest.TestCase): # ...and ensure that the scalar field has returned to its default. self.assertEqual(0, getattr(composite_field, scalar_field_name)) - # Finally, ensure that modifications to the old composite field object - # don't have any effect on the parent. Possible only with the pure-python - # implementation of the API. - # - # (NOTE that when we clear the composite field in the parent, we actually - # don't recursively clear down the tree. Instead, we just disconnect the - # cleared composite from the tree.) - if api_implementation.Type() != 'python': - return self.assertTrue(old_composite_field is not composite_field) setattr(old_composite_field, scalar_field_name, new_val) self.assertTrue(not composite_field.HasField(scalar_field_name)) @@ -348,8 +340,6 @@ class ReflectionTest(unittest.TestCase): nested.bb = 23 def testDisconnectingNestedMessageBeforeSettingField(self): - if api_implementation.Type() != 'python': - return proto = unittest_pb2.TestAllTypes() nested = proto.optional_nested_message proto.ClearField('optional_nested_message') # Should disconnect from parent @@ -358,6 +348,64 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) + def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') + del proto + del nested + # Force a garbage collect so that the underlying CMessages are freed along + # with the Messages they point to. This is to make sure we're not deleting + # default message instances. + gc.collect() + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + + def testDisconnectingNestedMessageAfterSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + self.assertTrue(proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertEqual(5, nested.bb) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testDisconnectingNestedMessageBeforeGettingField(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + + def testDisconnectingNestedMessageAfterMerge(self): + # This test exercises the code path that does not use ReleaseMessage(). + # The underlying fear is that if we use ReleaseMessage() incorrectly, + # we will have memory leaks. It's hard to check that that doesn't happen, + # but at least we can exercise that code path to make sure it works. + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_nested_message.bb = 5 + proto1.MergeFrom(proto2) + self.assertTrue(proto1.HasField('optional_nested_message')) + proto1.ClearField('optional_nested_message') + self.assertTrue(not proto1.HasField('optional_nested_message')) + + def testDisconnectingLazyNestedMessage(self): + # This test exercises releasing a nested message that is lazy. This test + # only exercises real code in the C++ implementation as Python does not + # support lazy parsing, but the current C++ implementation results in + # memory corruption and a crash. + if api_implementation.Type() != 'python': + return + proto = unittest_pb2.TestAllTypes() + proto.optional_lazy_message.bb = 5 + proto.ClearField('optional_lazy_message') + del proto + gc.collect() + def testHasBitsWhenModifyingRepeatedFields(self): # Test nesting when we add an element to a repeated field in a submessage. proto = unittest_pb2.TestNestedMessageHasBits() @@ -635,6 +683,77 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testEnum_Name(self): + self.assertEqual('FOREIGN_FOO', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) + self.assertEqual('FOREIGN_BAR', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) + self.assertEqual('FOREIGN_BAZ', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Name, 11312) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual('FOO', + proto.NestedEnum.Name(proto.FOO)) + self.assertEqual('FOO', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) + self.assertEqual('BAR', + proto.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAR', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAZ', + proto.NestedEnum.Name(proto.BAZ)) + self.assertEqual('BAZ', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) + self.assertRaises(ValueError, + proto.NestedEnum.Name, 11312) + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) + + def testEnum_Value(self): + self.assertEqual(unittest_pb2.FOREIGN_FOO, + unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Value, 'FO') + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(proto.FOO, + proto.NestedEnum.Value('FOO')) + self.assertEqual(proto.FOO, + unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) + self.assertEqual(proto.BAR, + proto.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAR, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAZ, + proto.NestedEnum.Value('BAZ')) + self.assertEqual(proto.BAZ, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) + self.assertRaises(ValueError, + proto.NestedEnum.Value, 'Foo') + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') + + def testEnum_KeysAndValues(self): + self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], + unittest_pb2.ForeignEnum.keys()) + self.assertEqual([4, 5, 6], + unittest_pb2.ForeignEnum.values()) + self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), + ('FOREIGN_BAZ', 6)], + unittest_pb2.ForeignEnum.items()) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], + proto.NestedEnum.items()) + def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -826,6 +945,35 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(1, len(proto.repeated_nested_message)) self.assertEqual(23, proto.repeated_nested_message[0].bb) + def testRepeatedCompositeRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + # Need to set some differentiating variable so m0 != m1 != m2: + m0.bb = len(proto.repeated_nested_message) + m1 = proto.repeated_nested_message.add() + m1.bb = len(proto.repeated_nested_message) + self.assertTrue(m0 != m1) + m2 = proto.repeated_nested_message.add() + m2.bb = len(proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) + + self.assertEqual(3, len(proto.repeated_nested_message)) + proto.repeated_nested_message.remove(m0) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + self.assertEqual(m2, proto.repeated_nested_message[1]) + + # Removing m0 again or removing None should raise error + self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) + self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) + self.assertEqual(2, len(proto.repeated_nested_message)) + + proto.repeated_nested_message.remove(m2) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + def testHandWrittenReflection(self): # Hand written extensions are only supported by the pure-Python # implementation of the API. @@ -856,6 +1004,68 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + def testDescriptorProtoSupport(self): + # Hand written descriptors/reflection are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + + def AddDescriptorField(proto, field_name, field_type): + AddDescriptorField.field_index += 1 + new_field = proto.field.add() + new_field.name = field_name + new_field.type = field_type + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + AddDescriptorField.field_index = 0 + + desc_proto = descriptor_pb2.DescriptorProto() + desc_proto.name = 'Car' + fdp = descriptor_pb2.FieldDescriptorProto + AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) + AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) + AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) + AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) + # Add a repeated field + AddDescriptorField.field_index += 1 + new_field = desc_proto.field.add() + new_field.name = 'owners' + new_field.type = fdp.TYPE_STRING + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + + desc = descriptor.MakeDescriptor(desc_proto) + self.assertTrue(desc.fields_by_name.has_key('name')) + self.assertTrue(desc.fields_by_name.has_key('year')) + self.assertTrue(desc.fields_by_name.has_key('automatic')) + self.assertTrue(desc.fields_by_name.has_key('price')) + self.assertTrue(desc.fields_by_name.has_key('owners')) + + class CarMessage(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = desc + + prius = CarMessage() + prius.name = 'prius' + prius.year = 2010 + prius.automatic = True + prius.price = 25134.75 + prius.owners.extend(['bob', 'susan']) + + serialized_prius = prius.SerializeToString() + new_prius = reflection.ParseMessage(desc, serialized_prius) + self.assertTrue(new_prius is not prius) + self.assertEqual(prius, new_prius) + + # these are unnecessary assuming message equality works as advertised but + # explicitly check to be safe since we're mucking about in metaclass foo + self.assertEqual(prius.name, new_prius.name) + self.assertEqual(prius.year, new_prius.year) + self.assertEqual(prius.automatic, new_prius.automatic) + self.assertEqual(prius.price, new_prius.price) + self.assertEqual(prius.owners, new_prius.owners) + def testTopLevelExtensionsForOptionalScalar(self): extendee_proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.optional_int32_extension @@ -1243,7 +1453,12 @@ class ReflectionTest(unittest.TestCase): def testClear(self): proto = unittest_pb2.TestAllTypes() - test_util.SetAllFields(proto) + # C++ implementation does not support lazy fields right now so leave it + # out for now. + if api_implementation.Type() == 'python': + test_util.SetAllFields(proto) + else: + test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() self.assertEquals(proto.ByteSize(), 0) @@ -1259,6 +1474,33 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def testDisconnectingBeforeClear(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + foreign = proto.optional_foreign_message + foreign.c = 6 + + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + self.assertTrue(foreign is not proto.optional_foreign_message) + self.assertEqual(5, nested.bb) + self.assertEqual(6, foreign.c) + nested.bb = 15 + foreign.c = 16 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(not proto.HasField('optional_foreign_message')) + self.assertEqual(0, proto.optional_foreign_message.c) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1408,7 +1650,7 @@ class ReflectionTest(unittest.TestCase): unicode_decode_failed = False try: message2.MergeFromString(bytes) - except UnicodeDecodeError, e: + except UnicodeDecodeError as e: unicode_decode_failed = True string_field = message2.str self.assertTrue(unicode_decode_failed or type(string_field) == str) @@ -2119,7 +2361,7 @@ class SerializationTest(unittest.TestCase): """This method checks if the excpetion type and message are as expected.""" try: callable_obj() - except exc_class, ex: + except exc_class as ex: # Check if the exception message is the right one. self.assertEqual(exception, str(ex)) return @@ -2131,15 +2373,22 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: a,b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: ' + 'a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() + proto2 = unittest_pb2.TestRequired() + self.assertFalse(proto2.HasField('a')) + # proto2 ParseFromString does not check that required fields are set. + proto2.ParseFromString(partial) + self.assertFalse(proto2.HasField('a')) + proto.a = 1 self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -2147,7 +2396,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: c') + 'Message protobuf_unittest.TestRequired is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -2176,7 +2425,8 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign ' + 'is missing required fields: ' 'optional_message.b,optional_message.c') proto.optional_message.b = 2 @@ -2188,7 +2438,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 'repeated_message[0].b,repeated_message[0].c,' 'repeated_message[1].a,repeated_message[1].c') diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto new file mode 100644 index 00000000..6a82299a --- /dev/null +++ b/python/google/protobuf/internal/test_bad_identifiers.proto @@ -0,0 +1,52 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + + +package protobuf_unittest; + +option py_generic_services = true; + +message TestBadIdentifiers { + extensions 100 to max; +} + +// Make sure these reasonable extension names don't conflict with internal +// variables. +extend TestBadIdentifiers { + optional string message = 100 [default="foo"]; + optional string descriptor = 101 [default="bar"]; + optional string reflection = 102 [default="baz"]; + optional string service = 103 [default="qux"]; +} + +message AnotherMessage {} +service AnotherService {} diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 1df16194..be8ae7be 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -42,8 +42,8 @@ from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 -def SetAllFields(message): - """Sets every field in the message to a unique value. +def SetAllNonLazyFields(message): + """Sets every non-lazy field in the message to a unique value. Args: message: A unittest_pb2.TestAllTypes instance. @@ -79,6 +79,7 @@ def SetAllFields(message): message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 + message.optional_public_import_message.e = 126 message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ @@ -111,6 +112,7 @@ def SetAllFields(message): message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 + message.repeated_lazy_message.add().bb = 227 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) @@ -140,6 +142,7 @@ def SetAllFields(message): message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 + message.repeated_lazy_message.add().bb = 327 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) @@ -176,6 +179,11 @@ def SetAllFields(message): message.default_cord = '425' +def SetAllFields(message): + SetAllNonLazyFields(message) + message.optional_lazy_message.bb = 127 + + def SetAllExtensions(message): """Sets every extension in the message to a unique value. @@ -211,6 +219,8 @@ def SetAllExtensions(message): extensions[pb2.optional_nested_message_extension].bb = 118 extensions[pb2.optional_foreign_message_extension].c = 119 extensions[pb2.optional_import_message_extension].d = 120 + extensions[pb2.optional_public_import_message_extension].e = 126 + extensions[pb2.optional_lazy_message_extension].bb = 127 extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ @@ -244,6 +254,7 @@ def SetAllExtensions(message): extensions[pb2.repeated_nested_message_extension].add().bb = 218 extensions[pb2.repeated_foreign_message_extension].add().c = 219 extensions[pb2.repeated_import_message_extension].add().d = 220 + extensions[pb2.repeated_lazy_message_extension].add().bb = 227 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) @@ -273,6 +284,7 @@ def SetAllExtensions(message): extensions[pb2.repeated_nested_message_extension].add().bb = 318 extensions[pb2.repeated_foreign_message_extension].add().c = 319 extensions[pb2.repeated_import_message_extension].add().d = 320 + extensions[pb2.repeated_lazy_message_extension].add().bb = 327 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) @@ -407,6 +419,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) + test_case.assertEqual(126, message.optional_public_import_message.e) + test_case.assertEqual(127, message.optional_lazy_message.bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) @@ -464,6 +478,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) + test_case.assertEqual(227, message.repeated_lazy_message[0].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, message.repeated_nested_enum[0]) @@ -492,6 +507,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) + test_case.assertEqual(327, message.repeated_lazy_message[1].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.repeated_nested_enum[1]) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 73d97d18..23b50eb5 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -94,6 +94,28 @@ class TextFormatTest(unittest.TestCase): ' }\n' '}\n') + def testPrintBadEnumValue(self): + message = unittest_pb2.TestAllTypes() + message.optional_nested_enum = 100 + message.optional_foreign_enum = 101 + message.optional_import_enum = 102 + self.CompareToGoldenText( + text_format.MessageToString(message), + 'optional_nested_enum: 100\n' + 'optional_foreign_enum: 101\n' + 'optional_import_enum: 102\n') + + def testPrintBadEnumValueExtensions(self): + message = unittest_pb2.TestAllExtensions() + message.Extensions[unittest_pb2.optional_nested_enum_extension] = 100 + message.Extensions[unittest_pb2.optional_foreign_enum_extension] = 101 + message.Extensions[unittest_pb2.optional_import_enum_extension] = 102 + self.CompareToGoldenText( + text_format.MessageToString(message), + '[protobuf_unittest.optional_nested_enum_extension]: 100\n' + '[protobuf_unittest.optional_foreign_enum_extension]: 101\n' + '[protobuf_unittest.optional_import_enum_extension]: 102\n') + def testPrintExotic(self): message = unittest_pb2.TestAllTypes() message.repeated_int64.append(-9223372036854775808) @@ -399,6 +421,14 @@ class TextFormatTest(unittest.TestCase): 'has no value with number 100.'), text_format.Merge, text, message) + def testMergeBadIntValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_int32: bork' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:17 : Couldn\'t parse integer: bork'), + text_format.Merge, text, message) + def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): """Same as assertRaises, but also compares the exception message.""" if hasattr(e_class, '__name__'): @@ -408,7 +438,7 @@ class TextFormatTest(unittest.TestCase): try: func(*args, **kwargs) - except e_class, expr: + except e_class as expr: if str(expr) != e: msg = '%s raised, but with wrong message: "%s" instead of "%s"' raise self.failureException(msg % (exc_name, @@ -427,7 +457,7 @@ class TokenizerTest(unittest.TestCase): 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' - 'ID12: 2222222222222222222 ' + 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' ) tokenizer = text_format._Tokenizer(text) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), @@ -456,10 +486,10 @@ class TokenizerTest(unittest.TestCase): '{', (tokenizer.ConsumeIdentifier, 'A'), ':', - (tokenizer.ConsumeFloat, text_format._INFINITY), + (tokenizer.ConsumeFloat, float('inf')), (tokenizer.ConsumeIdentifier, 'B'), ':', - (tokenizer.ConsumeFloat, -text_format._INFINITY), + (tokenizer.ConsumeFloat, -float('inf')), (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), @@ -479,6 +509,12 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeIdentifier, 'ID12'), ':', (tokenizer.ConsumeUint64, 2222222222222222222), + (tokenizer.ConsumeIdentifier, 'ID13'), + ':', + (tokenizer.ConsumeFloat, 1.23456), + (tokenizer.ConsumeIdentifier, 'ID14'), + ':', + (tokenizer.ConsumeFloat, 1.2e+2), (tokenizer.ConsumeIdentifier, 'false_bool'), ':', (tokenizer.ConsumeBool, False), @@ -556,16 +592,6 @@ class TokenizerTest(unittest.TestCase): tokenizer = text_format._Tokenizer(text) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) - def testInfNan(self): - # Make sure our infinity and NaN definitions are sound. - self.assertEquals(float, type(text_format._INFINITY)) - self.assertEquals(float, type(text_format._NAN)) - self.assertTrue(text_format._NAN != text_format._NAN) - - inf_times_zero = text_format._INFINITY * 0 - self.assertTrue(inf_times_zero != inf_times_zero) - self.assertTrue(text_format._INFINITY > 0) - if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py new file mode 100755 index 00000000..84984b40 --- /dev/null +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -0,0 +1,170 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for preservation of unknown fields in the pure Python implementation.""" + +__author__ = 'bohdank@google.com (Bohdan Koval)' + +import unittest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import encoder +from google.protobuf.internal import test_util +from google.protobuf.internal import type_checkers + + +class UnknownFieldsTest(unittest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + self.unknown_fields = self.empty_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes] + result_dict = {} + decoder(value, 0, len(value), self.all_fields, result_dict) + return result_dict[field_descriptor] + + def testVarint(self): + value = self.GetField('optional_int32') + self.assertEqual(self.all_fields.optional_int32, value) + + def testFixed32(self): + value = self.GetField('optional_fixed32') + self.assertEqual(self.all_fields.optional_fixed32, value) + + def testFixed64(self): + value = self.GetField('optional_fixed64') + self.assertEqual(self.all_fields.optional_fixed64, value) + + def testLengthDelimited(self): + value = self.GetField('optional_string') + self.assertEqual(self.all_fields.optional_string, value) + + def testGroup(self): + value = self.GetField('optionalgroup') + self.assertEqual(self.all_fields.optionalgroup, value) + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testCopyFrom(self): + message = unittest_pb2.TestEmptyMessage() + message.CopyFrom(self.empty_message) + self.assertEqual(self.unknown_fields, message._unknown_fields) + + def testMergeFrom(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 1 + message.optional_uint32 = 2 + source = unittest_pb2.TestEmptyMessage() + source.ParseFromString(message.SerializeToString()) + + message.ClearField('optional_int32') + message.optional_int64 = 3 + message.optional_uint32 = 4 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_fields = destination._unknown_fields[:] + + destination.MergeFrom(source) + self.assertEqual(unknown_fields + source._unknown_fields, + destination._unknown_fields) + + def testClear(self): + self.empty_message.Clear() + self.assertEqual(0, len(self.empty_message._unknown_fields)) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testUnknownExtensions(self): + message = unittest_pb2.TestEmptyMessageWithExtensions() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message._unknown_fields, + message._unknown_fields) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 1545009 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + +if __name__ == '__main__': + unittest.main() |