diff options
author | Bo Yang <teboring@google.com> | 2015-05-21 14:28:59 -0700 |
---|---|---|
committer | Bo Yang <teboring@google.com> | 2015-05-21 19:32:02 -0700 |
commit | 5db217305f37a79eeccd70f000088a06ec82fcec (patch) | |
tree | be53dcf0c0b47ef9178ab8a6fa5c1946ee84a28f /python/google/protobuf/internal | |
parent | 56095026ccc2f755a6fdb296e30c3ddec8f556a2 (diff) |
down-integrate internal changes
Diffstat (limited to 'python/google/protobuf/internal')
20 files changed, 1243 insertions, 91 deletions
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index d976f9e1..72c2fa01 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,6 +41,146 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import sys + +if sys.version_info[0] < 3: + # We would use collections.MutableMapping all the time, but in Python 2 it + # doesn't define __slots__. This causes two significant problems: + # + # 1. we can't disallow arbitrary attribute assignment, even if our derived + # classes *do* define __slots__. + # + # 2. we can't safely derive a C type from it without __slots__ defined (the + # interpreter expects to find a dict at tp_dictoffset, which we can't + # robustly provide. And we don't want an instance dict anyway. + # + # So this is the Python 2.7 definition of Mapping/MutableMapping functions + # verbatim, except that: + # 1. We declare __slots__. + # 2. We don't declare this as a virtual base class. The classes defined + # in collections are the interesting base classes, not us. + # + # Note: deriving from object is critical. It is the only thing that makes + # this a true type, allowing us to derive from it in C++ cleanly and making + # __slots__ properly disallow arbitrary element assignment. + from collections import Mapping as _Mapping + + class Mapping(object): + __slots__ = () + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + return iter(self) + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def keys(self): + return list(self) + + def items(self): + return [(key, self[key]) for key in self] + + def values(self): + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, _Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + + class MutableMapping(Mapping): + __slots__ = () + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + if len(args) > 2: + raise TypeError("update() takes at most 2 positional " + "arguments ({} given)".format(len(args))) + elif not args: + raise TypeError("update() takes at least 1 argument (0 given)") + self = args[0] + other = args[1] if len(args) >= 2 else () + + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + _Mapping.register(Mapping) + +else: + # In Python 3 we can just use MutableMapping directly, because it defines + # __slots__. + from collections import MutableMapping + + class BaseContainer(object): """Base container class.""" @@ -286,3 +426,160 @@ class RepeatedCompositeFieldContainer(BaseContainer): raise TypeError('Can only compare repeated composite fields against ' 'other repeated composite fields.') return self._values == other._values + + +class ScalarMap(MutableMapping): + + """Simple, type-checked, dict-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener'] + + def __init__(self, message_listener, key_checker, value_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._key_checker = key_checker + self._value_checker = value_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + val = self._value_checker.DefaultValue() + self._values[key] = val + return val + + def __contains__(self, item): + return item in self._values + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __setitem__(self, key, value): + checked_key = self._key_checker.CheckValue(key) + checked_value = self._value_checker.CheckValue(value) + self._values[checked_key] = checked_value + self._message_listener.Modified() + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def MergeFrom(self, other): + self._values.update(other._values) + self._message_listener.Modified() + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + +class MessageMap(MutableMapping): + + """Simple, type-checked, dict-like container for with submessage values.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_values', '_message_listener', + '_message_descriptor'] + + def __init__(self, message_listener, message_descriptor, key_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._key_checker = key_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values[key] = new_element + self._message_listener.Modified() + + return new_element + + def get_or_create(self, key): + """get_or_create() is an alias for getitem (ie. map[key]). + + Args: + key: The key to get or create in the map. + + This is useful in cases where you want to be explicit that the call is + mutating the map. This can avoid lint errors for statements like this + that otherwise would appear to be pointless statements: + + msg.my_map[key] + """ + return self[key] + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __contains__(self, item): + return item in self._values + + def __setitem__(self, key, value): + raise ValueError('May not set values directly, call my_map[key].foo = 5') + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def MergeFrom(self, other): + for key in other: + self[key].MergeFrom(other[key]) + # self._message_listener.Modified() not required here, because + # mutations to submessages already propagate. + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 0f500606..3837eaea 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -733,6 +733,50 @@ def MessageSetItemDecoder(extensions_by_number): return DecodeItem # -------------------------------------------------------------------- + +def MapDecoder(field_descriptor, new_default, is_message_map): + """Returns a decoder for a map field.""" + + key = field_descriptor + tag_bytes = encoder.TagBytes(field_descriptor.number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + local_DecodeVarint = _DecodeVarint + # Can't read _concrete_class yet; might not be initialized. + message_type = field_descriptor.message_type + + def DecodeMap(buffer, pos, end, message, field_dict): + submsg = message_type._concrete_class() + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + submsg.Clear() + if submsg._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + if is_message_map: + value[submsg.key].MergeFrom(submsg.value) + else: + value[submsg.key] = submsg.value + + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + + return DecodeMap + +# -------------------------------------------------------------------- # Optimization is not as heavy here because calls to SkipField() are rare, # except for handling end-group tags. diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index 56fe14e9..8416e157 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -35,7 +35,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import unittest - from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 7d145f42..d159cc62 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -37,6 +37,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import os import unittest +import unittest from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation @@ -226,6 +227,13 @@ class DescriptorPoolTest(unittest.TestCase): db.Add(self.factory_test2_fd) self.testFindMessageTypeByName() + def testAddSerializedFile(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) + self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) + self.testFindMessageTypeByName() + def testComplexNesting(self): test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) @@ -510,6 +518,43 @@ class AddDescriptorTest(unittest.TestCase): 'protobuf_unittest.TestAllTypes') +@unittest.skipIf( + api_implementation.Type() != 'cpp', + 'default_pool is only supported by the C++ implementation') +class DefaultPoolTest(unittest.TestCase): + + def testFindMethods(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + self.assertIs( + pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + + def testAddFileDescriptor(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + pool.Add(file_desc) + pool.AddSerializedFile(file_desc.SerializeToString()) + + TEST1_FILE = ProtoFile( 'google/protobuf/internal/descriptor_pool_test1.proto', 'google.protobuf.python.internal', diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 335caee6..26866f3a 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -35,8 +35,8 @@ __author__ = 'robinson@google.com (Will Robinson)' import sys -import unittest +import unittest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 38a5138a..752f4eab 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -314,7 +314,7 @@ def MessageSizer(field_number, is_repeated, is_packed): # -------------------------------------------------------------------- -# MessageSet is special. +# MessageSet is special: it needs custom logic to compute its size properly. def MessageSetItemSizer(field_number): @@ -339,6 +339,32 @@ def MessageSetItemSizer(field_number): return FieldSize +# -------------------------------------------------------------------- +# Map is special: it needs custom logic to compute its size properly. + + +def MapSizer(field_descriptor): + """Returns a sizer for a map field.""" + + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + message_sizer = MessageSizer(field_descriptor.number, False, False) + + def FieldSize(map_value): + total = 0 + for key in map_value: + value = map_value[key] + # It's wasteful to create the messages and throw them away one second + # later since we'll do the same for the actual encode. But there's not an + # obvious way to avoid this within the current design without tons of code + # duplication. + entry_msg = message_type._concrete_class(key=key, value=value) + total += message_sizer(entry_msg) + return total + + return FieldSize + # ==================================================================== # Encoders! @@ -786,3 +812,30 @@ def MessageSetItemEncoder(field_number): return write(end_bytes) return EncodeField + + +# -------------------------------------------------------------------- +# As before, Map is special. + + +def MapEncoder(field_descriptor): + """Encoder for extensions of MessageSet. + + Maps always have a wire format like this: + message MapEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapEntry map = N; + """ + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + encode_message = MessageEncoder(field_descriptor.number, False, False) + + def EncodeField(write, value): + for key in value: + entry_msg = message_type._concrete_class(key=key, value=value[key]) + encode_message(write, entry_msg) + + return EncodeField diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index ccc5860b..5c07cbe6 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -42,7 +42,6 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' import unittest - from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 626c3fc9..b8694f96 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -35,7 +35,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import unittest - from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 4ecaa1c7..320ff0d2 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -50,7 +50,9 @@ import pickle import sys import unittest +import unittest from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation @@ -125,10 +127,17 @@ class MessageTest(unittest.TestCase): self.assertEqual(unpickled_message, golden_message) def testPositiveInfinity(self, message_module): - golden_data = (b'\x5D\x00\x00\x80\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - b'\xCD\x02\x00\x00\x80\x7F' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + else: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCA\x02\x04\x00\x00\x80\x7F' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) @@ -138,10 +147,17 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self, message_module): - golden_data = (b'\x5D\x00\x00\x80\xFF' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - b'\xCD\x02\x00\x00\x80\xFF' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + else: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCA\x02\x04\x00\x00\x80\xFF' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) @@ -1034,64 +1050,132 @@ class Proto2Test(unittest.TestCase): self.assertEqual(len(parsing_merge.Extensions[ unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def testPythonicInit(self): + message = unittest_pb2.TestAllTypes( + optional_int32=100, + optional_fixed32=200, + optional_float=300.5, + optional_bytes=b'x', + optionalgroup={'a': 400}, + optional_nested_message={'bb': 500}, + optional_nested_enum='BAZ', + repeatedgroup=[{'a': 600}, + {'a': 700}], + repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], + default_int32=800, + oneof_string='y') + self.assertTrue(isinstance(message, unittest_pb2.TestAllTypes)) + self.assertEqual(100, message.optional_int32) + self.assertEqual(200, message.optional_fixed32) + self.assertEqual(300.5, message.optional_float) + self.assertEqual(b'x', message.optional_bytes) + self.assertEqual(400, message.optionalgroup.a) + self.assertTrue(isinstance(message.optional_nested_message, + unittest_pb2.TestAllTypes.NestedMessage)) + self.assertEqual(500, message.optional_nested_message.bb) + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + self.assertEqual(2, len(message.repeatedgroup)) + self.assertEqual(600, message.repeatedgroup[0].a) + self.assertEqual(700, message.repeatedgroup[1].a) + self.assertEqual(2, len(message.repeated_nested_enum)) + self.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.repeated_nested_enum[0]) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[1]) + self.assertEqual(800, message.default_int32) + self.assertEqual('y', message.oneof_string) + self.assertFalse(message.HasField('optional_int64')) + self.assertEqual(0, len(message.repeated_float)) + self.assertEqual(42, message.default_int64) + + message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes( + optional_nested_message={'INVALID_NESTED_FIELD': 17}) + + with self.assertRaises(TypeError): + unittest_pb2.TestAllTypes( + optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + # Class to test proto3-only features/behavior (updated field presence & enums) class Proto3Test(unittest.TestCase): + # Utility method for comparing equality with a map. + def assertMapIterEquals(self, map_iter, dict_value): + # Avoid mutating caller's copy. + dict_value = dict(dict_value) + + for k, v in map_iter: + self.assertEqual(v, dict_value[k]) + del dict_value[k] + + self.assertEqual({}, dict_value) + def testFieldPresence(self): message = unittest_proto3_arena_pb2.TestAllTypes() # We can't test presence of non-repeated, non-submessage fields. with self.assertRaises(ValueError): - message.HasField("optional_int32") + message.HasField('optional_int32') with self.assertRaises(ValueError): - message.HasField("optional_float") + message.HasField('optional_float') with self.assertRaises(ValueError): - message.HasField("optional_string") + message.HasField('optional_string') with self.assertRaises(ValueError): - message.HasField("optional_bool") + message.HasField('optional_bool') # But we can still test presence of submessage fields. - self.assertFalse(message.HasField("optional_nested_message")) + self.assertFalse(message.HasField('optional_nested_message')) # As with proto2, we can't test presence of fields that don't exist, or # repeated fields. with self.assertRaises(ValueError): - message.HasField("field_doesnt_exist") + message.HasField('field_doesnt_exist') with self.assertRaises(ValueError): - message.HasField("repeated_int32") + message.HasField('repeated_int32') with self.assertRaises(ValueError): - message.HasField("repeated_nested_message") + message.HasField('repeated_nested_message') # Fields should default to their type-specific default. self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_float) - self.assertEqual("", message.optional_string) + self.assertEqual('', message.optional_string) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) # Setting a submessage should still return proper presence information. message.optional_nested_message.bb = 0 - self.assertTrue(message.HasField("optional_nested_message")) + self.assertTrue(message.HasField('optional_nested_message')) # Set the fields to non-default values. message.optional_int32 = 5 message.optional_float = 1.1 - message.optional_string = "abc" + message.optional_string = 'abc' message.optional_bool = True message.optional_nested_message.bb = 15 # Clearing the fields unsets them and resets their value to default. - message.ClearField("optional_int32") - message.ClearField("optional_float") - message.ClearField("optional_string") - message.ClearField("optional_bool") - message.ClearField("optional_nested_message") + message.ClearField('optional_int32') + message.ClearField('optional_float') + message.ClearField('optional_string') + message.ClearField('optional_bool') + message.ClearField('optional_nested_message') self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_float) - self.assertEqual("", message.optional_string) + self.assertEqual('', message.optional_string) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) @@ -1113,6 +1197,393 @@ class Proto3Test(unittest.TestCase): self.assertEqual(1234567, m2.optional_nested_enum) self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # Map isn't really a proto3-only feature. But there is no proto2 equivalent + # of google/protobuf/map_unittest.proto right now, so it's not easy to + # test both with the same test like we do for the other proto2/proto3 tests. + # (google/protobuf/map_protobuf_unittest.proto is very different in the set + # of messages and fields it contains). + def testScalarMapDefaults(self): + msg = map_unittest_pb2.TestMap() + + # Scalars start out unset. + self.assertFalse(-123 in msg.map_int32_int32) + self.assertFalse(-2**33 in msg.map_int64_int64) + self.assertFalse(123 in msg.map_uint32_uint32) + self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(888 in msg.map_int32_enum) + + # Accessing an unset key returns the default. + self.assertEqual(0, msg.map_int32_int32[-123]) + self.assertEqual(0, msg.map_int64_int64[-2**33]) + self.assertEqual(0, msg.map_uint32_uint32[123]) + self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(0, msg.map_int32_enum[888]) + + # It also sets the value in the map + self.assertTrue(-123 in msg.map_int32_int32) + self.assertTrue(-2**33 in msg.map_int64_int64) + self.assertTrue(123 in msg.map_uint32_uint32) + self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(888 in msg.map_int32_enum) + + self.assertTrue(isinstance(msg.map_string_string['abc'], unicode)) + + # Accessing an unset key still throws TypeError of the type of the key + # is incorrect. + with self.assertRaises(TypeError): + msg.map_string_string[123] + + self.assertFalse(123 in msg.map_string_string) + + def testMapGet(self): + # Need to test that get() properly returns the default, even though the dict + # has defaultdict-like semantics. + msg = map_unittest_pb2.TestMap() + + self.assertIsNone(msg.map_int32_int32.get(5)) + self.assertEquals(10, msg.map_int32_int32.get(5, 10)) + self.assertIsNone(msg.map_int32_int32.get(5)) + + msg.map_int32_int32[5] = 15 + self.assertEquals(15, msg.map_int32_int32.get(5)) + + self.assertIsNone(msg.map_int32_foreign_message.get(5)) + self.assertEquals(10, msg.map_int32_foreign_message.get(5, 10)) + + submsg = msg.map_int32_foreign_message[5] + self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) + + def testScalarMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + self.assertFalse(5 in msg.map_int32_int32) + + msg.map_int32_int32[-123] = -456 + msg.map_int64_int64[-2**33] = -2**34 + msg.map_uint32_uint32[123] = 456 + msg.map_uint64_uint64[2**33] = 2**34 + msg.map_string_string['abc'] = '123' + msg.map_int32_enum[888] = 2 + + self.assertEqual([], msg.FindInitializationErrors()) + + self.assertEqual(1, len(msg.map_string_string)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_string_string[123] = '123' + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(1, len(msg.map_string_string)) + + # Bad value. + with self.assertRaises(TypeError): + msg.map_string_string['123'] = 123 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + # Bad key. + with self.assertRaises(TypeError): + msg2.map_string_string[123] = '123' + + # Bad value. + with self.assertRaises(TypeError): + msg2.map_string_string['123'] = 123 + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + self.assertEqual('123', msg2.map_string_string['abc']) + self.assertEqual(2, msg2.map_int32_enum[888]) + + def testStringUnicodeConversionInMap(self): + msg = map_unittest_pb2.TestMap() + + unicode_obj = u'\u1234' + bytes_obj = unicode_obj.encode('utf8') + + msg.map_string_string[bytes_obj] = bytes_obj + + (key, value) = msg.map_string_string.items()[0] + + self.assertEqual(key, unicode_obj) + self.assertEqual(value, unicode_obj) + + self.assertTrue(isinstance(key, unicode)) + self.assertTrue(isinstance(value, unicode)) + + def testMessageMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_foreign_message)) + self.assertFalse(5 in msg.map_int32_foreign_message) + + msg.map_int32_foreign_message[123] + # get_or_create() is an alias for getitem. + msg.map_int32_foreign_message.get_or_create(-456) + + self.assertEqual(2, len(msg.map_int32_foreign_message)) + self.assertIn(123, msg.map_int32_foreign_message) + self.assertIn(-456, msg.map_int32_foreign_message) + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_int32_foreign_message['123'] + + # Can't assign directly to submessage. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + self.assertIn(123, msg2.map_int32_foreign_message) + self.assertIn(-456, msg2.map_int32_foreign_message) + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + + def testMergeFrom(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[12] = 34 + msg.map_int32_int32[56] = 78 + msg.map_int64_int64[22] = 33 + msg.map_int32_foreign_message[111].c = 5 + msg.map_int32_foreign_message[222].c = 10 + + msg2 = map_unittest_pb2.TestMap() + msg2.map_int32_int32[12] = 55 + msg2.map_int64_int64[88] = 99 + msg2.map_int32_foreign_message[222].c = 15 + + msg2.MergeFrom(msg) + + self.assertEqual(34, msg2.map_int32_int32[12]) + self.assertEqual(78, msg2.map_int32_int32[56]) + self.assertEqual(33, msg2.map_int64_int64[22]) + self.assertEqual(99, msg2.map_int64_int64[88]) + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + self.assertEqual(10, msg2.map_int32_foreign_message[222].c) + + # Verify that there is only one entry per key, even though the MergeFrom + # may have internally created multiple entries for a single key in the + # list representation. + as_dict = {} + for key in msg2.map_int32_foreign_message: + self.assertFalse(key in as_dict) + as_dict[key] = msg2.map_int32_foreign_message[key].c + + self.assertEqual({111: 5, 222: 10}, as_dict) + + # Special case: test that delete of item really removes the item, even if + # there might have physically been duplicate keys due to the previous merge. + # This is only a special case for the C++ implementation which stores the + # map as an array. + del msg2.map_int32_int32[12] + self.assertFalse(12 in msg2.map_int32_int32) + + del msg2.map_int32_foreign_message[222] + self.assertFalse(222 in msg2.map_int32_foreign_message) + + def testIntegerMapWithLongs(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[long(-123)] = long(-456) + msg.map_int64_int64[long(-2**33)] = long(-2**34) + msg.map_uint32_uint32[long(123)] = long(456) + msg.map_uint64_uint64[long(2**33)] = long(2**34) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + + def testMapAssignmentCausesPresence(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_int32[123] = 456 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_int32[888] = 999 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_int32.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testMapAssignmentCausesPresenceForSubmessages(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_foreign_message[123].c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_foreign_message[888].c = 7 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message[888].MergeFrom( + msg.test_map.map_int32_foreign_message[123]) + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testModifyMapWhileIterating(self): + msg = map_unittest_pb2.TestMap() + + string_string_iter = iter(msg.map_string_string) + int32_foreign_iter = iter(msg.map_int32_foreign_message) + + msg.map_string_string['abc'] = '123' + msg.map_int32_foreign_message[5].c = 5 + + with self.assertRaises(RuntimeError): + for key in string_string_iter: + pass + + with self.assertRaises(RuntimeError): + for key in int32_foreign_iter: + pass + + def testSubmessageMap(self): + msg = map_unittest_pb2.TestMap() + + submsg = msg.map_int32_foreign_message[111] + self.assertIs(submsg, msg.map_int32_foreign_message[111]) + self.assertTrue(isinstance(submsg, unittest_pb2.ForeignMessage)) + + submsg.c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + + # Doesn't allow direct submessage assignment. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() + + def testMapIteration(self): + msg = map_unittest_pb2.TestMap() + + for k, v in msg.map_int32_int32.iteritems(): + # Should not be reached. + self.assertTrue(False) + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + self.assertEqual(3, len(msg.map_int32_int32)) + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(msg.map_int32_int32.iteritems(), matching_dict) + + def testMapIterationClearMessage(self): + # Iterator needs to work even if message and map are deleted. + msg = map_unittest_pb2.TestMap() + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + + it = msg.map_int32_int32.iteritems() + del msg + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(it, matching_dict) + + def testMapConstruction(self): + msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) + self.assertEqual(2, msg.map_int32_int32[1]) + self.assertEqual(4, msg.map_int32_int32[3]) + + msg = map_unittest_pb2.TestMap( + map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) + self.assertEqual(5, msg.map_int32_foreign_message[3].c) + + def testMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + map = msg.map_int32_int32 + + map[2] = 4 + map[3] = 6 + map[4] = 8 + + msg.ClearField('map_int32_int32') + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(map.iteritems(), matching_dict) + + def testMapIterValidAfterFieldCleared(self): + # Map iterator needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + + it = msg.map_int32_int32.iteritems() + + msg.ClearField('map_int32_int32') + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(it, matching_dict) + + def testMapDelete(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + + msg.map_int32_int32[4] = 6 + self.assertEqual(1, len(msg.map_int32_int32)) + + with self.assertRaises(KeyError): + del msg.map_int32_int32[88] + + del msg.map_int32_int32[4] + self.assertEqual(0, len(msg.map_int32_int32)) + + class ValidTypeNamesTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py index b1e57f35..edaf3fa3 100644 --- a/python/google/protobuf/internal/proto_builder_test.py +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -32,6 +32,7 @@ """Tests for google.protobuf.proto_builder.""" +import collections import unittest from google.protobuf import descriptor_pb2 @@ -43,10 +44,11 @@ from google.protobuf import text_format class ProtoBuilderTest(unittest.TestCase): def setUp(self): - self._fields = { - 'foo': descriptor_pb2.FieldDescriptorProto.TYPE_INT64, - 'bar': descriptor_pb2.FieldDescriptorProto.TYPE_STRING, - } + self.ordered_fields = collections.OrderedDict([ + ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64), + ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING), + ]) + self._fields = dict(self.ordered_fields) def testMakeSimpleProtoClass(self): """Test that we can create a proto class.""" @@ -59,6 +61,17 @@ class ProtoBuilderTest(unittest.TestCase): self.assertMultiLineEqual( 'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto)) + def testOrderedFields(self): + """Test that the field order is maintained when given an OrderedDict.""" + proto_cls = proto_builder.MakeSimpleProtoClass( + self.ordered_fields, + full_name='net.proto2.python.public.proto_builder_test.OrderedTest') + proto = proto_cls() + proto.foo = 12345 + proto.bar = 'asdf' + self.assertMultiLineEqual( + 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto)) + def testMakeSameProtoClassTwice(self): """Test that the DescriptorPool is used.""" pool = descriptor_pool.DescriptorPool() diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 54f584ae..ca9f7675 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -61,9 +61,11 @@ if sys.version_info[0] < 3: except ImportError: from StringIO import StringIO as BytesIO import copy_reg as copyreg + _basestring = basestring else: from io import BytesIO import copyreg + _basestring = str import struct import weakref @@ -77,6 +79,7 @@ from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod +from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor @@ -101,6 +104,7 @@ def InitMessage(descriptor, cls): for field in descriptor.fields: _AttachFieldHelpers(cls, field) + descriptor._concrete_class = cls # pylint: disable=protected-access _AddEnumValues(descriptor, cls) _AddInitMethod(descriptor, cls) _AddPropertiesForFields(descriptor, cls) @@ -198,12 +202,37 @@ def _IsMessageSetExtension(field): field.label == _FieldDescriptor.LABEL_OPTIONAL) +def _IsMapField(field): + return (field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _IsMessageMapField(field): + value_type = field.message_type.fields_by_name["value"] + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - - if _IsMessageSetExtension(field_descriptor): + is_packable = (is_repeated and + wire_format.IsTypePackable(field_descriptor.type)) + if not is_packable: + is_packed = False + elif field_descriptor.containing_type.syntax == "proto2": + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + else: + has_packed_false = (field_descriptor.has_options and + field_descriptor.GetOptions().HasField("packed") and + field_descriptor.GetOptions().packed == False) + is_packed = not has_packed_false + is_map_entry = _IsMapField(field_descriptor) + + if is_map_entry: + field_encoder = encoder.MapEncoder(field_descriptor) + sizer = encoder.MapSizer(field_descriptor) + elif _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) else: @@ -228,9 +257,16 @@ def _AttachFieldHelpers(cls, field_descriptor): if field_descriptor.containing_oneof is not None: oneof_descriptor = field_descriptor - field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor) + if is_map_entry: + is_message_map = _IsMessageMapField(field_descriptor) + + field_decoder = decoder.MapDecoder( + field_descriptor, _GetInitializeDefaultForMap(field_descriptor), + is_message_map) + else: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) @@ -265,6 +301,26 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) +def _GetInitializeDefaultForMap(field): + if field.label != _FieldDescriptor.LABEL_REPEATED: + raise ValueError('map_entry set on non-repeated field %s' % ( + field.name)) + fields_by_name = field.message_type.fields_by_name + key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) + + value_field = fields_by_name['value'] + if _IsMessageMapField(field): + def MakeMessageMapDefault(message): + return containers.MessageMap( + message._listener_for_children, value_field.message_type, key_checker) + return MakeMessageMapDefault + else: + value_checker = type_checkers.GetTypeChecker(value_field) + def MakePrimitiveMapDefault(message): + return containers.ScalarMap( + message._listener_for_children, key_checker, value_checker) + return MakePrimitiveMapDefault + def _DefaultValueConstructorForField(field): """Returns a function which returns a default value for a field. @@ -279,6 +335,9 @@ def _DefaultValueConstructorForField(field): value may refer back to |message| via a weak reference. """ + if _IsMapField(field): + return _GetInitializeDefaultForMap(field) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( @@ -329,7 +388,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name): def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" - fields = message_descriptor.fields + + def _GetIntegerEnumValue(enum_type, value): + """Convert a string or integer enum value to an integer. + + If the value is a string, it is converted to the enum value in + enum_type with the same name. If the value is not a string, it's + returned as-is. (No conversion or bounds-checking is done.) + """ + if isinstance(value, _basestring): + try: + return enum_type.values_by_name[value].number + except KeyError: + raise ValueError('Enum type %s: unknown label "%s"' % ( + enum_type.full_name, value)) + return value + def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 @@ -352,19 +426,37 @@ def _AddInitMethod(message_descriptor, cls): if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - for val in field_value: - copy.add().MergeFrom(val) + if _IsMapField(field): + if _IsMessageMapField(field): + for key in field_value: + copy[key].MergeFrom(field_value[key]) + else: + copy.update(field_value) + else: + for val in field_value: + if isinstance(val, dict): + copy.add(**val) + else: + copy.add().MergeFrom(val) else: # Scalar + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = [_GetIntegerEnumValue(field.enum_type, val) + for val in field_value] copy.extend(field_value) self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) + new_val = field_value + if isinstance(field_value, dict): + new_val = field.message_type._concrete_class(**field_value) try: - copy.MergeFrom(field_value) + copy.MergeFrom(new_val) except TypeError: _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) self._fields[field] = copy else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = _GetIntegerEnumValue(field.enum_type, field_value) try: setattr(self, field_name, field_value) except TypeError: @@ -758,6 +850,26 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension +def _UnpackAny(msg): + type_url = msg.type_url + db = symbol_database.Default() + + if not type_url: + return None + + # TODO(haberman): For now we just strip the hostname. Better logic will be + # required. + type_name = type_url.split("/")[-1] + descriptor = db.pool.FindMessageTypeByName(type_name) + + if descriptor is None: + return None + + message_class = db.GetPrototype(descriptor) + message = message_class() + + message.ParseFromString(msg.value) + return message def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -769,6 +881,12 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True + if self.DESCRIPTOR.full_name == "google.protobuf.Any": + any_a = _UnpackAny(self) + any_b = _UnpackAny(other) + if any_a and any_b: + return any_a == any_b + if not self.ListFields() == other.ListFields(): return False @@ -961,6 +1079,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: + if (field.message_type.has_options and + field.message_type.GetOptions().map_entry): + continue for element in value: if not element.IsInitialized(): if errors is not None: @@ -996,16 +1117,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): else: name = field.name - if field.label == _FieldDescriptor.LABEL_REPEATED: + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + element = value[key] + prefix = "%s[%d]." % (name, key) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + # ScalarMaps can't have any initialization errors. + pass + elif field.label == _FieldDescriptor.LABEL_REPEATED: for i in xrange(len(value)): element = value[i] prefix = "%s[%d]." % (name, i) sub_errors = element.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] else: prefix = name + "." sub_errors = value.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] return errors diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index ae79c78b..4eca4989 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -39,8 +39,8 @@ import copy import gc import operator import struct -import unittest +import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -1798,8 +1798,8 @@ class ReflectionTest(unittest.TestCase): def testBadArguments(self): # Some of these assertions used to segfault. from google.protobuf.pyext import _message - self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3) - self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42) + self.assertRaises(TypeError, _message.default_pool.FindFieldByName, 3) + self.assertRaises(TypeError, _message.default_pool.FindExtensionByName, 42) self.assertRaises(TypeError, unittest_pb2.TestAllTypes().__getattribute__, 42) diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index de462124..9967255a 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -35,7 +35,6 @@ __author__ = 'petar@google.com (Petar Petrov)' import unittest - from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service @@ -81,7 +80,7 @@ class FooUnitTest(unittest.TestCase): self.assertEqual('Method Bar not implemented.', rpc_controller.failure_message) self.assertEqual(None, self.callback_response) - + class MyServiceImpl(unittest_pb2.TestService): def Foo(self, rpc_controller, request, done): self.foo_called = True diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py index c888aff7..80b83bc2 100644 --- a/python/google/protobuf/internal/symbol_database_test.py +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -33,7 +33,6 @@ """Tests for google.protobuf.symbol_database.""" import unittest - from google.protobuf import unittest_pb2 from google.protobuf import symbol_database diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index d84e3836..0cbdbad9 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -75,7 +75,8 @@ def SetAllNonLazyFields(message): message.optional_string = u'115' message.optional_bytes = b'116' - message.optionalgroup.a = 117 + if IsProto2(message): + message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 @@ -109,7 +110,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'215') message.repeated_bytes.append(b'216') - message.repeatedgroup.add().a = 217 + if IsProto2(message): + message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 @@ -140,7 +142,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'315') message.repeated_bytes.append(b'316') - message.repeatedgroup.add().a = 317 + if IsProto2(message): + message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 @@ -396,7 +399,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertTrue(message.HasField('optional_string')) test_case.assertTrue(message.HasField('optional_bytes')) - test_case.assertTrue(message.HasField('optionalgroup')) + if IsProto2(message): + test_case.assertTrue(message.HasField('optionalgroup')) test_case.assertTrue(message.HasField('optional_nested_message')) test_case.assertTrue(message.HasField('optional_foreign_message')) test_case.assertTrue(message.HasField('optional_import_message')) @@ -430,7 +434,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('115', message.optional_string) test_case.assertEqual(b'116', message.optional_bytes) - test_case.assertEqual(117, message.optionalgroup.a) + if IsProto2(message): + test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) @@ -463,7 +468,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(2, len(message.repeated_string)) test_case.assertEqual(2, len(message.repeated_bytes)) - test_case.assertEqual(2, len(message.repeatedgroup)) + if IsProto2(message): + test_case.assertEqual(2, len(message.repeatedgroup)) test_case.assertEqual(2, len(message.repeated_nested_message)) test_case.assertEqual(2, len(message.repeated_foreign_message)) test_case.assertEqual(2, len(message.repeated_import_message)) @@ -491,7 +497,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('215', message.repeated_string[0]) test_case.assertEqual(b'216', message.repeated_bytes[0]) - test_case.assertEqual(217, message.repeatedgroup[0].a) + if IsProto2(message): + test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) @@ -521,7 +528,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('315', message.repeated_string[1]) test_case.assertEqual(b'316', message.repeated_bytes[1]) - test_case.assertEqual(317, message.repeatedgroup[1].a) + if IsProto2(message): + test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py index 27896b94..5df13b78 100755 --- a/python/google/protobuf/internal/text_encoding_test.py +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -33,7 +33,6 @@ """Tests for google.protobuf.text_encoding.""" import unittest - from google.protobuf import text_encoding TEST_VALUES = [ diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index bf7e06ee..06bd1ee5 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -37,8 +37,10 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re import unittest +import unittest from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 @@ -309,31 +311,6 @@ class TextFormatTest(TextFormatBase): r'"unknown_field".'), text_format.Parse, text, message) - def testParseGroupNotClosed(self, message_module): - message = message_module.TestAllTypes() - text = 'RepeatedGroup: <' - self.assertRaisesRegexp( - text_format.ParseError, '1:16 : Expected ">".', - text_format.Parse, text, message) - - text = 'RepeatedGroup: {' - self.assertRaisesRegexp( - text_format.ParseError, '1:16 : Expected "}".', - text_format.Parse, text, message) - - def testParseEmptyGroup(self, message_module): - message = message_module.TestAllTypes() - text = 'OptionalGroup: {}' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - - message.Clear() - - message = message_module.TestAllTypes() - text = 'OptionalGroup: <>' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - def testParseBadEnumValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' @@ -408,6 +385,14 @@ class TextFormatTest(TextFormatBase): # Ideally the schemas would be made more similar so these tests could pass. class OnlyWorksWithProto2RightNowTests(TextFormatBase): + def testPrintAllFieldsPointy(self, message_module): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() @@ -471,8 +456,49 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): test_util.SetAllFields(message) self.assertEquals(message, parsed_message) + def testPrintMap(self): + message = map_unittest_pb2.TestMap() -# Tests of proto2-only features (MessageSet and extensions). + message.map_int32_int32[-123] = -456 + message.map_int64_int64[-2**33] = -2**34 + message.map_uint32_uint32[123] = 456 + message.map_uint64_uint64[2**33] = 2**34 + message.map_string_string["abc"] = "123" + message.map_int32_foreign_message[111].c = 5 + + # Maps are serialized to text format using their underlying repeated + # representation. + self.CompareToGoldenText( + text_format.MessageToString(message), + 'map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + + +# Tests of proto2-only features (MessageSet, extensions, etc.). class Proto2Tests(TextFormatBase): def testPrintMessageSet(self): @@ -620,6 +646,69 @@ class Proto2Tests(TextFormatBase): 'have multiple "optional_int32" fields.'), text_format.Parse, text, message) + def testParseGroupNotClosed(self): + message = unittest_pb2.TestAllTypes() + text = 'RepeatedGroup: <' + self.assertRaisesRegexp( + text_format.ParseError, '1:16 : Expected ">".', + text_format.Parse, text, message) + text = 'RepeatedGroup: {' + self.assertRaisesRegexp( + text_format.ParseError, '1:16 : Expected "}".', + text_format.Parse, text, message) + + def testParseEmptyGroup(self): + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: {}' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + message.Clear() + + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: <>' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + # Maps aren't really proto2-only, but our test schema only has maps for + # proto2. + def testParseMap(self): + text = ('map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + message = map_unittest_pb2.TestMap() + text_format.Parse(text, message) + + self.assertEqual(-456, message.map_int32_int32[-123]) + self.assertEqual(-2**34, message.map_int64_int64[-2**33]) + self.assertEqual(456, message.map_uint32_uint32[123]) + self.assertEqual(2**34, message.map_uint64_uint64[2**33]) + self.assertEqual("123", message.map_string_string["abc"]) + self.assertEqual(5, message.map_int32_foreign_message[111].c) + class TokenizerTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 76c056c4..f20e526a 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -129,6 +129,9 @@ class IntValueChecker(object): proposed_value = self._TYPE(proposed_value) return proposed_value + def DefaultValue(self): + return 0 + class EnumValueChecker(object): @@ -146,6 +149,9 @@ class EnumValueChecker(object): raise ValueError('Unknown enum value: %d' % proposed_value) return proposed_value + def DefaultValue(self): + return self._enum_type.values[0].number + class UnicodeValueChecker(object): @@ -171,6 +177,9 @@ class UnicodeValueChecker(object): (proposed_value)) return proposed_value + def DefaultValue(self): + return u"" + class Int32ValueChecker(IntValueChecker): # We're sure to use ints instead of longs here since comparison may be more diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 9337ae8a..1b81ae79 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -36,7 +36,6 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' import unittest - from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 5cd7fcb9..78dc1167 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -35,7 +35,6 @@ __author__ = 'robinson@google.com (Will Robinson)' import unittest - from google.protobuf import message from google.protobuf.internal import wire_format |