aboutsummaryrefslogtreecommitdiffhomepage
path: root/python/google/protobuf/internal
diff options
context:
space:
mode:
authorGravatar Adam Cozzette <acozzette@google.com>2016-11-17 16:48:38 -0800
committerGravatar Adam Cozzette <acozzette@google.com>2016-11-17 16:59:59 -0800
commit5a76e633ea9b5adb215e93fdc11e1c0c08b3fc74 (patch)
tree0276f81f8848a05d84cd7e287b43d665e30f04e3 /python/google/protobuf/internal
parente28286fa05d8327fd6c5aa70cfb3be558f0932b8 (diff)
Integrated internal changes from Google
Diffstat (limited to 'python/google/protobuf/internal')
-rwxr-xr-xpython/google/protobuf/internal/decoder.py6
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py47
-rw-r--r--python/google/protobuf/internal/message_factory_test.py20
-rwxr-xr-xpython/google/protobuf/internal/python_message.py37
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py67
-rwxr-xr-xpython/google/protobuf/internal/test_util.py154
-rw-r--r--python/google/protobuf/internal/testing_refleaks.py2
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py30
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py15
9 files changed, 304 insertions, 74 deletions
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 31869e45..c5f73dc1 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -642,10 +642,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
-def MessageSetItemDecoder(extensions_by_number):
+def MessageSetItemDecoder(descriptor):
"""Returns a decoder for a MessageSet item.
- The parameter is the _extensions_by_number map for the message class.
+ The parameter is the message Descriptor.
The message set message looks like this:
message MessageSet {
@@ -694,7 +694,7 @@ def MessageSetItemDecoder(extensions_by_number):
if message_start == -1:
raise _DecodeError('MessageSet item missing message.')
- extension = extensions_by_number.get(type_id)
+ extension = message.Extensions._FindExtensionByNumber(type_id)
if extension is not None:
value = field_dict.get(extension)
if value is None:
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index d4de2d81..cb6abe6c 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -254,6 +254,53 @@ class DescriptorPoolTest(unittest.TestCase):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
+ def testFindAllExtensions(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ expected_extension_numbers = {one_more_field, another_field}
+ self.assertEqual(expected_extension_numbers, set(extensions))
+ # Verify that mutating the returned list does not affect the pool.
+ extensions.append('unexpected_element')
+ # Get the extensions again, the returned value does not contain the
+ # 'unexpected_element'.
+ extensions = self.pool.FindAllExtensions(factory1_message)
+ self.assertEqual(expected_extension_numbers, set(extensions))
+
+ def testFindExtensionByNumber(self):
+ factory1_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ factory2_message = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ # An extension defined in a message.
+ one_more_field = factory2_message.extensions_by_name['one_more_field']
+ self.pool.AddExtensionDescriptor(one_more_field)
+ # An extension defined at file scope.
+ factory_test2 = self.pool.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto')
+ another_field = factory_test2.extensions_by_name['another_field']
+ self.pool.AddExtensionDescriptor(another_field)
+
+ # An extension defined in a message.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1001)
+ self.assertEqual(extension.name, 'one_more_field')
+ # An extension defined at file scope.
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1002)
+ self.assertEqual(extension.name, 'another_field')
+ with self.assertRaises(KeyError):
+ extension = self.pool.FindExtensionByNumber(factory1_message, 1234567)
+
def testExtensionsAreNotFields(self):
with self.assertRaises(KeyError):
self.pool.FindFieldByName('google.protobuf.python.internal.another_field')
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 7bb7d1ac..4caa2443 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -114,18 +114,18 @@ class MessageFactoryTest(unittest.TestCase):
).issubset(set(messages.keys())))
self._ExerciseDynamicClass(
messages['google.protobuf.python.internal.Factory2Message'])
- self.assertTrue(
- set(['google.protobuf.python.internal.Factory2Message.one_more_field',
- 'google.protobuf.python.internal.another_field'],
- ).issubset(
- set(messages['google.protobuf.python.internal.Factory1Message']
- ._extensions_by_name.keys())))
factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
+ self.assertTrue(set(
+ ['google.protobuf.python.internal.Factory2Message.one_more_field',
+ 'google.protobuf.python.internal.another_field'],).issubset(set(
+ ext.full_name
+ for ext in factory_msg1.DESCRIPTOR.file.pool.FindAllExtensions(
+ factory_msg1.DESCRIPTOR))))
msg1 = messages['google.protobuf.python.internal.Factory1Message']()
- ext1 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.Factory2Message.one_more_field']
- ext2 = factory_msg1._extensions_by_name[
- 'google.protobuf.python.internal.another_field']
+ ext1 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.Factory2Message.one_more_field')
+ ext2 = msg1.Extensions._FindExtensionByName(
+ 'google.protobuf.python.internal.another_field')
msg1.Extensions[ext1] = 'test1'
msg1.Extensions[ext2] = 'test2'
self.assertEqual('test1', msg1.Extensions[ext1])
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index dc6565d4..c1bd1f9c 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -51,8 +51,8 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
from io import BytesIO
-import sys
import struct
+import sys
import weakref
import six
@@ -162,12 +162,10 @@ class GeneratedProtocolMessageType(type):
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
+ decoder.MessageSetItemDecoder(descriptor), None)
# Attach stuff to each FieldDescriptor for quick lookup later on.
for field in descriptor.fields:
@@ -747,32 +745,21 @@ def _AddPropertiesForExtensions(descriptor, cls):
constant_name = extension_name.upper() + "_FIELD_NUMBER"
setattr(cls, constant_name, extension_field.number)
+ # TODO(amauryfa): Migrate all users of these attributes to functions like
+ # pool.FindExtensionByNumber(descriptor).
+ if descriptor.file is not None:
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ pool = descriptor.file.pool
+ cls._extensions_by_number = pool._extensions_by_number[descriptor]
+ cls._extensions_by_name = pool._extensions_by_name[descriptor]
def _AddStaticMethods(cls):
# TODO(robinson): This probably needs to be thread-safe(?)
def RegisterExtension(extension_handle):
extension_handle.containing_type = cls.DESCRIPTOR
+ # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
+ cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle)
_AttachFieldHelpers(cls, extension_handle)
-
- # Try to insert our extension, failing if an extension with the same number
- # already exists.
- actual_handle = cls._extensions_by_number.setdefault(
- extension_handle.number, extension_handle)
- if actual_handle is not extension_handle:
- raise AssertionError(
- 'Extensions "%s" and "%s" both try to extend message type "%s" with '
- 'field number %d.' %
- (extension_handle.full_name, actual_handle.full_name,
- cls.DESCRIPTOR.full_name, extension_handle.number))
-
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- handle = extension_handle # avoid line wrapping
- if _IsMessageSetExtension(handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
-
cls.RegisterExtension = staticmethod(RegisterExtension)
def FromString(s):
@@ -1230,7 +1217,7 @@ def _AddMergeFromMethod(cls):
if not isinstance(msg, cls):
raise TypeError(
"Parameter to MergeFrom() must be instance of same class: "
- "expected %s got %s." % (cls.__name__, type(msg).__name__))
+ 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
assert msg is not self
self._Modified()
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index dad79c37..0e881015 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -99,12 +99,12 @@ class _MiniDecoder(object):
return wire_format.UnpackTag(self.ReadVarint())
def ReadFloat(self):
- result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
+ result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
self._pos += 4
return result
def ReadDouble(self):
- result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
+ result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
self._pos += 8
return result
@@ -621,9 +621,15 @@ class ReflectionTest(BaseTestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
- def testIntegerTypes(self):
+ def assertIntegerTypes(self, integer_fn):
+ """Verifies setting of scalar integers.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestGetAndDeserialize(field_name, value, expected_type):
proto = unittest_pb2.TestAllTypes()
+ value = integer_fn(value)
setattr(proto, field_name, value)
self.assertIsInstance(getattr(proto, field_name), expected_type)
proto2 = unittest_pb2.TestAllTypes()
@@ -635,7 +641,7 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
try:
integer_64 = long
- except NameError: # Python3
+ except NameError: # Python3
integer_64 = int
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
@@ -649,9 +655,33 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
- def testSingleScalarBoundsChecking(self):
+ def testIntegerTypes(self):
+ self.assertIntegerTypes(lambda x: x)
+
+ def testNonStandardIntegerTypes(self):
+ self.assertIntegerTypes(test_util.NonStandardInteger)
+
+ def testIllegalValuesForIntegers(self):
+ pb = unittest_pb2.TestAllTypes()
+
+ # Strings are illegal, even when the represent an integer.
+ with self.assertRaises(TypeError):
+ pb.optional_uint64 = '2'
+
+ # The exact error should propagate with a poorly written custom integer.
+ with self.assertRaisesRegexp(RuntimeError, 'my_error'):
+ pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
+
+ def assetIntegerBoundsChecking(self, integer_fn):
+ """Verifies bounds checking for scalar integer fields.
+
+ Args:
+ integer_fn: A function to wrap the integers that will be assigned.
+ """
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
+ expected_min = integer_fn(expected_min)
+ expected_max = integer_fn(expected_max)
setattr(pb, field_name, expected_min)
self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
@@ -663,11 +693,22 @@ class ReflectionTest(BaseTestCase):
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
+ # A bit of white-box testing since -1 is an int and not a long in C++ and
+ # so goes down a different path.
+ pb = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError):
+ pb.optional_uint64 = integer_fn(-(1 << 63))
pb = unittest_pb2.TestAllTypes()
- pb.optional_nested_enum = 1
+ pb.optional_nested_enum = integer_fn(1)
self.assertEqual(1, pb.optional_nested_enum)
+ def testSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(lambda x: x)
+
+ def testNonStandardSingleScalarBoundsChecking(self):
+ self.assetIntegerBoundsChecking(test_util.NonStandardInteger)
+
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
@@ -1187,12 +1228,18 @@ class ReflectionTest(BaseTestCase):
self.assertTrue(not extendee_proto.HasExtension(extension))
def testRegisteredExtensions(self):
- self.assertTrue('protobuf_unittest.optional_int32_extension' in
- unittest_pb2.TestAllExtensions._extensions_by_name)
- self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ pool = unittest_pb2.DESCRIPTOR.pool
+ self.assertTrue(
+ pool.FindExtensionByNumber(
+ unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
+ self.assertIs(
+ pool.FindExtensionByName(
+ 'protobuf_unittest.optional_int32_extension').containing_type,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR)
# Make sure extensions haven't been registered into types that shouldn't
# have any.
- self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+ self.assertEqual(0, len(
+ pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 2c805599..269d0e2d 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -36,8 +36,9 @@ This is intentionally modeled on C++ code in
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
+import operator
import os.path
-
import sys
from google.protobuf import unittest_import_pb2
@@ -694,3 +695,154 @@ def SetAllUnpackedFields(message):
message.unpacked_bool.extend([True, False])
message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
+
+
+class NonStandardInteger(numbers.Integral):
+ """An integer object that does not subclass int.
+
+ This is used to verify that both C++ and regular proto systems can handle
+ integer others than int and long and that they handle them in predictable
+ ways.
+
+ NonStandardInteger is the minimal legal specification for a custom Integral.
+ As such, it does not support 0 < x < 5 and it is not hashable.
+
+ Note: This is added here instead of relying on numpy or a similar library with
+ custom integers to limit dependencies.
+ """
+
+ def __init__(self, val, error_string_on_conversion=None):
+ assert isinstance(val, numbers.Integral)
+ if isinstance(val, NonStandardInteger):
+ val = val.val
+ self.val = val
+ self.error_string_on_conversion = error_string_on_conversion
+
+ def __long__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return long(self.val)
+
+ def __abs__(self):
+ return NonStandardInteger(operator.abs(self.val))
+
+ def __add__(self, y):
+ return NonStandardInteger(operator.add(self.val, y))
+
+ def __div__(self, y):
+ return NonStandardInteger(operator.div(self.val, y))
+
+ def __eq__(self, y):
+ return operator.eq(self.val, y)
+
+ def __floordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(self.val, y))
+
+ def __truediv__(self, y):
+ return NonStandardInteger(operator.truediv(self.val, y))
+
+ def __invert__(self):
+ return NonStandardInteger(operator.invert(self.val))
+
+ def __mod__(self, y):
+ return NonStandardInteger(operator.mod(self.val, y))
+
+ def __mul__(self, y):
+ return NonStandardInteger(operator.mul(self.val, y))
+
+ def __neg__(self):
+ return NonStandardInteger(operator.neg(self.val))
+
+ def __pos__(self):
+ return NonStandardInteger(operator.pos(self.val))
+
+ def __pow__(self, y):
+ return NonStandardInteger(operator.pow(self.val, y))
+
+ def __trunc__(self):
+ return int(self.val)
+
+ def __radd__(self, y):
+ return NonStandardInteger(operator.add(y, self.val))
+
+ def __rdiv__(self, y):
+ return NonStandardInteger(operator.div(y, self.val))
+
+ def __rmod__(self, y):
+ return NonStandardInteger(operator.mod(y, self.val))
+
+ def __rmul__(self, y):
+ return NonStandardInteger(operator.mul(y, self.val))
+
+ def __rpow__(self, y):
+ return NonStandardInteger(operator.pow(y, self.val))
+
+ def __rfloordiv__(self, y):
+ return NonStandardInteger(operator.floordiv(y, self.val))
+
+ def __rtruediv__(self, y):
+ return NonStandardInteger(operator.truediv(y, self.val))
+
+ def __lshift__(self, y):
+ return NonStandardInteger(operator.lshift(self.val, y))
+
+ def __rshift__(self, y):
+ return NonStandardInteger(operator.rshift(self.val, y))
+
+ def __rlshift__(self, y):
+ return NonStandardInteger(operator.lshift(y, self.val))
+
+ def __rrshift__(self, y):
+ return NonStandardInteger(operator.rshift(y, self.val))
+
+ def __le__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.le(self.val, y)
+
+ def __lt__(self, y):
+ if isinstance(y, NonStandardInteger):
+ y = y.val
+ return operator.lt(self.val, y)
+
+ def __and__(self, y):
+ return NonStandardInteger(operator.and_(self.val, y))
+
+ def __or__(self, y):
+ return NonStandardInteger(operator.or_(self.val, y))
+
+ def __xor__(self, y):
+ return NonStandardInteger(operator.xor(self.val, y))
+
+ def __rand__(self, y):
+ return NonStandardInteger(operator.and_(y, self.val))
+
+ def __ror__(self, y):
+ return NonStandardInteger(operator.or_(y, self.val))
+
+ def __rxor__(self, y):
+ return NonStandardInteger(operator.xor(y, self.val))
+
+ def __bool__(self):
+ return self.val
+
+ def __nonzero__(self):
+ return self.val
+
+ def __ceil__(self):
+ return self
+
+ def __floor__(self):
+ return self
+
+ def __int__(self):
+ if self.error_string_on_conversion:
+ raise RuntimeError(self.error_string_on_conversion)
+ return int(self.val)
+
+ def __round__(self):
+ return self
+
+ def __repr__(self):
+ return 'NonStandardInteger(%s)' % self.val
+
diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py
index c461a9f4..8ce06519 100644
--- a/python/google/protobuf/internal/testing_refleaks.py
+++ b/python/google/protobuf/internal/testing_refleaks.py
@@ -124,5 +124,3 @@ else:
def Same(func):
return func
return Same
-
-
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index ab481ab4..176cbd15 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -582,22 +582,20 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
% (letter,) for letter in string.ascii_uppercase))
self.CompareToGoldenText(text_format.MessageToString(message), golden)
- def testMapOrderSemantics(self):
- golden_lines = self.ReadGolden('map_test_data.txt')
- # The C++ implementation emits defaulted-value fields, while the Python
- # implementation does not. Adjusting for this is awkward, but it is
- # valuable to test against a common golden file.
- line_blacklist = (' key: 0\n', ' value: 0\n', ' key: false\n',
- ' value: false\n')
- golden_lines = [line for line in golden_lines if line not in line_blacklist]
-
- message = map_unittest_pb2.TestMap()
- text_format.ParseLines(golden_lines, message)
- candidate = text_format.MessageToString(message)
- # The Python implementation emits "1.0" for the double value that the C++
- # implementation emits as "1".
- candidate = candidate.replace('1.0', '1', 2)
- self.assertMultiLineEqual(candidate, ''.join(golden_lines))
+ # TODO(teboring): In c/137553523, not serializing default value for map entry
+ # message has been fixed. This test needs to be disabled in order to submit
+ # that cl. Add this back when c/137553523 has been submitted.
+ # def testMapOrderSemantics(self):
+ # golden_lines = self.ReadGolden('map_test_data.txt')
+
+ # message = map_unittest_pb2.TestMap()
+ # text_format.ParseLines(golden_lines, message)
+ # candidate = text_format.MessageToString(message)
+ # # The Python implementation emits "1.0" for the double value that the C++
+ # # implementation emits as "1".
+ # candidate = candidate.replace('1.0', '1', 2)
+ # candidate = candidate.replace('0.0', '0', 2)
+ # self.assertMultiLineEqual(candidate, ''.join(golden_lines))
# Tests of proto2-only features (MessageSet, extensions, etc.).
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 1be3ad9a..4a76cd4e 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -45,6 +45,7 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)'
+import numbers
import six
if six.PY3:
@@ -126,11 +127,11 @@ class IntValueChecker(object):
"""Checker used for integer fields. Performs type-check and range check."""
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if not self._MIN <= proposed_value <= self._MAX:
+ if not self._MIN <= int(proposed_value) <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
@@ -150,11 +151,11 @@ class EnumValueChecker(object):
self._enum_type = enum_type
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, six.integer_types):
+ if not isinstance(proposed_value, numbers.Integral):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), six.integer_types))
raise TypeError(message)
- if proposed_value not in self._enum_type.values_by_number:
+ if int(proposed_value) not in self._enum_type.values_by_number:
raise ValueError('Unknown enum value: %d' % proposed_value)
return proposed_value
@@ -223,11 +224,11 @@ _VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
_FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
_FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault(
- 0.0, float, int, long),
+ 0.0, numbers.Real),
_FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault(
- 0.0, float, int, long),
+ 0.0, numbers.Real),
_FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault(
- False, bool, int),
+ False, bool, numbers.Integral),
_FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes),
}