aboutsummaryrefslogtreecommitdiffhomepage
path: root/python
diff options
context:
space:
mode:
authorGravatar Adam Cozzette <acozzette@google.com>2016-11-17 16:48:38 -0800
committerGravatar Adam Cozzette <acozzette@google.com>2016-11-17 16:59:59 -0800
commit5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch)
tree0276f81f8848a05d84cd7e287b43d665e30f04e3 /python
parente28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff)
Integrated internal changes from Google
Diffstat (limited to 'python')
-rw-r--r--python/google/protobuf/descriptor_pool.py99
-rwxr-xr-xpython/google/protobuf/internal/decoder.py6
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py47
-rw-r--r--python/google/protobuf/internal/message_factory_test.py20
-rwxr-xr-xpython/google/protobuf/internal/python_message.py37
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py67
-rwxr-xr-xpython/google/protobuf/internal/test_util.py154
-rw-r--r--python/google/protobuf/internal/testing_refleaks.py2
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py30
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py15
-rw-r--r--python/google/protobuf/json_format.py4
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc67
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc77
-rw-r--r--python/google/protobuf/pyext/map_container.cc2
-rw-r--r--python/google/protobuf/pyext/message.cc409
-rw-r--r--python/google/protobuf/pyext/message.h24
-rw-r--r--python/google/protobuf/pyext/message_factory.cc66
-rw-r--r--python/google/protobuf/pyext/message_factory.h12
-rw-r--r--python/google/protobuf/pyext/safe_numerics.h164
19 files changed, 982 insertions, 320 deletions
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 5f43ee5f..28b7e843 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -57,6 +57,8 @@ directly instead of this class.
__author__ = 'matthewtoia@google.com (Matt Toia)'
+import collections
+
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import text_encoding
@@ -88,6 +90,14 @@ def _OptionsOrNone(descriptor_proto):
return None
+def _IsMessageSetExtension(field):
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
+
+
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
@@ -115,6 +125,12 @@ class DescriptorPool(object):
self._descriptors = {}
self._enum_descriptors = {}
self._file_descriptors = {}
+ self._toplevel_extensions = {}
+ # We store extensions in two two-level mappings: The first key is the
+ # descriptor of the message being extended, the second key is the extension
+ # full name or its tag number.
+ self._extensions_by_name = collections.defaultdict(dict)
+ self._extensions_by_number = collections.defaultdict(dict)
def Add(self, file_desc_proto):
"""Adds the FileDescriptorProto and its types to this pool.
@@ -170,6 +186,48 @@ class DescriptorPool(object):
self._enum_descriptors[enum_desc.full_name] = enum_desc
self.AddFileDescriptor(enum_desc.file)
+ def AddExtensionDescriptor(self, extension):
+ """Adds a FieldDescriptor describing an extension to the pool.
+
+ Args:
+ extension: A FieldDescriptor.
+
+ Raises:
+ AssertionError: when another extension with the same number extends the
+ same message.
+ TypeError: when the specified extension is not a
+ descriptor.FieldDescriptor.
+ """
+ if not (isinstance(extension, descriptor.FieldDescriptor) and
+ extension.is_extension):
+ raise TypeError('Expected an extension descriptor.')
+
+ if extension.extension_scope is None:
+ self._toplevel_extensions[extension.full_name] = extension
+
+ try:
+ existing_desc = self._extensions_by_number[
+ extension.containing_type][extension.number]
+ except KeyError:
+ pass
+ else:
+ if extension is not existing_desc:
+ raise AssertionError(
+ 'Extensions "%s" and "%s" both try to extend message type "%s" '
+ 'with field number %d.' %
+ (extension.full_name, existing_desc.full_name,
+ extension.containing_type.full_name, extension.number))
+
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+
+ # Also register MessageSet extensions with the type name.
+ if _IsMessageSetExtension(extension):
+ self._extensions_by_name[extension.containing_type][
+ extension.message_type.full_name] = extension
+
def AddFileDescriptor(self, file_desc):
"""Adds a FileDescriptor to the pool, non-recursively.
@@ -302,6 +360,14 @@ class DescriptorPool(object):
A FieldDescriptor, describing the named extension.
"""
full_name = _NormalizeFullyQualifiedName(full_name)
+ try:
+ # The proto compiler does not give any link between the FileDescriptor
+ # and top-level extensions unless the FileDescriptorProto is added to
+ # the DescriptorDatabase, but this can impact memory usage.
+ # So we registered these extensions by name explicitly.
+ return self._toplevel_extensions[full_name]
+ except KeyError:
+ pass
message_name, _, extension_name = full_name.rpartition('.')
try:
# Most extensions are nested inside a message.
@@ -311,6 +377,39 @@ class DescriptorPool(object):
scope = self.FindFileContainingSymbol(full_name)
return scope.extensions_by_name[extension_name]
+ def FindExtensionByNumber(self, message_descriptor, number):
+ """Gets the extension of the specified message with the specified number.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+ number: integer, number of the extension field.
+
+ Returns:
+ A FieldDescriptor describing the extension.
+
+ Raise:
+ KeyError: when no extension with the given number is known for the
+ specified message.
+ """
+ return self._extensions_by_number[message_descriptor][number]
+
+ def FindAllExtensions(self, message_descriptor):
+ """Gets all the known extension of a given message.
+
+ Extensions have to be registered to this pool by calling
+ AddExtensionDescriptor.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+
+ Returns:
+ A list of FieldDescriptor describing the extensions.
+ """
+ return self._extensions_by_number[message_descriptor].values()
+
def _ConvertFileProtoToFileDescriptor(self, file_proto):
"""Creates a FileDescriptor from a proto or returns a cached copy.
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 31869e45..c5f73dc1 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -642,10 +642,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
-def MessageSetItemDecoder(extensions_by_number):
+def MessageSetItemDecoder(descriptor):
"""Returns a decoder for a MessageSet item.
- The parameter is the _extensions_by_number map for the message class.
+ The parameter is the message Descriptor.
The message set message looks like this:
message MessageSet {
@@ -694,7 +694,7 @@ def MessageSetItemDecoder(extensions_by_number):
if message_start == -1:
raise _DecodeError('MessageSet item missing message.')
- extension = extensions_by_number.get(type_id)
+ extension = message.Extensions._FindExtensionByNumber(type_id)
if extension is not None:
value = field_dict.get(extension)
if value is None:
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index d4de2d81..cb6abe6c 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -254,6 +254,53 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
+ def testFindAllExtensions(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ expected_extension_numbers = {one_more_field, another_field}
+ self.assertEqual(expected_extension_numbers, set(extensions))
+ # Verify that mutating the returned list does not affect the pool.
+ extensions.append('unexpected_element')
+ # Get the extensions again, the returned value does not contain the
+ # 'unexpected_element'.
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ self.assertEqual(expected_extension_numbers, set(extensions))
+
+ def testFindExtensionByNumber(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1001)
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1002)
+ self.assertEqual(extension.name, 'another_field')
+ with self.assertRaises(KeyError):
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1234567)
+
def testExtensionsAreNotFields(self):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 7bb7d1ac..4caa2443 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -114,18 +114,18 @@ class MessageFactoryTest(unittest.TestCase):
).issubset(set(messages.keys())))
self._ExerciseDynamicClass(
messages['google.protobuf.python.internal.Factory2Message'])
- self.assertTrue(
- set(['google.protobuf.python.internal.Factory2Message.one_more_field',
- 'google.protobuf.python.internal.another_field'],
- ).issubset(
- set(messages['google.protobuf.python.internal.Factory1Message']
- ._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
+ self.assertTrue(set(
+ ['google.protobuf.python.internal.Factory2Message.one_more_field',
+ 'google.protobuf.python.internal.another_field'],).issubset(set(
+ ext.full_name
+ for ext in factory_msg1.DESCRIPTOR.file.pool.FindAllExtensions(
+ factory_msg1.DESCRIPTOR))))
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
- ext1 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.Factory2Message.one_more_field']
- ext2 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.another_field']
+ ext1 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ ext2 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2'
self.assertEqual('test1', msg1.Extensions[ext1])
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index dc6565d4..c1bd1f9c 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -51,8 +51,8 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
from io import BytesIO
-import sys
import struct
+import sys
import weakref
import six
@@ -162,12 +162,10 @@ class GeneratedProtocolMessageType(type):
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
+ decoder.MessageSetItemDecoder(descriptor), None)
# Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields:
@@ -747,32 +745,21 @@ def _AddPropertiesForExtensions(descriptor, cls):
constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number)
+ # TODO(amauryfa): Migrate all users of these attributes to functions like
+ # pool.FindExtensionByNumber(descriptor).
+ if descriptor.file is not None:
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ pool = descriptor.file.pool
+ cls._extensions_by_number = pool._extensions_by_number[descriptor]
+ cls._extensions_by_name = pool._extensions_by_name[descriptor]
def _AddStaticMethods(cls):
# TODO(robinson): This probably needs to be thread-safe(?)
def RegisterExtension(extension_handle):
extension_handle.containing_type = cls.DESCRIPTOR
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle)
_AttachFieldHelpers(cls, extension_handle)
-
- # Try to insert our extension, failing if an extension with the same number
- # already exists.
- actual_handle = cls._extensions_by_number.setdefault(
- extension_handle.number, extension_handle)
- if actual_handle is not extension_handle:
- raise AssertionError(
- 'Extensions "%s" and "%s" both try to extend message type "%s" with '
- 'field number %d.' %
- (extension_handle.full_name, actual_handle.full_name,
- cls.DESCRIPTOR.full_name, extension_handle.number))
-
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- handle = extension_handle # avoid line wrapping
- if _IsMessageSetExtension(handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
-
cls.RegisterExtension = staticmethod(RegisterExtension)
def FromString(s):
@@ -1230,7 +1217,7 @@ def _AddMergeFromMethod(cls):
if not isinstance(msg, cls):
raise TypeError(
"Parameter to MergeFrom() must be instance of same class: "
- "expected %s got %s." % (cls.__name__, type(msg).__name__))
+ 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
assert msg is not self
self._Modified()
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index dad79c37..0e881015 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -99,12 +99,12 @@ class _MiniDecoder(object):
return wire_format.UnpackTag(self.ReadVarint())
def ReadFloat(self):
- result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+ result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
self._pos += 4
return result
def ReadDouble(self):
- result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+ result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
self._pos += 8
return result
@@ -621,9 +621,15 @@ class ReflectionTest(BaseTestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
- def testIntegerTypes(self):
+ def assertIntegerTypes(self, integer_fn):
+ """Verifies setting of scalar integers.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestGetAndDeserialize(field_name, value, expected_type):
proto = unittest_pb2.TestAllTypes()
+ value = integer_fn(value)
setattr(proto, field_name, value)
self.assertIsInstance(getattr(proto, field_name), expected_type)
proto2 = unittest_pb2.TestAllTypes()
@@ -635,7 +641,7 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
try:
integer_64 = long
- except NameError: # Python3
+ except NameError: # Python3
integer_64 = int
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
@@ -649,9 +655,33 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
- def testSingleScalarBoundsChecking(self):
+ def testIntegerTypes(self):
+ self.assertIntegerTypes(lambda x: x)
+
+ def testNonStandardIntegerTypes(self):
+ self.assertIntegerTypes(test_util.NonStandardInteger)
+
+ def testIllegalValuesForIntegers(self):
+ pb = unittest_pb2.TestAllTypes()
+
+ # Strings are illegal, even when the represent an integer.
+ with self.assertRaises(TypeError):
+ pb.optional_uint64 = '2'
+
+ # The exact error should propagate with a poorly written custom integer.
+ with self.assertRaisesRegexp(RuntimeError, 'my_error'):
+ pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
+
+ def assetIntegerBoundsChecking(self, integer_fn):
+ """Verifies bounds checking for scalar integer fields.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
+ expected_min = integer_fn(expected_min)
+ expected_max = integer_fn(expected_max)
setattr(pb, field_name, expected_min)
self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
@@ -663,11 +693,22 @@ class ReflectionTest(BaseTestCase):
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
+ # A bit of white-box testing since -1 is an int and not a long in C++ and
+ # so goes down a different path.
+ pb = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError):
+ pb.optional_uint64 = integer_fn(-(1 << 63))
pb = unittest_pb2.TestAllTypes()
- pb.optional_nested_enum = 1
+ pb.optional_nested_enum = integer_fn(1)
self.assertEqual(1, pb.optional_nested_enum)
+ def testSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(lambda x: x)
+
+ def testNonStandardSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
+
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
@@ -1187,12 +1228,18 @@ class ReflectionTest(BaseTestCase):
self.assertTrue(not extendee_proto.HasExtension(extension))
def testRegisteredExtensions(self):
- self.assertTrue('protobuf_unittest.optional_int32_extension' in
- unittest_pb2.TestAllExtensions._extensions_by_name)
- self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ pool = unittest_pb2.DESCRIPTOR.pool
+ self.assertTrue(
+ pool.FindExtensionByNumber(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
+ self.assertIs(
+ pool.FindExtensionByName(
+ 'protobuf_unittest.optional_int32_extension').containing_type,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR)
# Make sure extensions haven't been registered into types that shouldn't
# have any.
- self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+ self.assertEqual(0, len(
+ pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 2c805599..269d0e2d 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -36,8 +36,9 @@ This is intentionally modeled on C++ code in
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
+import operator
import os.path
-
import sys
from google.protobuf import unittest_import_pb2
@@ -694,3 +695,154 @@ def SetAllUnpackedFields(message):
message.unpacked_bool.extend([True, False])
message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
+
+
+class NonStandardInteger(numbers.Integral):
+ """An integer object that does not subclass int.
+
+ This is used to verify that both C++ and regular proto systems can handle
+ integer others than int and long and that they handle them in predictable
+ ways.
+
+ NonStandardInteger is the minimal legal specification for a custom Integral.
+ As such, it does not support 0 < x < 5 and it is not hashable.
+
+ Note: This is added here instead of relying on numpy or a similar library with
+ custom integers to limit dependencies.
+ """
+
+ def __init__(self, val, error_string_on_conversion=None):
+ assert isinstance(val, numbers.Integral)
+ if isinstance(val, NonStandardInteger):
+ val = val.val
+ self.val = val
+ self.error_string_on_conversion = error_string_on_conversion
+
+ def __long__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return long(self.val)
+
+ def __abs__(self):
+ return NonStandardInteger(operator.abs(self.val))
+
+ def __add__(self, y):
+ return NonStandardInteger(operator.add(self.val, y))
+
+ def __div__(self, y):
+ return NonStandardInteger(operator.div(self.val, y))
+
+ def __eq__(self, y):
+ return operator.eq(self.val, y)
+
+ def __floordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(self.val, y))
+
+ def __truediv__(self, y):
+ return NonStandardInteger(operator.truediv(self.val, y))
+
+ def __invert__(self):
+ return NonStandardInteger(operator.invert(self.val))
+
+ def __mod__(self, y):
+ return NonStandardInteger(operator.mod(self.val, y))
+
+ def __mul__(self, y):
+ return NonStandardInteger(operator.mul(self.val, y))
+
+ def __neg__(self):
+ return NonStandardInteger(operator.neg(self.val))
+
+ def __pos__(self):
+ return NonStandardInteger(operator.pos(self.val))
+
+ def __pow__(self, y):
+ return NonStandardInteger(operator.pow(self.val, y))
+
+ def __trunc__(self):
+ return int(self.val)
+
+ def __radd__(self, y):
+ return NonStandardInteger(operator.add(y, self.val))
+
+ def __rdiv__(self, y):
+ return NonStandardInteger(operator.div(y, self.val))
+
+ def __rmod__(self, y):
+ return NonStandardInteger(operator.mod(y, self.val))
+
+ def __rmul__(self, y):
+ return NonStandardInteger(operator.mul(y, self.val))
+
+ def __rpow__(self, y):
+ return NonStandardInteger(operator.pow(y, self.val))
+
+ def __rfloordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(y, self.val))
+
+ def __rtruediv__(self, y):
+ return NonStandardInteger(operator.truediv(y, self.val))
+
+ def __lshift__(self, y):
+ return NonStandardInteger(operator.lshift(self.val, y))
+
+ def __rshift__(self, y):
+ return NonStandardInteger(operator.rshift(self.val, y))
+
+ def __rlshift__(self, y):
+ return NonStandardInteger(operator.lshift(y, self.val))
+
+ def __rrshift__(self, y):
+ return NonStandardInteger(operator.rshift(y, self.val))
+
+ def __le__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.le(self.val, y)
+
+ def __lt__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.lt(self.val, y)
+
+ def __and__(self, y):
+ return NonStandardInteger(operator.and_(self.val, y))
+
+ def __or__(self, y):
+ return NonStandardInteger(operator.or_(self.val, y))
+
+ def __xor__(self, y):
+ return NonStandardInteger(operator.xor(self.val, y))
+
+ def __rand__(self, y):
+ return NonStandardInteger(operator.and_(y, self.val))
+
+ def __ror__(self, y):
+ return NonStandardInteger(operator.or_(y, self.val))
+
+ def __rxor__(self, y):
+ return NonStandardInteger(operator.xor(y, self.val))
+
+ def __bool__(self):
+ return self.val
+
+ def __nonzero__(self):
+ return self.val
+
+ def __ceil__(self):
+ return self
+
+ def __floor__(self):
+ return self
+
+ def __int__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return int(self.val)
+
+ def __round__(self):
+ return self
+
+ def __repr__(self):
+ return 'NonStandardInteger(%s)' % self.val
+
diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py
index c461a9f4..8ce06519 100644
--- a/python/google/protobuf/internal/testing_refleaks.py
+++ b/python/google/protobuf/internal/testing_refleaks.py
@@ -124,5 +124,3 @@ else:
def Same(func):
return func
return Same
-
-
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index ab481ab4..176cbd15 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -582,22 +582,20 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
% (letter,) for letter in string.ascii_uppercase))
self.CompareToGoldenText(text_format.MessageToString(message), golden)
- def testMapOrderSemantics(self):
- golden_lines = self.ReadGolden('map_test_data.txt')
- # The C++ implementation emits defaulted-value fields, while the Python
- # implementation does not. Adjusting for this is awkward, but it is
- # valuable to test against a common golden file.
- line_blacklist = (' key: 0\n', ' value: 0\n', ' key: false\n',
- ' value: false\n')
- golden_lines = [line for line in golden_lines if line not in line_blacklist]
-
- message = map_unittest_pb2.TestMap()
- text_format.ParseLines(golden_lines, message)
- candidate = text_format.MessageToString(message)
- # The Python implementation emits "1.0" for the double value that the C++
- # implementation emits as "1".
- candidate = candidate.replace('1.0', '1', 2)
- self.assertMultiLineEqual(candidate, ''.join(golden_lines))
+ # TODO(teboring): In c/137553523, not serializing default value for map entry
+ # message has been fixed. This test needs to be disabled in order to submit
+ # that cl. Add this back when c/137553523 has been submitted.
+ # def testMapOrderSemantics(self):
+ # golden_lines = self.ReadGolden('map_test_data.txt')
+
+ # message = map_unittest_pb2.TestMap()
+ # text_format.ParseLines(golden_lines, message)
+ # candidate = text_format.MessageToString(message)
+ # # The Python implementation emits "1.0" for the double value that the C++
+ # # implementation emits as "1".
+ # candidate = candidate.replace('1.0', '1', 2)
+ # candidate = candidate.replace('0.0', '0', 2)
+ # self.assertMultiLineEqual(candidate, ''.join(golden_lines))
# Tests of proto2-only features (MessageSet, extensions, etc.).
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 1be3ad9a..4a76cd4e 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -45,6 +45,7 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
import six
if six.PY3:
@@ -126,11 +127,11 @@ class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if not self._MIN <= proposed_value <= self._MAX:
+ if not self._MIN <= int(proposed_value) <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
@@ -150,11 +151,11 @@ class EnumValueChecker(object):
self._enum_type = enum_type
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if proposed_value not in self._enum_type.values_by_number:
+ if int(proposed_value) not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
@@ -223,11 +224,11 @@ _VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
_FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault(
- 0.0, float, int, long),
+ 0.0, numbers.Real),
_FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault(
- 0.0, float, int, long),
+ 0.0, numbers.Real),
_FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault(
- False, bool, int),
+ False, bool, numbers.Integral),
_FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
}
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index c42371d0..d02cb091 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -43,9 +43,9 @@ Simple usage example:
__author__ = 'jieluo@google.com (Jie Luo)'
try:
- from collections import OrderedDict
+ from collections import OrderedDict
except ImportError:
- from ordereddict import OrderedDict #PY26
+ from ordereddict import OrderedDict #PY26
import base64
import json
import math
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index a42e5431..fa66bf9a 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -319,6 +319,51 @@ PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) {
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
+PyObject* FindExtensionByNumber(PyDescriptorPool* self, PyObject* args) {
+ PyObject* message_descriptor;
+ int number;
+ if (!PyArg_ParseTuple(args, "Oi", &message_descriptor, &number)) {
+ return NULL;
+ }
+ const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(
+ message_descriptor);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+
+ const FieldDescriptor* extension_descriptor =
+ self->pool->FindExtensionByNumber(descriptor, number);
+ if (extension_descriptor == NULL) {
+ PyErr_Format(PyExc_KeyError, "Couldn't find extension %d", number);
+ return NULL;
+ }
+
+ return PyFieldDescriptor_FromDescriptor(extension_descriptor);
+}
+
+PyObject* FindAllExtensions(PyDescriptorPool* self, PyObject* arg) {
+ const Descriptor* descriptor = PyMessageDescriptor_AsDescriptor(arg);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+
+ std::vector<const FieldDescriptor*> extensions;
+ self->pool->FindAllExtensions(descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyList_New(extensions.size()));
+ if (result == NULL) {
+ return NULL;
+ }
+ for (int i = 0; i < extensions.size(); i++) {
+ PyObject* extension = PyFieldDescriptor_FromDescriptor(extensions[i]);
+ if (extension == NULL) {
+ return NULL;
+ }
+ PyList_SET_ITEM(result.get(), i, extension); // Steals the reference.
+ }
+ return result.release();
+}
+
// These functions should not exist -- the only valid way to create
// descriptors is to call Add() or AddSerializedFile().
// But these AddDescriptor() functions were created in Python and some people
@@ -376,6 +421,22 @@ PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
Py_RETURN_NONE;
}
+PyObject* AddExtensionDescriptor(PyDescriptorPool* self, PyObject* descriptor) {
+ const FieldDescriptor* extension_descriptor =
+ PyFieldDescriptor_AsDescriptor(descriptor);
+ if (!extension_descriptor) {
+ return NULL;
+ }
+ if (extension_descriptor !=
+ self->pool->FindExtensionByName(extension_descriptor->full_name())) {
+ PyErr_Format(PyExc_ValueError,
+ "The extension descriptor %s does not belong to this pool",
+ extension_descriptor->full_name().c_str());
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
// The code below loads new Descriptors from a serialized FileDescriptorProto.
@@ -475,6 +536,8 @@ static PyMethodDef Methods[] = {
"No-op. Add() must have been called before." },
{ "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O,
"No-op. Add() must have been called before." },
+ { "AddExtensionDescriptor", (PyCFunction)AddExtensionDescriptor, METH_O,
+ "No-op. Add() must have been called before." },
{ "FindFileByName", (PyCFunction)FindFileByName, METH_O,
"Searches for a file descriptor by its .proto name." },
@@ -495,6 +558,10 @@ static PyMethodDef Methods[] = {
{ "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O,
"Gets the FileDescriptor containing the specified symbol." },
+ { "FindExtensionByNumber", (PyCFunction)FindExtensionByNumber, METH_VARARGS,
+ "Gets the extension descriptor for the given number." },
+ { "FindAllExtensions", (PyCFunction)FindAllExtensions, METH_O,
+ "Gets all known extensions of the given message descriptor." },
{NULL}
};
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index dbb7bca0..9423c1d8 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -38,6 +38,7 @@
#include <google/protobuf/descriptor.h>
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
@@ -46,6 +47,16 @@
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/shared_ptr.h>
+#if PY_MAJOR_VERSION >= 3
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
namespace google {
namespace protobuf {
namespace python {
@@ -90,6 +101,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ // TODO(plabatut): consider building the class on the fly!
PyObject* sub_message = cmessage::InternalGetSubMessage(
self->parent, descriptor);
if (sub_message == NULL) {
@@ -101,7 +113,17 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- CMessageClass* message_class = message_factory::GetMessageClass(
+ // On the fly message class creation is needed to support the following
+ // situation:
+ // 1- add FileDescriptor to the pool that contains extensions of a message
+ // defined by another proto file. Do not create any message classes.
+ // 2- instantiate an extended message, and access the extension using
+ // the field descriptor.
+ // 3- the extension submessage fails to be returned, because no class has
+ // been created.
+ // It happens when deserializing text proto format, or when enumerating
+ // fields of a deserialized message.
+ CMessageClass* message_class = message_factory::GetOrCreateMessageClass(
cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
if (message_class == NULL) {
@@ -154,34 +176,51 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
- ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString(
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name"));
- if (extensions_by_name == NULL) {
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
+ char* name;
+ Py_ssize_t name_size;
+ if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_name.get(), name);
- if (result == NULL) {
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension =
+ pool->pool->FindExtensionByName(string(name, name_size));
+ if (message_extension == NULL) {
+ // Is is the name of a message set extension?
+ const Descriptor* message_descriptor = pool->pool->FindMessageTypeByName(
+ string(name, name_size));
+ if (message_descriptor && message_descriptor->extension_count() > 0) {
+ const FieldDescriptor* extension = message_descriptor->extension(0);
+ if (extension->is_extension() &&
+ extension->containing_type()->options().message_set_wire_format() &&
+ extension->type() == FieldDescriptor::TYPE_MESSAGE &&
+ extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
+ message_extension = extension;
+ }
+ }
+ }
+ if (message_extension == NULL) {
Py_RETURN_NONE;
- } else {
- Py_INCREF(result);
- return result;
}
+
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
-PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) {
- ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString(
- reinterpret_cast<PyObject*>(self->parent), "_extensions_by_number"));
- if (extensions_by_number == NULL) {
+PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* arg) {
+ int64 number = PyLong_AsLong(arg);
+ if (number == -1 && PyErr_Occurred()) {
return NULL;
}
- PyObject* result = PyDict_GetItem(extensions_by_number.get(), number);
- if (result == NULL) {
+
+ PyDescriptorPool* pool = cmessage::GetFactoryForMessage(self->parent)->pool;
+ const FieldDescriptor* message_extension = pool->pool->FindExtensionByNumber(
+ self->parent->message->GetDescriptor(), number);
+ if (message_extension == NULL) {
Py_RETURN_NONE;
- } else {
- Py_INCREF(result);
- return result;
}
+
+ return PyFieldDescriptor_FromDescriptor(message_extension);
}
ExtensionDict* NewExtensionDict(CMessage *parent) {
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 318c2e7c..088ddf93 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -374,7 +374,7 @@ static int InitializeAndCopyToParentContainer(MapContainer* from,
// A somewhat roundabout way of copying just one field from old_message to
// new_message. This is the best we can do with what Reflection gives us.
Message* mutable_old = from->GetMutableMessage();
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
fields.push_back(from->parent_field_descriptor);
// Move the map field into the new message.
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 7ff99aea..5967a587 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -64,11 +64,11 @@
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
- #define PyInt_Check PyLong_Check
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyInt_FromSize_t PyLong_FromSize_t
@@ -92,8 +92,6 @@ namespace protobuf {
namespace python {
static PyObject* kDESCRIPTOR;
-static PyObject* k_extensions_by_name;
-static PyObject* k_extensions_by_number;
PyObject* EnumTypeWrapper_class;
static PyObject* PythonMessage_class;
static PyObject* kEmptyWeakref;
@@ -128,19 +126,6 @@ static bool AddFieldNumberToClass(
// Finalize the creation of the Message class.
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
- // If there are extension_ranges, the message is "extendable", and extension
- // classes will register themselves in this class.
- if (descriptor->extension_range_count() > 0) {
- ScopedPyObjectPtr by_name(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_name, by_name.get()) < 0) {
- return -1;
- }
- ScopedPyObjectPtr by_number(PyDict_New());
- if (PyObject_SetAttr(cls, k_extensions_by_number, by_number.get()) < 0) {
- return -1;
- }
- }
-
// For each field set: cls.<field>_FIELD_NUMBER = <number>
for (int i = 0; i < descriptor->field_count(); ++i) {
if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
@@ -357,6 +342,61 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) {
#endif // PY_MAJOR_VERSION >= 3
}
+// The _extensions_by_name dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
+static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItemString(result.get(), extensions[i]->full_name().c_str(),
+ extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+// The _extensions_by_number dictionary is built on every access.
+// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
+static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
+ const PyDescriptorPool* pool = self->py_message_factory->pool;
+
+ std::vector<const FieldDescriptor*> extensions;
+ pool->pool->FindAllExtensions(self->message_descriptor, &extensions);
+
+ ScopedPyObjectPtr result(PyDict_New());
+ for (int i = 0; i < extensions.size(); i++) {
+ ScopedPyObjectPtr extension(
+ PyFieldDescriptor_FromDescriptor(extensions[i]));
+ if (extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(extensions[i]->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(result.get(), number.get(), extension.get()) < 0) {
+ return NULL;
+ }
+ }
+ return result.release();
+}
+
+static PyGetSetDef Getters[] = {
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
+ {NULL}
+};
+
} // namespace message_meta
PyTypeObject CMessageClass_Type = {
@@ -389,7 +429,7 @@ PyTypeObject CMessageClass_Type = {
0, // tp_iternext
0, // tp_methods
0, // tp_members
- 0, // tp_getset
+ message_meta::Getters, // tp_getset
0, // tp_base
0, // tp_dict
0, // tp_descr_get
@@ -525,23 +565,10 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
// ---------------------------------------------------------------------
-// Constants used for integer type range checking.
-PyObject* kPythonZero;
-PyObject* kint32min_py;
-PyObject* kint32max_py;
-PyObject* kuint32max_py;
-PyObject* kint64min_py;
-PyObject* kint64max_py;
-PyObject* kuint64max_py;
-
PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
-// Constant PyString values used for GetAttr/GetItem.
-static PyObject* k_cdescriptor;
-static PyObject* kfull_name;
-
/* Is 64bit */
void FormatTypeError(PyObject* arg, char* expected_types) {
PyObject* repr = PyObject_Repr(arg);
@@ -555,68 +582,126 @@ void FormatTypeError(PyObject* arg, char* expected_types) {
}
}
-template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max) {
- bool is_long = PyLong_Check(arg);
-#if PY_MAJOR_VERSION < 3
- if (!PyInt_Check(arg) && !is_long) {
- FormatTypeError(arg, "int, long");
- return false;
+void OutOfRangeError(PyObject* arg) {
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
}
- if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) {
-#else
- if (!is_long) {
- FormatTypeError(arg, "int");
+}
+
+template<class RangeType, class ValueType>
+bool VerifyIntegerCastAndRange(PyObject* arg, ValueType value) {
+ if GOOGLE_PREDICT_FALSE(value == -1 && PyErr_Occurred()) {
+ if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+ // Replace it with the same ValueError as pure python protos instead of
+ // the default one.
+ PyErr_Clear();
+ OutOfRangeError(arg);
+ } // Otherwise propagate existing error.
return false;
}
- if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
- PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
-#endif
- if (!PyErr_Occurred()) {
- PyObject *s = PyObject_Str(arg);
- if (s) {
- PyErr_Format(PyExc_ValueError,
- "Value out of range: %s",
- PyString_AsString(s));
- Py_DECREF(s);
- }
- }
+ if GOOGLE_PREDICT_FALSE(!IsValidNumericCast<RangeType>(value)) {
+ OutOfRangeError(arg);
return false;
}
+ return true;
+}
+
+template<class T>
+bool CheckAndGetInteger(PyObject* arg, T* value) {
+ // The fast path.
#if PY_MAJOR_VERSION < 3
- if (!is_long) {
- *value = static_cast<T>(PyInt_AsLong(arg));
- } else // NOLINT
+ // For the typical case, offer a fast path.
+ if GOOGLE_PREDICT_TRUE(PyInt_Check(arg)) {
+ long int_result = PyInt_AsLong(arg);
+ if GOOGLE_PREDICT_TRUE(IsValidNumericCast<T>(int_result)) {
+ *value = static_cast<T>(int_result);
+ return true;
+ } else {
+ OutOfRangeError(arg);
+ return false;
+ }
+ }
#endif
- {
- if (min == kPythonZero) {
- *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg));
+ // This effectively defines an integer as "an object that can be cast as
+ // an integer and can be used as an ordinal number".
+ // This definition includes everything that implements numbers.Integral
+ // and shouldn't cast the net too wide.
+ if GOOGLE_PREDICT_FALSE(!PyIndex_Check(arg)) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
+
+ // Now we have an integral number so we can safely use PyLong_ functions.
+ // We need to treat the signed and unsigned cases differently in case arg is
+ // holding a value above the maximum for signed longs.
+ if (std::numeric_limits<T>::min() == 0) {
+ // Unsigned case.
+ unsigned PY_LONG_LONG ulong_result;
+ if (PyLong_Check(arg)) {
+ ulong_result = PyLong_AsUnsignedLongLong(arg);
+ } else {
+ // Unlike PyLong_AsLongLong, PyLong_AsUnsignedLongLong is very
+ // picky about the exact type.
+ PyObject* casted = PyNumber_Long(arg);
+ if GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
+ ulong_result = PyLong_AsUnsignedLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, unsigned PY_LONG_LONG>(arg,
+ ulong_result)) {
+ *value = static_cast<T>(ulong_result);
+ } else {
+ return false;
+ }
+ } else {
+ // Signed case.
+ PY_LONG_LONG long_result;
+ PyNumberMethods *nb;
+ if ((nb = arg->ob_type->tp_as_number) != NULL && nb->nb_int != NULL) {
+ // PyLong_AsLongLong requires it to be a long or to have an __int__()
+ // method.
+ long_result = PyLong_AsLongLong(arg);
} else {
- *value = static_cast<T>(PyLong_AsLongLong(arg));
+ // Valid subclasses of numbers.Integral should have a __long__() method
+ // so fall back to that.
+ PyObject* casted = PyNumber_Long(arg);
+ if GOOGLE_PREDICT_FALSE(casted == NULL) {
+ // Propagate existing error.
+ return false;
+ }
+ long_result = PyLong_AsLongLong(casted);
+ Py_DECREF(casted);
+ }
+ if (VerifyIntegerCastAndRange<T, PY_LONG_LONG>(arg, long_result)) {
+ *value = static_cast<T>(long_result);
+ } else {
+ return false;
}
}
+
return true;
}
// These are referenced by repeated_scalar_container, and must
// be explicitly instantiated.
-template bool CheckAndGetInteger<int32>(
- PyObject*, int32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<int64>(
- PyObject*, int64*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint32>(
- PyObject*, uint32*, PyObject*, PyObject*);
-template bool CheckAndGetInteger<uint64>(
- PyObject*, uint64*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<int32>(PyObject*, int32*);
+template bool CheckAndGetInteger<int64>(PyObject*, int64*);
+template bool CheckAndGetInteger<uint32>(PyObject*, uint32*);
+template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
bool CheckAndGetDouble(PyObject* arg, double* value) {
- if (!PyInt_Check(arg) && !PyLong_Check(arg) &&
- !PyFloat_Check(arg)) {
+ *value = PyFloat_AsDouble(arg);
+ if GOOGLE_PREDICT_FALSE(*value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, float");
return false;
}
- *value = PyFloat_AsDouble(arg);
return true;
}
@@ -630,11 +715,13 @@ bool CheckAndGetFloat(PyObject* arg, float* value) {
}
bool CheckAndGetBool(PyObject* arg, bool* value) {
- if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) {
+ long long_value = PyInt_AsLong(arg);
+ if (long_value == -1 && PyErr_Occurred()) {
FormatTypeError(arg, "int, long, bool");
return false;
}
- *value = static_cast<bool>(PyInt_AsLong(arg));
+ *value = static_cast<bool>(long_value);
+
return true;
}
@@ -966,20 +1053,7 @@ int InternalDeleteRepeatedField(
int min, max;
length = reflection->FieldSize(*message, field_descriptor);
- if (PyInt_Check(slice) || PyLong_Check(slice)) {
- from = to = PyLong_AsLong(slice);
- if (from < 0) {
- from = to = length + from;
- }
- step = 1;
- min = max = from;
-
- // Range check.
- if (from < 0 || from >= length) {
- PyErr_Format(PyExc_IndexError, "list assignment index out of range");
- return -1;
- }
- } else if (PySlice_Check(slice)) {
+ if (PySlice_Check(slice)) {
from = to = step = slice_length = 0;
PySlice_GetIndicesEx(
#if PY_MAJOR_VERSION < 3
@@ -996,8 +1070,23 @@ int InternalDeleteRepeatedField(
max = from;
}
} else {
- PyErr_SetString(PyExc_TypeError, "list indices must be integers");
- return -1;
+ from = to = PyLong_AsLong(slice);
+ if (from == -1 && PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ if (from < 0) {
+ from = to = length + from;
+ }
+ step = 1;
+ min = max = from;
+
+ // Range check.
+ if (from < 0 || from >= length) {
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range");
+ return -1;
+ }
}
Py_ssize_t i = from;
@@ -1958,99 +2047,29 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) {
return PyLong_FromLong(self->message->ByteSize());
}
-static PyObject* RegisterExtension(PyObject* cls,
- PyObject* extension_handle) {
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
const FieldDescriptor* descriptor =
GetExtensionDescriptor(extension_handle);
if (descriptor == NULL) {
return NULL;
}
-
- ScopedPyObjectPtr extensions_by_name(
- PyObject_GetAttr(cls, k_extensions_by_name));
- if (extensions_by_name == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class");
+ if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
+ cls->ob_type->tp_name);
return NULL;
}
- ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name));
- if (full_name == NULL) {
+ CMessageClass *message_class = reinterpret_cast<CMessageClass*>(cls);
+ if (message_class == NULL) {
return NULL;
}
-
// If the extension was already registered, check that it is the same.
- PyObject* existing_extension =
- PyDict_GetItem(extensions_by_name.get(), full_name.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
-
- if (PyDict_SetItem(extensions_by_name.get(), full_name.get(),
- extension_handle) < 0) {
- return NULL;
- }
-
- // Also store a mapping from extension number to implementing class.
- ScopedPyObjectPtr extensions_by_number(
- PyObject_GetAttr(cls, k_extensions_by_number));
- if (extensions_by_number == NULL) {
- PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
- return NULL;
- }
-
- ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
- if (number == NULL) {
- return NULL;
- }
-
- // If the extension was already registered by number, check that it is the
- // same.
- existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get());
- if (existing_extension != NULL) {
- const FieldDescriptor* existing_extension_descriptor =
- GetExtensionDescriptor(existing_extension);
- if (existing_extension_descriptor != descriptor) {
- const Descriptor* msg_desc = GetMessageDescriptor(
- reinterpret_cast<PyTypeObject*>(cls));
- PyErr_Format(
- PyExc_ValueError,
- "Extensions \"%s\" and \"%s\" both try to extend message type "
- "\"%s\" with field number %ld.",
- existing_extension_descriptor->full_name().c_str(),
- descriptor->full_name().c_str(),
- msg_desc->full_name().c_str(),
- PyInt_AsLong(number.get()));
- return NULL;
- }
- // Nothing else to do.
- Py_RETURN_NONE;
- }
- if (PyDict_SetItem(extensions_by_number.get(), number.get(),
- extension_handle) < 0) {
+ const FieldDescriptor* existing_extension =
+ message_class->py_message_factory->pool->pool->FindExtensionByNumber(
+ descriptor->containing_type(), descriptor->number());
+ if (existing_extension != NULL && existing_extension != descriptor) {
+ PyErr_SetString(PyExc_ValueError, "Double registration of Extensions");
return NULL;
}
-
- // Check if it's a message set
- if (descriptor->is_extension() &&
- descriptor->containing_type()->options().message_set_wire_format() &&
- descriptor->type() == FieldDescriptor::TYPE_MESSAGE &&
- descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) {
- ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
- descriptor->message_type()->full_name().c_str(),
- descriptor->message_type()->full_name().size()));
- if (message_name == NULL) {
- return NULL;
- }
- PyDict_SetItem(extensions_by_name.get(), message_name.get(),
- extension_handle);
- }
-
Py_RETURN_NONE;
}
@@ -2087,7 +2106,7 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
static PyObject* GetExtensionDict(CMessage* self, void *closure);
static PyObject* ListFields(CMessage* self) {
- vector<const FieldDescriptor*> fields;
+ std::vector<const FieldDescriptor*> fields;
self->message->GetReflection()->ListFields(*self->message, &fields);
// Normally, the list will be exactly the size of the fields.
@@ -2178,7 +2197,7 @@ static PyObject* DiscardUnknownFields(CMessage* self) {
PyObject* FindInitializationErrors(CMessage* self) {
Message* message = self->message;
- vector<string> errors;
+ std::vector<string> errors;
message->FindInitializationErrors(&errors);
PyObject* error_list = PyList_New(errors.size());
@@ -2570,11 +2589,24 @@ static PyObject* GetExtensionDict(CMessage* self, void *closure) {
return NULL;
}
+static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByName(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
+static PyObject* GetExtensionsByNumber(CMessage *self, void *closure) {
+ return message_meta::GetExtensionsByNumber(
+ reinterpret_cast<CMessageClass*>(Py_TYPE(self)), closure);
+}
+
static PyGetSetDef Getters[] = {
{"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"},
+ {"_extensions_by_name", (getter)GetExtensionsByName, NULL},
+ {"_extensions_by_number", (getter)GetExtensionsByNumber, NULL},
{NULL}
};
+
static PyMethodDef Methods[] = {
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -2835,19 +2867,7 @@ void InitGlobals() {
// TODO(gps): Check all return values in this function for NULL and propagate
// the error (MemoryError) on up to result in an import failure. These should
// also be freed and reset to NULL during finalization.
- kPythonZero = PyInt_FromLong(0);
- kint32min_py = PyInt_FromLong(kint32min);
- kint32max_py = PyInt_FromLong(kint32max);
- kuint32max_py = PyLong_FromLongLong(kuint32max);
- kint64min_py = PyLong_FromLongLong(kint64min);
- kint64max_py = PyLong_FromLongLong(kint64max);
- kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
-
kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
- k_cdescriptor = PyString_FromString("_cdescriptor");
- kfull_name = PyString_FromString("full_name");
- k_extensions_by_name = PyString_FromString("_extensions_by_name");
- k_extensions_by_number = PyString_FromString("_extensions_by_number");
PyObject *dummy_obj = PySet_New(NULL);
kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL);
@@ -2887,25 +2907,6 @@ bool InitProto2MessageModule(PyObject *m) {
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
// it here as well to document that subclasses need to set it.
PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
- // Subclasses with message extensions will override _extensions_by_name and
- // _extensions_by_number with fresh mutable dictionaries in AddDescriptors.
- // All other classes can share this same immutable mapping.
- ScopedPyObjectPtr empty_dict(PyDict_New());
- if (empty_dict == NULL) {
- return false;
- }
- ScopedPyObjectPtr immutable_dict(PyDictProxy_New(empty_dict.get()));
- if (immutable_dict == NULL) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_name, immutable_dict.get()) < 0) {
- return false;
- }
- if (PyDict_SetItem(CMessage_Type.tp_dict,
- k_extensions_by_number, immutable_dict.get()) < 0) {
- return false;
- }
PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index 1550724c..ce80497e 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -117,6 +117,7 @@ typedef struct CMessage {
PyObject* weakreflist;
} CMessage;
+extern PyTypeObject CMessageClass_Type;
extern PyTypeObject CMessage_Type;
@@ -235,6 +236,10 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs);
PyObject* MergeFrom(CMessage* self, PyObject* arg);
+// This method does not do anything beyond checking that no other extension
+// has been registered with the same field number on this class.
+PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle);
+
// Retrieves an attribute named 'name' from CMessage 'self'. Returns
// the attribute value on success, or NULL on failure.
//
@@ -275,25 +280,25 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
#define GOOGLE_CHECK_GET_INT32(arg, value, err) \
int32 value; \
- if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_INT64(arg, value, err) \
int64 value; \
- if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \
uint32 value; \
- if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \
uint64 value; \
- if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \
+ if (!CheckAndGetInteger(arg, &value)) { \
return err; \
}
@@ -316,20 +321,11 @@ PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
}
-extern PyObject* kPythonZero;
-extern PyObject* kint32min_py;
-extern PyObject* kint32max_py;
-extern PyObject* kuint32max_py;
-extern PyObject* kint64min_py;
-extern PyObject* kint64max_py;
-extern PyObject* kuint64max_py;
-
#define FULL_MODULE_NAME "google.protobuf.pyext._message"
void FormatTypeError(PyObject* arg, char* expected_types);
template<class T>
-bool CheckAndGetInteger(
- PyObject* arg, T* value, PyObject* min, PyObject* max);
+bool CheckAndGetInteger(PyObject* arg, T* value);
bool CheckAndGetDouble(PyObject* arg, double* value);
bool CheckAndGetFloat(PyObject* arg, float* value);
bool CheckAndGetBool(PyObject* arg, bool* value);
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
index 2ad89022..e0b45bf2 100644
--- a/python/google/protobuf/pyext/message_factory.cc
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -130,6 +130,72 @@ int RegisterMessageClass(PyMessageFactory* self,
return 0;
}
+CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
+ const Descriptor* descriptor) {
+ // This is the same implementation as MessageFactory.GetPrototype().
+ ScopedPyObjectPtr py_descriptor(
+ PyMessageDescriptor_FromDescriptor(descriptor));
+ if (py_descriptor == NULL) {
+ return NULL;
+ }
+ // Do not create a MessageClass that already exists.
+ hash_map<const Descriptor*, CMessageClass*>::iterator it =
+ self->classes_by_descriptor->find(descriptor);
+ if (it != self->classes_by_descriptor->end()) {
+ Py_INCREF(it->second);
+ return it->second;
+ }
+ // Create a new message class.
+ ScopedPyObjectPtr args(Py_BuildValue(
+ "s(){sOsOsO}", descriptor->name().c_str(),
+ "DESCRIPTOR", py_descriptor.get(),
+ "__module__", Py_None,
+ "message_factory", self));
+ if (args == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr message_class(PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get()));
+ if (message_class == NULL) {
+ return NULL;
+ }
+ // Create messages class for the messages used by the fields, and registers
+ // all extensions for these messages during the recursion.
+ for (int field_idx = 0; field_idx < descriptor->field_count(); field_idx++) {
+ const Descriptor* sub_descriptor =
+ descriptor->field(field_idx)->message_type();
+ // It is NULL if the field type is not a message.
+ if (sub_descriptor != NULL) {
+ CMessageClass* result = GetOrCreateMessageClass(self, sub_descriptor);
+ if (result == NULL) {
+ return NULL;
+ }
+ Py_DECREF(result);
+ }
+ }
+
+ // Register extensions defined in this message.
+ for (int ext_idx = 0 ; ext_idx < descriptor->extension_count() ; ext_idx++) {
+ const FieldDescriptor* extension = descriptor->extension(ext_idx);
+ ScopedPyObjectPtr py_extended_class(
+ GetOrCreateMessageClass(self, extension->containing_type())
+ ->AsPyObject());
+ if (py_extended_class == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_extension(PyFieldDescriptor_FromDescriptor(extension));
+ if (py_extension == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr result(cmessage::RegisterExtension(
+ py_extended_class.get(), py_extension.get()));
+ if (result == NULL) {
+ return NULL;
+ }
+ }
+ return reinterpret_cast<CMessageClass*>(message_class.release());
+}
+
// Retrieve the message class added to our database.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor) {
diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h
index 07cccbfb..36092f7e 100644
--- a/python/google/protobuf/pyext/message_factory.h
+++ b/python/google/protobuf/pyext/message_factory.h
@@ -82,14 +82,14 @@ PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool);
int RegisterMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor,
CMessageClass* message_class);
-
-// Retrieves the Python class registered with the given message descriptor.
-//
-// Returns a *borrowed* reference if found, otherwise returns NULL with an
-// exception set.
+// Retrieves the Python class registered with the given message descriptor, or
+// fail with a TypeError. Returns a *borrowed* reference.
CMessageClass* GetMessageClass(PyMessageFactory* self,
const Descriptor* message_descriptor);
-
+// Retrieves the Python class registered with the given message descriptor.
+// The class is created if not done yet. Returns a *new* reference.
+CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor);
} // namespace message_factory
// Initialize objects used by this module.
diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h
new file mode 100644
index 00000000..639ba2c8
--- /dev/null
+++ b/python/google/protobuf/pyext/safe_numerics.h
@@ -0,0 +1,164 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
+// Copied from chromium with only changes to the namespace.
+
+#include <limits>
+
+#include <google/protobuf/stubs/logging.h>
+#include <google/protobuf/stubs/common.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+template <bool SameSize, bool DestLarger,
+ bool DestIsSigned, bool SourceIsSigned>
+struct IsValidNumericCastImpl;
+
+#define BASE_NUMERIC_CAST_CASE_SPECIALIZATION(A, B, C, D, Code) \
+template <> struct IsValidNumericCastImpl<A, B, C, D> { \
+ template <class Source, class DestBounds> static inline bool Test( \
+ Source source, DestBounds min, DestBounds max) { \
+ return Code; \
+ } \
+}
+
+#define BASE_NUMERIC_CAST_CASE_SAME_SIZE(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ true, true, DestSigned, SourceSigned, Code); \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ true, false, DestSigned, SourceSigned, Code)
+
+#define BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ false, false, DestSigned, SourceSigned, Code); \
+
+#define BASE_NUMERIC_CAST_CASE_DEST_LARGER(DestSigned, SourceSigned, Code) \
+ BASE_NUMERIC_CAST_CASE_SPECIALIZATION( \
+ false, true, DestSigned, SourceSigned, Code); \
+
+// The three top level cases are:
+// - Same size
+// - Source larger
+// - Dest larger
+// And for each of those three cases, we handle the 4 different possibilities
+// of signed and unsigned. This gives 12 cases to handle, which we enumerate
+// below.
+//
+// The last argument in each of the macros is the actual comparison code. It
+// has three arguments available, source (the value), and min/max which are
+// the ranges of the destination.
+
+
+// These are the cases where both types have the same size.
+
+// Both signed.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, true, true);
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, false, true);
+// Dest unsigned, Source signed.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(false, true, source >= 0);
+// Dest signed, Source unsigned.
+// This cast is OK because Dest's max must be less than Source's.
+BASE_NUMERIC_CAST_CASE_SAME_SIZE(true, false,
+ source <= static_cast<Source>(max));
+
+
+// These are the cases where Source is larger.
+
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, false, source <= max);
+// Both signed.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, true,
+ source >= min && source <= max);
+// Dest is unsigned, Source is signed.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(false, true,
+ source >= 0 && source <= max);
+// Dest is signed, Source is unsigned.
+// This cast is OK because Dest's max must be less than Source's.
+BASE_NUMERIC_CAST_CASE_SOURCE_LARGER(true, false,
+ source <= static_cast<Source>(max));
+
+
+// These are the cases where Dest is larger.
+
+// Both unsigned.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, false, true);
+// Both signed.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, true, true);
+// Dest is unsigned, Source is signed.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(false, true, source >= 0);
+// Dest is signed, Source is unsigned.
+BASE_NUMERIC_CAST_CASE_DEST_LARGER(true, false, true);
+
+#undef BASE_NUMERIC_CAST_CASE_SPECIALIZATION
+#undef BASE_NUMERIC_CAST_CASE_SAME_SIZE
+#undef BASE_NUMERIC_CAST_CASE_SOURCE_LARGER
+#undef BASE_NUMERIC_CAST_CASE_DEST_LARGER
+
+
+// The main test for whether the conversion will under or overflow.
+template <class Dest, class Source>
+inline bool IsValidNumericCast(Source source) {
+ typedef std::numeric_limits<Source> SourceLimits;
+ typedef std::numeric_limits<Dest> DestLimits;
+ GOOGLE_COMPILE_ASSERT(SourceLimits::is_specialized, argument_must_be_numeric);
+ GOOGLE_COMPILE_ASSERT(SourceLimits::is_integer, argument_must_be_integral);
+ GOOGLE_COMPILE_ASSERT(DestLimits::is_specialized, result_must_be_numeric);
+ GOOGLE_COMPILE_ASSERT(DestLimits::is_integer, result_must_be_integral);
+
+ return IsValidNumericCastImpl<
+ sizeof(Dest) == sizeof(Source),
+ (sizeof(Dest) > sizeof(Source)),
+ DestLimits::is_signed,
+ SourceLimits::is_signed>::Test(
+ source,
+ DestLimits::min(),
+ DestLimits::max());
+}
+
+// checked_numeric_cast<> is analogous to static_cast<> for numeric types,
+// except that it CHECKs that the specified numeric conversion will not
+// overflow or underflow. Floating point arguments are not currently allowed
+// (this is COMPILE_ASSERTd), though this could be supported if necessary.
+template <class Dest, class Source>
+inline Dest checked_numeric_cast(Source source) {
+ GOOGLE_CHECK(IsValidNumericCast<Dest>(source));
+ return static_cast<Dest>(source);
+}
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__