diff options
author | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:48:38 -0800 |
---|---|---|
committer | Adam Cozzette <acozzette@google.com> | 2016-11-17 16:59:59 -0800 |
commit | 5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch) | |
tree | 0276f81f8848a05d84cd7e287b43d665e30f04e3 /python | |
parent | e28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff) |
Integrated internal changes from Google
Diffstat (limited to 'python')
19 files changed, 982 insertions, 320 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 5f43ee5f..28b7e843 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -57,6 +57,8 @@ directly instead of this class. __author__ = 'matthewtoia@google.com (Matt Toia)' +import collections + from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import text_encoding @@ -88,6 +90,14 @@ def _OptionsOrNone(descriptor_proto): return None +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) + + class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" @@ -115,6 +125,12 @@ class DescriptorPool(object): self._descriptors = {} self._enum_descriptors = {} self._file_descriptors = {} + self._toplevel_extensions = {} + # We store extensions in two two-level mappings: The first key is the + # descriptor of the message being extended, the second key is the extension + # full name or its tag number. + self._extensions_by_name = collections.defaultdict(dict) + self._extensions_by_number = collections.defaultdict(dict) def Add(self, file_desc_proto): """Adds the FileDescriptorProto and its types to this pool. @@ -170,6 +186,48 @@ class DescriptorPool(object): self._enum_descriptors[enum_desc.full_name] = enum_desc self.AddFileDescriptor(enum_desc.file) + def AddExtensionDescriptor(self, extension): + """Adds a FieldDescriptor describing an extension to the pool. + + Args: + extension: A FieldDescriptor. + + Raises: + AssertionError: when another extension with the same number extends the + same message. + TypeError: when the specified extension is not a + descriptor.FieldDescriptor. + """ + if not (isinstance(extension, descriptor.FieldDescriptor) and + extension.is_extension): + raise TypeError('Expected an extension descriptor.') + + if extension.extension_scope is None: + self._toplevel_extensions[extension.full_name] = extension + + try: + existing_desc = self._extensions_by_number[ + extension.containing_type][extension.number] + except KeyError: + pass + else: + if extension is not existing_desc: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" ' + 'with field number %d.' % + (extension.full_name, existing_desc.full_name, + extension.containing_type.full_name, extension.number)) + + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + + # Also register MessageSet extensions with the type name. + if _IsMessageSetExtension(extension): + self._extensions_by_name[extension.containing_type][ + extension.message_type.full_name] = extension + def AddFileDescriptor(self, file_desc): """Adds a FileDescriptor to the pool, non-recursively. @@ -302,6 +360,14 @@ class DescriptorPool(object): A FieldDescriptor, describing the named extension. """ full_name = _NormalizeFullyQualifiedName(full_name) + try: + # The proto compiler does not give any link between the FileDescriptor + # and top-level extensions unless the FileDescriptorProto is added to + # the DescriptorDatabase, but this can impact memory usage. + # So we registered these extensions by name explicitly. + return self._toplevel_extensions[full_name] + except KeyError: + pass message_name, _, extension_name = full_name.rpartition('.') try: # Most extensions are nested inside a message. @@ -311,6 +377,39 @@ class DescriptorPool(object): scope = self.FindFileContainingSymbol(full_name) return scope.extensions_by_name[extension_name] + def FindExtensionByNumber(self, message_descriptor, number): + """Gets the extension of the specified message with the specified number. + + Extensions have to be registered to this pool by calling + AddExtensionDescriptor. + + Args: + message_descriptor: descriptor of the extended message. + number: integer, number of the extension field. + + Returns: + A FieldDescriptor describing the extension. + + Raise: + KeyError: when no extension with the given number is known for the + specified message. + """ + return self._extensions_by_number[message_descriptor][number] + + def FindAllExtensions(self, message_descriptor): + """Gets all the known extension of a given message. + + Extensions have to be registered to this pool by calling + AddExtensionDescriptor. + + Args: + message_descriptor: descriptor of the extended message. + + Returns: + A list of FieldDescriptor describing the extensions. + """ + return self._extensions_by_number[message_descriptor].values() + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 31869e45..c5f73dc1 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -642,10 +642,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) -def MessageSetItemDecoder(extensions_by_number): +def MessageSetItemDecoder(descriptor): """Returns a decoder for a MessageSet item. - The parameter is the _extensions_by_number map for the message class. + The parameter is the message Descriptor. The message set message looks like this: message MessageSet { @@ -694,7 +694,7 @@ def MessageSetItemDecoder(extensions_by_number): if message_start == -1: raise _DecodeError('MessageSet item missing message.') - extension = extensions_by_number.get(type_id) + extension = message.Extensions._FindExtensionByNumber(type_id) if extension is not None: value = field_dict.get(extension) if value is None: diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index d4de2d81..cb6abe6c 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -254,6 +254,53 @@ class DescriptorPoolTest(unittest.TestCase): with self.assertRaises(KeyError): self.pool.FindFieldByName('Does not exist') + def testFindAllExtensions(self): + factory1_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + factory2_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + # An extension defined in a message. + one_more_field = factory2_message.extensions_by_name['one_more_field'] + self.pool.AddExtensionDescriptor(one_more_field) + # An extension defined at file scope. + factory_test2 = self.pool.FindFileByName( + 'google/protobuf/internal/factory_test2.proto') + another_field = factory_test2.extensions_by_name['another_field'] + self.pool.AddExtensionDescriptor(another_field) + + extensions = self.pool.FindAllExtensions(factory1_message) + expected_extension_numbers = {one_more_field, another_field} + self.assertEqual(expected_extension_numbers, set(extensions)) + # Verify that mutating the returned list does not affect the pool. + extensions.append('unexpected_element') + # Get the extensions again, the returned value does not contain the + # 'unexpected_element'. + extensions = self.pool.FindAllExtensions(factory1_message) + self.assertEqual(expected_extension_numbers, set(extensions)) + + def testFindExtensionByNumber(self): + factory1_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + factory2_message = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + # An extension defined in a message. + one_more_field = factory2_message.extensions_by_name['one_more_field'] + self.pool.AddExtensionDescriptor(one_more_field) + # An extension defined at file scope. + factory_test2 = self.pool.FindFileByName( + 'google/protobuf/internal/factory_test2.proto') + another_field = factory_test2.extensions_by_name['another_field'] + self.pool.AddExtensionDescriptor(another_field) + + # An extension defined in a message. + extension = self.pool.FindExtensionByNumber(factory1_message, 1001) + self.assertEqual(extension.name, 'one_more_field') + # An extension defined at file scope. + extension = self.pool.FindExtensionByNumber(factory1_message, 1002) + self.assertEqual(extension.name, 'another_field') + with self.assertRaises(KeyError): + extension = self.pool.FindExtensionByNumber(factory1_message, 1234567) + def testExtensionsAreNotFields(self): with self.assertRaises(KeyError): self.pool.FindFieldByName('google.protobuf.python.internal.another_field') diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 7bb7d1ac..4caa2443 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -114,18 +114,18 @@ class MessageFactoryTest(unittest.TestCase): ).issubset(set(messages.keys()))) self._ExerciseDynamicClass( messages['google.protobuf.python.internal.Factory2Message']) - self.assertTrue( - set(['google.protobuf.python.internal.Factory2Message.one_more_field', - 'google.protobuf.python.internal.another_field'], - ).issubset( - set(messages['google.protobuf.python.internal.Factory1Message'] - ._extensions_by_name.keys()))) factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] + self.assertTrue(set( + ['google.protobuf.python.internal.Factory2Message.one_more_field', + 'google.protobuf.python.internal.another_field'],).issubset(set( + ext.full_name + for ext in factory_msg1.DESCRIPTOR.file.pool.FindAllExtensions( + factory_msg1.DESCRIPTOR)))) msg1 = messages['google.protobuf.python.internal.Factory1Message']() - ext1 = factory_msg1._extensions_by_name[ - 'google.protobuf.python.internal.Factory2Message.one_more_field'] - ext2 = factory_msg1._extensions_by_name[ - 'google.protobuf.python.internal.another_field'] + ext1 = msg1.Extensions._FindExtensionByName( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + ext2 = msg1.Extensions._FindExtensionByName( + 'google.protobuf.python.internal.another_field') msg1.Extensions[ext1] = 'test1' msg1.Extensions[ext2] = 'test2' self.assertEqual('test1', msg1.Extensions[ext1]) diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index dc6565d4..c1bd1f9c 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -51,8 +51,8 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' from io import BytesIO -import sys import struct +import sys import weakref import six @@ -162,12 +162,10 @@ class GeneratedProtocolMessageType(type): """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} if (descriptor.has_options and descriptor.GetOptions().message_set_wire_format): cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + decoder.MessageSetItemDecoder(descriptor), None) # Attach stuff to each FieldDescriptor for quick lookup later on. for field in descriptor.fields: @@ -747,32 +745,21 @@ def _AddPropertiesForExtensions(descriptor, cls): constant_name = extension_name.upper() + "_FIELD_NUMBER" setattr(cls, constant_name, extension_field.number) + # TODO(amauryfa): Migrate all users of these attributes to functions like + # pool.FindExtensionByNumber(descriptor). + if descriptor.file is not None: + # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. + pool = descriptor.file.pool + cls._extensions_by_number = pool._extensions_by_number[descriptor] + cls._extensions_by_name = pool._extensions_by_name[descriptor] def _AddStaticMethods(cls): # TODO(robinson): This probably needs to be thread-safe(?) def RegisterExtension(extension_handle): extension_handle.containing_type = cls.DESCRIPTOR + # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. + cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle) _AttachFieldHelpers(cls, extension_handle) - - # Try to insert our extension, failing if an extension with the same number - # already exists. - actual_handle = cls._extensions_by_number.setdefault( - extension_handle.number, extension_handle) - if actual_handle is not extension_handle: - raise AssertionError( - 'Extensions "%s" and "%s" both try to extend message type "%s" with ' - 'field number %d.' % - (extension_handle.full_name, actual_handle.full_name, - cls.DESCRIPTOR.full_name, extension_handle.number)) - - cls._extensions_by_name[extension_handle.full_name] = extension_handle - - handle = extension_handle # avoid line wrapping - if _IsMessageSetExtension(handle): - # MessageSet extension. Also register under type name. - cls._extensions_by_name[ - extension_handle.message_type.full_name] = extension_handle - cls.RegisterExtension = staticmethod(RegisterExtension) def FromString(s): @@ -1230,7 +1217,7 @@ def _AddMergeFromMethod(cls): if not isinstance(msg, cls): raise TypeError( "Parameter to MergeFrom() must be instance of same class: " - "expected %s got %s." % (cls.__name__, type(msg).__name__)) + 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) assert msg is not self self._Modified() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index dad79c37..0e881015 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -99,12 +99,12 @@ class _MiniDecoder(object): return wire_format.UnpackTag(self.ReadVarint()) def ReadFloat(self): - result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] + result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] self._pos += 4 return result def ReadDouble(self): - result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] + result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] self._pos += 8 return result @@ -621,9 +621,15 @@ class ReflectionTest(BaseTestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) - def testIntegerTypes(self): + def assertIntegerTypes(self, integer_fn): + """Verifies setting of scalar integers. + + Args: + integer_fn: A function to wrap the integers that will be assigned. + """ def TestGetAndDeserialize(field_name, value, expected_type): proto = unittest_pb2.TestAllTypes() + value = integer_fn(value) setattr(proto, field_name, value) self.assertIsInstance(getattr(proto, field_name), expected_type) proto2 = unittest_pb2.TestAllTypes() @@ -635,7 +641,7 @@ class ReflectionTest(BaseTestCase): TestGetAndDeserialize('optional_uint32', 1 << 30, int) try: integer_64 = long - except NameError: # Python3 + except NameError: # Python3 integer_64 = int if struct.calcsize('L') == 4: # Python only has signed ints, so 32-bit python can't fit an uint32 @@ -649,9 +655,33 @@ class ReflectionTest(BaseTestCase): TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) - def testSingleScalarBoundsChecking(self): + def testIntegerTypes(self): + self.assertIntegerTypes(lambda x: x) + + def testNonStandardIntegerTypes(self): + self.assertIntegerTypes(test_util.NonStandardInteger) + + def testIllegalValuesForIntegers(self): + pb = unittest_pb2.TestAllTypes() + + # Strings are illegal, even when the represent an integer. + with self.assertRaises(TypeError): + pb.optional_uint64 = '2' + + # The exact error should propagate with a poorly written custom integer. + with self.assertRaisesRegexp(RuntimeError, 'my_error'): + pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') + + def assetIntegerBoundsChecking(self, integer_fn): + """Verifies bounds checking for scalar integer fields. + + Args: + integer_fn: A function to wrap the integers that will be assigned. + """ def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() + expected_min = integer_fn(expected_min) + expected_max = integer_fn(expected_max) setattr(pb, field_name, expected_min) self.assertEqual(expected_min, getattr(pb, field_name)) setattr(pb, field_name, expected_max) @@ -663,11 +693,22 @@ class ReflectionTest(BaseTestCase): TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) + # A bit of white-box testing since -1 is an int and not a long in C++ and + # so goes down a different path. + pb = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError): + pb.optional_uint64 = integer_fn(-(1 << 63)) pb = unittest_pb2.TestAllTypes() - pb.optional_nested_enum = 1 + pb.optional_nested_enum = integer_fn(1) self.assertEqual(1, pb.optional_nested_enum) + def testSingleScalarBoundsChecking(self): + self.assetIntegerBoundsChecking(lambda x: x) + + def testNonStandardSingleScalarBoundsChecking(self): + self.assetIntegerBoundsChecking(test_util.NonStandardInteger) + def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) @@ -1187,12 +1228,18 @@ class ReflectionTest(BaseTestCase): self.assertTrue(not extendee_proto.HasExtension(extension)) def testRegisteredExtensions(self): - self.assertTrue('protobuf_unittest.optional_int32_extension' in - unittest_pb2.TestAllExtensions._extensions_by_name) - self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + pool = unittest_pb2.DESCRIPTOR.pool + self.assertTrue( + pool.FindExtensionByNumber( + unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) + self.assertIs( + pool.FindExtensionByName( + 'protobuf_unittest.optional_int32_extension').containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) # Make sure extensions haven't been registered into types that shouldn't # have any. - self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + self.assertEqual(0, len( + pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 2c805599..269d0e2d 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -36,8 +36,9 @@ This is intentionally modeled on C++ code in __author__ = 'robinson@google.com (Will Robinson)' +import numbers +import operator import os.path - import sys from google.protobuf import unittest_import_pb2 @@ -694,3 +695,154 @@ def SetAllUnpackedFields(message): message.unpacked_bool.extend([True, False]) message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR, unittest_pb2.FOREIGN_BAZ]) + + +class NonStandardInteger(numbers.Integral): + """An integer object that does not subclass int. + + This is used to verify that both C++ and regular proto systems can handle + integer others than int and long and that they handle them in predictable + ways. + + NonStandardInteger is the minimal legal specification for a custom Integral. + As such, it does not support 0 < x < 5 and it is not hashable. + + Note: This is added here instead of relying on numpy or a similar library with + custom integers to limit dependencies. + """ + + def __init__(self, val, error_string_on_conversion=None): + assert isinstance(val, numbers.Integral) + if isinstance(val, NonStandardInteger): + val = val.val + self.val = val + self.error_string_on_conversion = error_string_on_conversion + + def __long__(self): + if self.error_string_on_conversion: + raise RuntimeError(self.error_string_on_conversion) + return long(self.val) + + def __abs__(self): + return NonStandardInteger(operator.abs(self.val)) + + def __add__(self, y): + return NonStandardInteger(operator.add(self.val, y)) + + def __div__(self, y): + return NonStandardInteger(operator.div(self.val, y)) + + def __eq__(self, y): + return operator.eq(self.val, y) + + def __floordiv__(self, y): + return NonStandardInteger(operator.floordiv(self.val, y)) + + def __truediv__(self, y): + return NonStandardInteger(operator.truediv(self.val, y)) + + def __invert__(self): + return NonStandardInteger(operator.invert(self.val)) + + def __mod__(self, y): + return NonStandardInteger(operator.mod(self.val, y)) + + def __mul__(self, y): + return NonStandardInteger(operator.mul(self.val, y)) + + def __neg__(self): + return NonStandardInteger(operator.neg(self.val)) + + def __pos__(self): + return NonStandardInteger(operator.pos(self.val)) + + def __pow__(self, y): + return NonStandardInteger(operator.pow(self.val, y)) + + def __trunc__(self): + return int(self.val) + + def __radd__(self, y): + return NonStandardInteger(operator.add(y, self.val)) + + def __rdiv__(self, y): + return NonStandardInteger(operator.div(y, self.val)) + + def __rmod__(self, y): + return NonStandardInteger(operator.mod(y, self.val)) + + def __rmul__(self, y): + return NonStandardInteger(operator.mul(y, self.val)) + + def __rpow__(self, y): + return NonStandardInteger(operator.pow(y, self.val)) + + def __rfloordiv__(self, y): + return NonStandardInteger(operator.floordiv(y, self.val)) + + def __rtruediv__(self, y): + return NonStandardInteger(operator.truediv(y, self.val)) + + def __lshift__(self, y): + return NonStandardInteger(operator.lshift(self.val, y)) + + def __rshift__(self, y): + return NonStandardInteger(operator.rshift(self.val, y)) + + def __rlshift__(self, y): + return NonStandardInteger(operator.lshift(y, self.val)) + + def __rrshift__(self, y): + return NonStandardInteger(operator.rshift(y, self.val)) + + def __le__(self, y): + if isinstance(y, NonStandardInteger): + y = y.val + return operator.le(self.val, y) + + def __lt__(self, y): + if isinstance(y, NonStandardInteger): + y = y.val + return operator.lt(self.val, y) + + def __and__(self, y): + return NonStandardInteger(operator.and_(self.val, y)) + + def __or__(self, y): + return NonStandardInteger(operator.or_(self.val, y)) + + def __xor__(self, y): + return NonStandardInteger(operator.xor(self.val, y)) + + def __rand__(self, y): + return NonStandardInteger(operator.and_(y, self.val)) + + def __ror__(self, y): + return NonStandardInteger(operator.or_(y, self.val)) + + def __rxor__(self, y): + return NonStandardInteger(operator.xor(y, self.val)) + + def __bool__(self): + return self.val + + def __nonzero__(self): + return self.val + + def __ceil__(self): + return self + + def __floor__(self): + return self + + def __int__(self): + if self.error_string_on_conversion: + raise RuntimeError(self.error_string_on_conversion) + return int(self.val) + + def __round__(self): + return self + + def __repr__(self): + return 'NonStandardInteger(%s)' % self.val + diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py index c461a9f4..8ce06519 100644 --- a/python/google/protobuf/internal/testing_refleaks.py +++ b/python/google/protobuf/internal/testing_refleaks.py @@ -124,5 +124,3 @@ else: def Same(func): return func return Same - - diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index ab481ab4..176cbd15 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -582,22 +582,20 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): % (letter,) for letter in string.ascii_uppercase)) self.CompareToGoldenText(text_format.MessageToString(message), golden) - def testMapOrderSemantics(self): - golden_lines = self.ReadGolden('map_test_data.txt') - # The C++ implementation emits defaulted-value fields, while the Python - # implementation does not. Adjusting for this is awkward, but it is - # valuable to test against a common golden file. - line_blacklist = (' key: 0\n', ' value: 0\n', ' key: false\n', - ' value: false\n') - golden_lines = [line for line in golden_lines if line not in line_blacklist] - - message = map_unittest_pb2.TestMap() - text_format.ParseLines(golden_lines, message) - candidate = text_format.MessageToString(message) - # The Python implementation emits "1.0" for the double value that the C++ - # implementation emits as "1". - candidate = candidate.replace('1.0', '1', 2) - self.assertMultiLineEqual(candidate, ''.join(golden_lines)) + # TODO(teboring): In c/137553523, not serializing default value for map entry + # message has been fixed. This test needs to be disabled in order to submit + # that cl. Add this back when c/137553523 has been submitted. + # def testMapOrderSemantics(self): + # golden_lines = self.ReadGolden('map_test_data.txt') + + # message = map_unittest_pb2.TestMap() + # text_format.ParseLines(golden_lines, message) + # candidate = text_format.MessageToString(message) + # # The Python implementation emits "1.0" for the double value that the C++ + # # implementation emits as "1". + # candidate = candidate.replace('1.0', '1', 2) + # candidate = candidate.replace('0.0', '0', 2) + # self.assertMultiLineEqual(candidate, ''.join(golden_lines)) # Tests of proto2-only features (MessageSet, extensions, etc.). diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 1be3ad9a..4a76cd4e 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -45,6 +45,7 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' +import numbers import six if six.PY3: @@ -126,11 +127,11 @@ class IntValueChecker(object): """Checker used for integer fields. Performs type-check and range check.""" def CheckValue(self, proposed_value): - if not isinstance(proposed_value, six.integer_types): + if not isinstance(proposed_value, numbers.Integral): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), six.integer_types)) raise TypeError(message) - if not self._MIN <= proposed_value <= self._MAX: + if not self._MIN <= int(proposed_value) <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) # We force 32-bit values to int and 64-bit values to long to make # alternate implementations where the distinction is more significant @@ -150,11 +151,11 @@ class EnumValueChecker(object): self._enum_type = enum_type def CheckValue(self, proposed_value): - if not isinstance(proposed_value, six.integer_types): + if not isinstance(proposed_value, numbers.Integral): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), six.integer_types)) raise TypeError(message) - if proposed_value not in self._enum_type.values_by_number: + if int(proposed_value) not in self._enum_type.values_by_number: raise ValueError('Unknown enum value: %d' % proposed_value) return proposed_value @@ -223,11 +224,11 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault( - 0.0, float, int, long), + 0.0, numbers.Real), _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault( - 0.0, float, int, long), + 0.0, numbers.Real), _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault( - False, bool, int), + False, bool, numbers.Integral), _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), } diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index c42371d0..d02cb091 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -43,9 +43,9 @@ Simple usage example: __author__ = 'jieluo@google.com (Jie Luo)' try: - from collections import OrderedDict + from collections import OrderedDict except ImportError: - from ordereddict import OrderedDict #PY26 + from ordereddict import OrderedDict #PY26 import base64 import json import math diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index a42e5431..fa66bf9a 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -319,6 +319,51 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { return PyFileDescriptor_FromDescriptor(file_descriptor); } +PyObject* FindExtensionByNumber(PyDescriptorPool* self, PyObject* args) { + PyObject* message_descriptor; + int number; + if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) { + return NULL; + } + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor( + message_descriptor); + if (descriptor == NULL) { + return NULL; + } + + const FieldDescriptor* extension_descriptor = + self->pool->FindExtensionByNumber(descriptor, number); + if (extension_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find extension %d", number); + return NULL; + } + + return PyFieldDescriptor_FromDescriptor(extension_descriptor); +} + +PyObject* FindAllExtensions(PyDescriptorPool* self, PyObject* arg) { + const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg); + if (descriptor == NULL) { + return NULL; + } + + std::vector<const FieldDescriptor*> extensions; + self->pool->FindAllExtensions(descriptor, &extensions); + + ScopedPyObjectPtr result(PyList_New(extensions.size())); + if (result == NULL) { + return NULL; + } + for (int i = 0; i < extensions.size(); i++) { + PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]); + if (extension == NULL) { + return NULL; + } + PyList_SET_ITEM(result.get(), i, extension); // Steals the reference. + } + return result.release(); +} + // These functions should not exist -- the only valid way to create // descriptors is to call Add() or AddSerializedFile(). // But these AddDescriptor() functions were created in Python and some people @@ -376,6 +421,22 @@ PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) { Py_RETURN_NONE; } +PyObject* AddExtensionDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const FieldDescriptor* extension_descriptor = + PyFieldDescriptor_AsDescriptor(descriptor); + if (!extension_descriptor) { + return NULL; + } + if (extension_descriptor != + self->pool->FindExtensionByName(extension_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The extension descriptor %s does not belong to this pool", + extension_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + // The code below loads new Descriptors from a serialized FileDescriptorProto. @@ -475,6 +536,8 @@ static PyMethodDef Methods[] = { "No-op. Add() must have been called before." }, { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O, "No-op. Add() must have been called before." }, + { "AddExtensionDescriptor", (PyCFunction)AddExtensionDescriptor, METH_O, + "No-op. Add() must have been called before." }, { "FindFileByName", (PyCFunction)FindFileByName, METH_O, "Searches for a file descriptor by its .proto name." }, @@ -495,6 +558,10 @@ static PyMethodDef Methods[] = { { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, "Gets the FileDescriptor containing the specified symbol." }, + { "FindExtensionByNumber", (PyCFunction)FindExtensionByNumber, METH_VARARGS, + "Gets the extension descriptor for the given number." }, + { "FindAllExtensions", (PyCFunction)FindAllExtensions, METH_O, + "Gets all known extensions of the given message descriptor." }, {NULL} }; diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index dbb7bca0..9423c1d8 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -38,6 +38,7 @@ #include <google/protobuf/descriptor.h> #include <google/protobuf/dynamic_message.h> #include <google/protobuf/message.h> +#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/pyext/descriptor.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/message_factory.h> @@ -46,6 +47,16 @@ #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/shared_ptr.h> +#if PY_MAJOR_VERSION >= 3 + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + namespace google { namespace protobuf { namespace python { @@ -90,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // TODO(plabatut): consider building the class on the fly! PyObject* sub_message = cmessage::InternalGetSubMessage( self->parent, descriptor); if (sub_message == NULL) { @@ -101,7 +113,17 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - CMessageClass* message_class = message_factory::GetMessageClass( + // On the fly message class creation is needed to support the following + // situation: + // 1- add FileDescriptor to the pool that contains extensions of a message + // defined by another proto file. Do not create any message classes. + // 2- instantiate an extended message, and access the extension using + // the field descriptor. + // 3- the extension submessage fails to be returned, because no class has + // been created. + // It happens when deserializing text proto format, or when enumerating + // fields of a deserialized message. + CMessageClass* message_class = message_factory::GetOrCreateMessageClass( cmessage::GetFactoryForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { @@ -154,34 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { return 0; } -PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); - if (extensions_by_name == NULL) { +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_name.get(), name); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = + pool->pool->FindExtensionByName(string(name, name_size)); + if (message_extension == NULL) { + // Is is the name of a message set extension? + const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName( + string(name, name_size)); + if (message_descriptor && message_descriptor->extension_count() > 0) { + const FieldDescriptor* extension = message_descriptor->extension(0); + if (extension->is_extension() && + extension->containing_type()->options().message_set_wire_format() && + extension->type() == FieldDescriptor::TYPE_MESSAGE && + extension->label() == FieldDescriptor::LABEL_OPTIONAL) { + message_extension = extension; + } + } + } + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } -PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { - ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( - reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number")); - if (extensions_by_number == NULL) { +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) { + int64 number = PyLong_AsLong(arg); + if (number == -1 && PyErr_Occurred()) { return NULL; } - PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); - if (result == NULL) { + + PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool; + const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber( + self->parent->message->GetDescriptor(), number); + if (message_extension == NULL) { Py_RETURN_NONE; - } else { - Py_INCREF(result); - return result; } + + return PyFieldDescriptor_FromDescriptor(message_extension); } ExtensionDict* NewExtensionDict(CMessage *parent) { diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 318c2e7c..088ddf93 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -374,7 +374,7 @@ static int InitializeAndCopyToParentContainer(MapContainer* from, // A somewhat roundabout way of copying just one field from old_message to // new_message. This is the best we can do with what Reflection gives us. Message* mutable_old = from->GetMutableMessage(); - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; fields.push_back(from->parent_field_descriptor); // Move the map field into the new message. diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 7ff99aea..5967a587 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -64,11 +64,11 @@ #include <google/protobuf/pyext/repeated_scalar_container.h> #include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/message_factory.h> +#include <google/protobuf/pyext/safe_numerics.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> #if PY_MAJOR_VERSION >= 3 - #define PyInt_Check PyLong_Check #define PyInt_AsLong PyLong_AsLong #define PyInt_FromLong PyLong_FromLong #define PyInt_FromSize_t PyLong_FromSize_t @@ -92,8 +92,6 @@ namespace protobuf { namespace python { static PyObject* kDESCRIPTOR; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; @@ -128,19 +126,6 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) { - return -1; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) { - return -1; - } - } - // For each field set: cls.<field>_FIELD_NUMBER = <number> for (int i = 0; i < descriptor->field_count(); ++i) { if (!AddFieldNumberToClass(cls, descriptor->field(i))) { @@ -357,6 +342,61 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { #endif // PY_MAJOR_VERSION >= 3 } +// The _extensions_by_name dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindAllExtensions() +static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(), + extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +// The _extensions_by_number dictionary is built on every access. +// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() +static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + const PyDescriptorPool* pool = self->py_message_factory->pool; + + std::vector<const FieldDescriptor*> extensions; + pool->pool->FindAllExtensions(self->message_descriptor, &extensions); + + ScopedPyObjectPtr result(PyDict_New()); + for (int i = 0; i < extensions.size(); i++) { + ScopedPyObjectPtr extension( + PyFieldDescriptor_FromDescriptor(extensions[i])); + if (extension == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number())); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) { + return NULL; + } + } + return result.release(); +} + +static PyGetSetDef Getters[] = { + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, + {NULL} +}; + } // namespace message_meta PyTypeObject CMessageClass_Type = { @@ -389,7 +429,7 @@ PyTypeObject CMessageClass_Type = { 0, // tp_iternext 0, // tp_methods 0, // tp_members - 0, // tp_getset + message_meta::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -525,23 +565,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -// Constants used for integer type range checking. -PyObject* kPythonZero; -PyObject* kint32min_py; -PyObject* kint32max_py; -PyObject* kuint32max_py; -PyObject* kint64min_py; -PyObject* kint64max_py; -PyObject* kuint64max_py; - PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -// Constant PyString values used for GetAttr/GetItem. -static PyObject* k_cdescriptor; -static PyObject* kfull_name; - /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { PyObject* repr = PyObject_Repr(arg); @@ -555,68 +582,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) { } } -template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max) { - bool is_long = PyLong_Check(arg); -#if PY_MAJOR_VERSION < 3 - if (!PyInt_Check(arg) && !is_long) { - FormatTypeError(arg, "int, long"); - return false; +void OutOfRangeError(PyObject* arg) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); } - if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { -#else - if (!is_long) { - FormatTypeError(arg, "int"); +} + +template<class RangeType, class ValueType> +bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) { + if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // Replace it with the same ValueError as pure python protos instead of + // the default one. + PyErr_Clear(); + OutOfRangeError(arg); + } // Otherwise propagate existing error. return false; } - if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || - PyObject_RichCompareBool(max, arg, Py_GE) != 1) { -#endif - if (!PyErr_Occurred()) { - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); - } - } + if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) { + OutOfRangeError(arg); return false; } + return true; +} + +template<class T> +bool CheckAndGetInteger(PyObject* arg, T* value) { + // The fast path. #if PY_MAJOR_VERSION < 3 - if (!is_long) { - *value = static_cast<T>(PyInt_AsLong(arg)); - } else // NOLINT + // For the typical case, offer a fast path. + if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) { + long int_result = PyInt_AsLong(arg); + if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) { + *value = static_cast<T>(int_result); + return true; + } else { + OutOfRangeError(arg); + return false; + } + } #endif - { - if (min == kPythonZero) { - *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); + // This effectively defines an integer as "an object that can be cast as + // an integer and can be used as an ordinal number". + // This definition includes everything that implements numbers.Integral + // and shouldn't cast the net too wide. + if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) { + FormatTypeError(arg, "int, long"); + return false; + } + + // Now we have an integral number so we can safely use PyLong_ functions. + // We need to treat the signed and unsigned cases differently in case arg is + // holding a value above the maximum for signed longs. + if (std::numeric_limits<T>::min() == 0) { + // Unsigned case. + unsigned PY_LONG_LONG ulong_result; + if (PyLong_Check(arg)) { + ulong_result = PyLong_AsUnsignedLongLong(arg); + } else { + // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very + // picky about the exact type. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + ulong_result = PyLong_AsUnsignedLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg, + ulong_result)) { + *value = static_cast<T>(ulong_result); + } else { + return false; + } + } else { + // Signed case. + PY_LONG_LONG long_result; + PyNumberMethods *nb; + if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) { + // PyLong_AsLongLong requires it to be a long or to have an __int__() + // method. + long_result = PyLong_AsLongLong(arg); } else { - *value = static_cast<T>(PyLong_AsLongLong(arg)); + // Valid subclasses of numbers.Integral should have a __long__() method + // so fall back to that. + PyObject* casted = PyNumber_Long(arg); + if GOOGLE_PREDICT_FALSE(casted == NULL) { + // Propagate existing error. + return false; + } + long_result = PyLong_AsLongLong(casted); + Py_DECREF(casted); + } + if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) { + *value = static_cast<T>(long_result); + } else { + return false; } } + return true; } // These are referenced by repeated_scalar_container, and must // be explicitly instantiated. -template bool CheckAndGetInteger<int32>( - PyObject*, int32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<int64>( - PyObject*, int64*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint32>( - PyObject*, uint32*, PyObject*, PyObject*); -template bool CheckAndGetInteger<uint64>( - PyObject*, uint64*, PyObject*, PyObject*); +template bool CheckAndGetInteger<int32>(PyObject*, int32*); +template bool CheckAndGetInteger<int64>(PyObject*, int64*); +template bool CheckAndGetInteger<uint32>(PyObject*, uint32*); +template bool CheckAndGetInteger<uint64>(PyObject*, uint64*); bool CheckAndGetDouble(PyObject* arg, double* value) { - if (!PyInt_Check(arg) && !PyLong_Check(arg) && - !PyFloat_Check(arg)) { + *value = PyFloat_AsDouble(arg); + if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, float"); return false; } - *value = PyFloat_AsDouble(arg); return true; } @@ -630,11 +715,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) { } bool CheckAndGetBool(PyObject* arg, bool* value) { - if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + long long_value = PyInt_AsLong(arg); + if (long_value == -1 && PyErr_Occurred()) { FormatTypeError(arg, "int, long, bool"); return false; } - *value = static_cast<bool>(PyInt_AsLong(arg)); + *value = static_cast<bool>(long_value); + return true; } @@ -966,20 +1053,7 @@ int InternalDeleteRepeatedField( int min, max; length = reflection->FieldSize(*message, field_descriptor); - if (PyInt_Check(slice) || PyLong_Check(slice)) { - from = to = PyLong_AsLong(slice); - if (from < 0) { - from = to = length + from; - } - step = 1; - min = max = from; - - // Range check. - if (from < 0 || from >= length) { - PyErr_Format(PyExc_IndexError, "list assignment index out of range"); - return -1; - } - } else if (PySlice_Check(slice)) { + if (PySlice_Check(slice)) { from = to = step = slice_length = 0; PySlice_GetIndicesEx( #if PY_MAJOR_VERSION < 3 @@ -996,8 +1070,23 @@ int InternalDeleteRepeatedField( max = from; } } else { - PyErr_SetString(PyExc_TypeError, "list indices must be integers"); - return -1; + from = to = PyLong_AsLong(slice); + if (from == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } } Py_ssize_t i = from; @@ -1958,99 +2047,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { return PyLong_FromLong(self->message->ByteSize()); } -static PyObject* RegisterExtension(PyObject* cls, - PyObject* extension_handle) { +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { return NULL; } - - ScopedPyObjectPtr extensions_by_name( - PyObject_GetAttr(cls, k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", + cls->ob_type->tp_name); return NULL; } - ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); - if (full_name == NULL) { + CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls); + if (message_class == NULL) { return NULL; } - // If the extension was already registered, check that it is the same. - PyObject* existing_extension = - PyDict_GetItem(extensions_by_name.get(), full_name.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - - if (PyDict_SetItem(extensions_by_name.get(), full_name.get(), - extension_handle) < 0) { - return NULL; - } - - // Also store a mapping from extension number to implementing class. - ScopedPyObjectPtr extensions_by_number( - PyObject_GetAttr(cls, k_extensions_by_number)); - if (extensions_by_number == NULL) { - PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); - return NULL; - } - - ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); - if (number == NULL) { - return NULL; - } - - // If the extension was already registered by number, check that it is the - // same. - existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); - if (existing_extension != NULL) { - const FieldDescriptor* existing_extension_descriptor = - GetExtensionDescriptor(existing_extension); - if (existing_extension_descriptor != descriptor) { - const Descriptor* msg_desc = GetMessageDescriptor( - reinterpret_cast<PyTypeObject*>(cls)); - PyErr_Format( - PyExc_ValueError, - "Extensions \"%s\" and \"%s\" both try to extend message type " - "\"%s\" with field number %ld.", - existing_extension_descriptor->full_name().c_str(), - descriptor->full_name().c_str(), - msg_desc->full_name().c_str(), - PyInt_AsLong(number.get())); - return NULL; - } - // Nothing else to do. - Py_RETURN_NONE; - } - if (PyDict_SetItem(extensions_by_number.get(), number.get(), - extension_handle) < 0) { + const FieldDescriptor* existing_extension = + message_class->py_message_factory->pool->pool->FindExtensionByNumber( + descriptor->containing_type(), descriptor->number()); + if (existing_extension != NULL && existing_extension != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); return NULL; } - - // Check if it's a message set - if (descriptor->is_extension() && - descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { - ScopedPyObjectPtr message_name(PyString_FromStringAndSize( - descriptor->message_type()->full_name().c_str(), - descriptor->message_type()->full_name().size())); - if (message_name == NULL) { - return NULL; - } - PyDict_SetItem(extensions_by_name.get(), message_name.get(), - extension_handle); - } - Py_RETURN_NONE; } @@ -2087,7 +2106,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { static PyObject* GetExtensionDict(CMessage* self, void *closure); static PyObject* ListFields(CMessage* self) { - vector<const FieldDescriptor*> fields; + std::vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); // Normally, the list will be exactly the size of the fields. @@ -2178,7 +2197,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) { PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; - vector<string> errors; + std::vector<string> errors; message->FindInitializationErrors(&errors); PyObject* error_list = PyList_New(errors.size()); @@ -2570,11 +2589,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) { return NULL; } +static PyObject* GetExtensionsByName(CMessage *self, void *closure) { + return message_meta::GetExtensionsByName( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + +static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) { + return message_meta::GetExtensionsByNumber( + reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure); +} + static PyGetSetDef Getters[] = { {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, + {"_extensions_by_name", (getter)GetExtensionsByName, NULL}, + {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL}, {NULL} }; + static PyMethodDef Methods[] = { { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, "Makes a deep copy of the class." }, @@ -2835,19 +2867,7 @@ void InitGlobals() { // TODO(gps): Check all return values in this function for NULL and propagate // the error (MemoryError) on up to result in an import failure. These should // also be freed and reset to NULL during finalization. - kPythonZero = PyInt_FromLong(0); - kint32min_py = PyInt_FromLong(kint32min); - kint32max_py = PyInt_FromLong(kint32max); - kuint32max_py = PyLong_FromLongLong(kuint32max); - kint64min_py = PyLong_FromLongLong(kint64min); - kint64max_py = PyLong_FromLongLong(kint64max); - kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); - kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); - k_cdescriptor = PyString_FromString("_cdescriptor"); - kfull_name = PyString_FromString("full_name"); - k_extensions_by_name = PyString_FromString("_extensions_by_name"); - k_extensions_by_number = PyString_FromString("_extensions_by_number"); PyObject *dummy_obj = PySet_New(NULL); kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); @@ -2887,25 +2907,6 @@ bool InitProto2MessageModule(PyObject *m) { // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); - // Subclasses with message extensions will override _extensions_by_name and - // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. - // All other classes can share this same immutable mapping. - ScopedPyObjectPtr empty_dict(PyDict_New()); - if (empty_dict == NULL) { - return false; - } - ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get())); - if (immutable_dict == NULL) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_name, immutable_dict.get()) < 0) { - return false; - } - if (PyDict_SetItem(CMessage_Type.tp_dict, - k_extensions_by_number, immutable_dict.get()) < 0) { - return false; - } PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type)); diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 1550724c..ce80497e 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -117,6 +117,7 @@ typedef struct CMessage { PyObject* weakreflist; } CMessage; +extern PyTypeObject CMessageClass_Type; extern PyTypeObject CMessage_Type; @@ -235,6 +236,10 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs); PyObject* MergeFrom(CMessage* self, PyObject* arg); +// This method does not do anything beyond checking that no other extension +// has been registered with the same field number on this class. +PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle); + // Retrieves an attribute named 'name' from CMessage 'self'. Returns // the attribute value on success, or NULL on failure. // @@ -275,25 +280,25 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); #define GOOGLE_CHECK_GET_INT32(arg, value, err) \ int32 value; \ - if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_INT64(arg, value, err) \ int64 value; \ - if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_UINT32(arg, value, err) \ uint32 value; \ - if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } #define GOOGLE_CHECK_GET_UINT64(arg, value, err) \ uint64 value; \ - if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \ + if (!CheckAndGetInteger(arg, &value)) { \ return err; \ } @@ -316,20 +321,11 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); } -extern PyObject* kPythonZero; -extern PyObject* kint32min_py; -extern PyObject* kint32max_py; -extern PyObject* kuint32max_py; -extern PyObject* kint64min_py; -extern PyObject* kint64max_py; -extern PyObject* kuint64max_py; - #define FULL_MODULE_NAME "google.protobuf.pyext._message" void FormatTypeError(PyObject* arg, char* expected_types); template<class T> -bool CheckAndGetInteger( - PyObject* arg, T* value, PyObject* min, PyObject* max); +bool CheckAndGetInteger(PyObject* arg, T* value); bool CheckAndGetDouble(PyObject* arg, double* value); bool CheckAndGetFloat(PyObject* arg, float* value); bool CheckAndGetBool(PyObject* arg, bool* value); diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index 2ad89022..e0b45bf2 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -130,6 +130,72 @@ int RegisterMessageClass(PyMessageFactory* self, return 0; } +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* descriptor) { + // This is the same implementation as MessageFactory.GetPrototype(). + ScopedPyObjectPtr py_descriptor( + PyMessageDescriptor_FromDescriptor(descriptor)); + if (py_descriptor == NULL) { + return NULL; + } + // Do not create a MessageClass that already exists. + hash_map<const Descriptor*, CMessageClass*>::iterator it = + self->classes_by_descriptor->find(descriptor); + if (it != self->classes_by_descriptor->end()) { + Py_INCREF(it->second); + return it->second; + } + // Create a new message class. + ScopedPyObjectPtr args(Py_BuildValue( + "s(){sOsOsO}", descriptor->name().c_str(), + "DESCRIPTOR", py_descriptor.get(), + "__module__", Py_None, + "message_factory", self)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr message_class(PyObject_CallObject( + reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get())); + if (message_class == NULL) { + return NULL; + } + // Create messages class for the messages used by the fields, and registers + // all extensions for these messages during the recursion. + for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) { + const Descriptor* sub_descriptor = + descriptor->field(field_idx)->message_type(); + // It is NULL if the field type is not a message. + if (sub_descriptor != NULL) { + CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor); + if (result == NULL) { + return NULL; + } + Py_DECREF(result); + } + } + + // Register extensions defined in this message. + for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) { + const FieldDescriptor* extension = descriptor->extension(ext_idx); + ScopedPyObjectPtr py_extended_class( + GetOrCreateMessageClass(self, extension->containing_type()) + ->AsPyObject()); + if (py_extended_class == NULL) { + return NULL; + } + ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension)); + if (py_extension == NULL) { + return NULL; + } + ScopedPyObjectPtr result(cmessage::RegisterExtension( + py_extended_class.get(), py_extension.get())); + if (result == NULL) { + return NULL; + } + } + return reinterpret_cast<CMessageClass*>(message_class.release()); +} + // Retrieve the message class added to our database. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor) { diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h index 07cccbfb..36092f7e 100644 --- a/python/google/protobuf/pyext/message_factory.h +++ b/python/google/protobuf/pyext/message_factory.h @@ -82,14 +82,14 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool); int RegisterMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor, CMessageClass* message_class); - -// Retrieves the Python class registered with the given message descriptor. -// -// Returns a *borrowed* reference if found, otherwise returns NULL with an -// exception set. +// Retrieves the Python class registered with the given message descriptor, or +// fail with a TypeError. Returns a *borrowed* reference. CMessageClass* GetMessageClass(PyMessageFactory* self, const Descriptor* message_descriptor); - +// Retrieves the Python class registered with the given message descriptor. +// The class is created if not done yet. Returns a *new* reference. +CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, + const Descriptor* message_descriptor); } // namespace message_factory // Initialize objects used by this module. diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h new file mode 100644 index 00000000..639ba2c8 --- /dev/null +++ b/python/google/protobuf/pyext/safe_numerics.h @@ -0,0 +1,164 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ +// Copied from chromium with only changes to the namespace. + +#include <limits> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> + +namespace google { +namespace protobuf { +namespace python { + +template <bool SameSize, bool DestLarger, + bool DestIsSigned, bool SourceIsSigned> +struct IsValidNumericCastImpl; + +#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \ +template <> struct IsValidNumericCastImpl<A, B, C, D> { \ + template <class Source, class DestBounds> static inline bool Test( \ + Source source, DestBounds min, DestBounds max) { \ + return Code; \ + } \ +} + +#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, true, DestSigned, SourceSigned, Code); \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + true, false, DestSigned, SourceSigned, Code) + +#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, false, DestSigned, SourceSigned, Code); \ + +#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \ + BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \ + false, true, DestSigned, SourceSigned, Code); \ + +// The three top level cases are: +// - Same size +// - Source larger +// - Dest larger +// And for each of those three cases, we handle the 4 different possibilities +// of signed and unsigned. This gives 12 cases to handle, which we enumerate +// below. +// +// The last argument in each of the macros is the actual comparison code. It +// has three arguments available, source (the value), and min/max which are +// the ranges of the destination. + + +// These are the cases where both types have the same size. + +// Both signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true); +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true); +// Dest unsigned, Source signed. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0); +// Dest signed, Source unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Source is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max); +// Both signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true, + source >= min && source <= max); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true, + source >= 0 && source <= max); +// Dest is signed, Source is unsigned. +// This cast is OK because Dest's max must be less than Source's. +BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false, + source <= static_cast<Source>(max)); + + +// These are the cases where Dest is larger. + +// Both unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true); +// Both signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true); +// Dest is unsigned, Source is signed. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0); +// Dest is signed, Source is unsigned. +BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true); + +#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION +#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE +#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER +#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER + + +// The main test for whether the conversion will under or overflow. +template <class Dest, class Source> +inline bool IsValidNumericCast(Source source) { + typedef std::numeric_limits<Source> SourceLimits; + typedef std::numeric_limits<Dest> DestLimits; + GOOGLE_COMPILE_ASSERT(SourceLimits::is_specialized, argument_must_be_numeric); + GOOGLE_COMPILE_ASSERT(SourceLimits::is_integer, argument_must_be_integral); + GOOGLE_COMPILE_ASSERT(DestLimits::is_specialized, result_must_be_numeric); + GOOGLE_COMPILE_ASSERT(DestLimits::is_integer, result_must_be_integral); + + return IsValidNumericCastImpl< + sizeof(Dest) == sizeof(Source), + (sizeof(Dest) > sizeof(Source)), + DestLimits::is_signed, + SourceLimits::is_signed>::Test( + source, + DestLimits::min(), + DestLimits::max()); +} + +// checked_numeric_cast<> is analogous to static_cast<> for numeric types, +// except that it CHECKs that the specified numeric conversion will not +// overflow or underflow. Floating point arguments are not currently allowed +// (this is COMPILE_ASSERTd), though this could be supported if necessary. +template <class Dest, class Source> +inline Dest checked_numeric_cast(Source source) { + GOOGLE_CHECK(IsValidNumericCast<Dest>(source)); + return static_cast<Dest>(source); +} + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ |