aboutsummaryrefslogtreecommitdiffhomepage
path: root/python
diff options
context:
space:
mode:
authorGravatar Bo Yang <teboring@google.com>2016-09-19 13:45:07 -0700
committerGravatar Bo Yang <teboring@google.com>2016-10-10 11:23:36 -0700
commitcc8ca5b6a5478b40546d4206392eb1471454460d (patch)
treec0b45abfa16d7d373a6ea8f7fe50f1de00ab938e /python
parent337a028bb65ccca4dda768695950b5aba53ae2c9 (diff)
Integrate internal changes
Diffstat (limited to 'python')
-rwxr-xr-xpython/google/protobuf/descriptor.py49
-rw-r--r--python/google/protobuf/descriptor_pool.py138
-rw-r--r--python/google/protobuf/internal/any_test.proto14
-rwxr-xr-xpython/google/protobuf/internal/containers.py21
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py5
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py19
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py3
-rw-r--r--python/google/protobuf/internal/json_format_test.py78
-rwxr-xr-xpython/google/protobuf/internal/message_test.py45
-rwxr-xr-xpython/google/protobuf/internal/python_message.py6
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py41
-rw-r--r--python/google/protobuf/internal/testing_refleaks.py124
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py25
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py107
-rw-r--r--python/google/protobuf/internal/well_known_types.py75
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py132
-rw-r--r--python/google/protobuf/json_format.py93
-rwxr-xr-xpython/google/protobuf/message.py9
-rw-r--r--python/google/protobuf/message_factory.py11
-rw-r--r--python/google/protobuf/pyext/descriptor.cc22
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc59
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h35
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc78
-rw-r--r--python/google/protobuf/pyext/extension_dict.h43
-rw-r--r--python/google/protobuf/pyext/map_container.cc95
-rw-r--r--python/google/protobuf/pyext/map_container.h12
-rw-r--r--python/google/protobuf/pyext/message.cc212
-rw-r--r--python/google/protobuf/pyext/message.h33
-rw-r--r--python/google/protobuf/pyext/message_factory.cc214
-rw-r--r--python/google/protobuf/pyext/message_factory.h103
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc2
-rw-r--r--python/google/protobuf/symbol_database.py3
-rwxr-xr-xpython/google/protobuf/text_format.py16
33 files changed, 1303 insertions, 619 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index 873af306..e1f2e3b7 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -171,13 +171,6 @@ class _NestedDescriptorBase(DescriptorBase):
self._serialized_start = serialized_start
self._serialized_end = serialized_end
- def GetTopLevelContainingType(self):
- """Returns the root if this is a nested type, or itself if its the root."""
- desc = self
- while desc.containing_type is not None:
- desc = desc.containing_type
- return desc
-
def CopyToProto(self, proto):
"""Copies this to the matching proto in descriptor_pb2.
@@ -497,7 +490,7 @@ class FieldDescriptor(DescriptorBase):
def __new__(cls, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None):
+ has_default_value=True, containing_oneof=None, json_name=None):
_message.Message._CheckCalledFromGeneratedFile()
if is_extension:
return _message.default_pool.FindExtensionByName(full_name)
@@ -507,7 +500,7 @@ class FieldDescriptor(DescriptorBase):
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True, containing_oneof=None):
+ has_default_value=True, containing_oneof=None, json_name=None):
"""The arguments are as described in the description of FieldDescriptor
attributes above.
@@ -519,6 +512,10 @@ class FieldDescriptor(DescriptorBase):
self.name = name
self.full_name = full_name
self._camelcase_name = None
+ if json_name is None:
+ self.json_name = _ToJsonName(name)
+ else:
+ self.json_name = json_name
self.index = index
self.number = number
self.type = type
@@ -894,6 +891,31 @@ def _ToCamelCase(name):
return ''.join(result)
+def _OptionsOrNone(descriptor_proto):
+ """Returns the value of the field `options`, or None if it is not set."""
+ if descriptor_proto.HasField('options'):
+ return descriptor_proto.options
+ else:
+ return None
+
+
+def _ToJsonName(name):
+ """Converts name to Json name and returns it."""
+ capitalize_next = False
+ result = []
+
+ for c in name:
+ if c == '_':
+ capitalize_next = True
+ elif capitalize_next:
+ result.append(c.upper())
+ capitalize_next = False
+ else:
+ result += c
+
+ return ''.join(result)
+
+
def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
syntax=None):
"""Make a protobuf Descriptor given a DescriptorProto protobuf.
@@ -970,6 +992,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
full_name = '.'.join(full_message_name + [field_proto.name])
enum_desc = None
nested_desc = None
+ if field_proto.json_name:
+ json_name = field_proto.json_name
+ else:
+ json_name = None
if field_proto.HasField('type_name'):
type_name = field_proto.type_name
full_type_name = '.'.join(full_message_name +
@@ -984,10 +1010,11 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True,
field_proto.number, field_proto.type,
FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
field_proto.label, None, nested_desc, enum_desc, None, False, None,
- options=field_proto.options, has_default_value=False)
+ options=_OptionsOrNone(field_proto), has_default_value=False,
+ json_name=json_name)
fields.append(field)
desc_name = '.'.join(full_message_name)
return Descriptor(desc_proto.name, desc_name, None, None, fields,
list(nested_types.values()), list(enum_types.values()), [],
- options=desc_proto.options)
+ options=_OptionsOrNone(desc_proto))
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 5c055ab9..5f43ee5f 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -62,7 +62,7 @@ from google.protobuf import descriptor_database
from google.protobuf import text_encoding
-_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS
+_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
def _NormalizeFullyQualifiedName(name):
@@ -80,6 +80,14 @@ def _NormalizeFullyQualifiedName(name):
return name.lstrip('.')
+def _OptionsOrNone(descriptor_proto):
+ """Returns the value of the field `options`, or None if it is not set."""
+ if descriptor_proto.HasField('options'):
+ return descriptor_proto.options
+ else:
+ return None
+
+
class DescriptorPool(object):
"""A collection of protobufs dynamically constructed by descriptor protos."""
@@ -326,78 +334,61 @@ class DescriptorPool(object):
name=file_proto.name,
package=file_proto.package,
syntax=file_proto.syntax,
- options=file_proto.options,
+ options=_OptionsOrNone(file_proto),
serialized_pb=file_proto.SerializeToString(),
dependencies=direct_deps,
public_dependencies=public_deps)
- if _USE_C_DESCRIPTORS:
- # When using C++ descriptors, all objects defined in the file were added
- # to the C++ database when the FileDescriptor was built above.
- # Just add them to this descriptor pool.
- def _AddMessageDescriptor(message_desc):
- self._descriptors[message_desc.full_name] = message_desc
- for nested in message_desc.nested_types:
- _AddMessageDescriptor(nested)
- for enum_type in message_desc.enum_types:
- _AddEnumDescriptor(enum_type)
- def _AddEnumDescriptor(enum_desc):
- self._enum_descriptors[enum_desc.full_name] = enum_desc
- for message_type in file_descriptor.message_types_by_name.values():
- _AddMessageDescriptor(message_type)
- for enum_type in file_descriptor.enum_types_by_name.values():
- _AddEnumDescriptor(enum_type)
+ scope = {}
+
+ # This loop extracts all the message and enum types from all the
+ # dependencies of the file_proto. This is necessary to create the
+ # scope of available message types when defining the passed in
+ # file proto.
+ for dependency in built_deps:
+ scope.update(self._ExtractSymbols(
+ dependency.message_types_by_name.values()))
+ scope.update((_PrefixWithDot(enum.full_name), enum)
+ for enum in dependency.enum_types_by_name.values())
+
+ for message_type in file_proto.message_type:
+ message_desc = self._ConvertMessageDescriptor(
+ message_type, file_proto.package, file_descriptor, scope,
+ file_proto.syntax)
+ file_descriptor.message_types_by_name[message_desc.name] = (
+ message_desc)
+
+ for enum_type in file_proto.enum_type:
+ file_descriptor.enum_types_by_name[enum_type.name] = (
+ self._ConvertEnumDescriptor(enum_type, file_proto.package,
+ file_descriptor, None, scope))
+
+ for index, extension_proto in enumerate(file_proto.extension):
+ extension_desc = self._MakeFieldDescriptor(
+ extension_proto, file_proto.package, index, is_extension=True)
+ extension_desc.containing_type = self._GetTypeFromScope(
+ file_descriptor.package, extension_proto.extendee, scope)
+ self._SetFieldType(extension_proto, extension_desc,
+ file_descriptor.package, scope)
+ file_descriptor.extensions_by_name[extension_desc.name] = (
+ extension_desc)
+
+ for desc_proto in file_proto.message_type:
+ self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
+
+ if file_proto.package:
+ desc_proto_prefix = _PrefixWithDot(file_proto.package)
else:
- scope = {}
-
- # This loop extracts all the message and enum types from all the
- # dependencies of the file_proto. This is necessary to create the
- # scope of available message types when defining the passed in
- # file proto.
- for dependency in built_deps:
- scope.update(self._ExtractSymbols(
- dependency.message_types_by_name.values()))
- scope.update((_PrefixWithDot(enum.full_name), enum)
- for enum in dependency.enum_types_by_name.values())
-
- for message_type in file_proto.message_type:
- message_desc = self._ConvertMessageDescriptor(
- message_type, file_proto.package, file_descriptor, scope,
- file_proto.syntax)
- file_descriptor.message_types_by_name[message_desc.name] = (
- message_desc)
-
- for enum_type in file_proto.enum_type:
- file_descriptor.enum_types_by_name[enum_type.name] = (
- self._ConvertEnumDescriptor(enum_type, file_proto.package,
- file_descriptor, None, scope))
-
- for index, extension_proto in enumerate(file_proto.extension):
- extension_desc = self._MakeFieldDescriptor(
- extension_proto, file_proto.package, index, is_extension=True)
- extension_desc.containing_type = self._GetTypeFromScope(
- file_descriptor.package, extension_proto.extendee, scope)
- self._SetFieldType(extension_proto, extension_desc,
- file_descriptor.package, scope)
- file_descriptor.extensions_by_name[extension_desc.name] = (
- extension_desc)
-
- for desc_proto in file_proto.message_type:
- self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
-
- if file_proto.package:
- desc_proto_prefix = _PrefixWithDot(file_proto.package)
- else:
- desc_proto_prefix = ''
+ desc_proto_prefix = ''
- for desc_proto in file_proto.message_type:
- desc = self._GetTypeFromScope(
- desc_proto_prefix, desc_proto.name, scope)
- file_descriptor.message_types_by_name[desc_proto.name] = desc
+ for desc_proto in file_proto.message_type:
+ desc = self._GetTypeFromScope(
+ desc_proto_prefix, desc_proto.name, scope)
+ file_descriptor.message_types_by_name[desc_proto.name] = desc
- for index, service_proto in enumerate(file_proto.service):
- file_descriptor.services_by_name[service_proto.name] = (
- self._MakeServiceDescriptor(service_proto, index, scope,
- file_proto.package, file_descriptor))
+ for index, service_proto in enumerate(file_proto.service):
+ file_descriptor.services_by_name[service_proto.name] = (
+ self._MakeServiceDescriptor(service_proto, index, scope,
+ file_proto.package, file_descriptor))
self.Add(file_proto)
self._file_descriptors[file_proto.name] = file_descriptor
@@ -413,6 +404,7 @@ class DescriptorPool(object):
package: The package the proto should be located in.
file_desc: The file containing this message.
scope: Dict mapping short and full symbols to message and enum types.
+ syntax: string indicating syntax of the file ("proto2" or "proto3")
Returns:
The added descriptor.
@@ -463,7 +455,7 @@ class DescriptorPool(object):
nested_types=nested,
enum_types=enums,
extensions=extensions,
- options=desc_proto.options,
+ options=_OptionsOrNone(desc_proto),
is_extendable=is_extendable,
extension_ranges=extension_ranges,
file=file_desc,
@@ -517,7 +509,7 @@ class DescriptorPool(object):
file=file_desc,
values=values,
containing_type=containing_type,
- options=enum_proto.options)
+ options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
self._enum_descriptors[enum_name] = desc
return desc
@@ -562,7 +554,7 @@ class DescriptorPool(object):
default_value=None,
is_extension=is_extension,
extension_scope=None,
- options=field_proto.options)
+ options=_OptionsOrNone(field_proto))
def _SetAllFieldTypes(self, package, desc_proto, scope):
"""Sets all the descriptor's fields's types.
@@ -681,7 +673,7 @@ class DescriptorPool(object):
name=value_proto.name,
index=index,
number=value_proto.number,
- options=value_proto.options,
+ options=_OptionsOrNone(value_proto),
type=None)
def _MakeServiceDescriptor(self, service_proto, service_index, scope,
@@ -711,7 +703,7 @@ class DescriptorPool(object):
full_name=service_name,
index=service_index,
methods=methods,
- options=service_proto.options,
+ options=_OptionsOrNone(service_proto),
file=file_desc)
return desc
@@ -740,7 +732,7 @@ class DescriptorPool(object):
containing_service=None,
input_type=input_type,
output_type=output_type,
- options=method_proto.options)
+ options=_OptionsOrNone(method_proto))
def _ExtractSymbols(self, descriptors):
"""Pulls out all the symbols from descriptor protos.
diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto
index cd641ca0..76a7ebd6 100644
--- a/python/google/protobuf/internal/any_test.proto
+++ b/python/google/protobuf/internal/any_test.proto
@@ -30,13 +30,21 @@
// Author: jieluo@google.com (Jie Luo)
-syntax = "proto3";
+syntax = "proto2";
package google.protobuf.internal;
import "google/protobuf/any.proto";
message TestAny {
- google.protobuf.Any value = 1;
- int32 int_value = 2;
+ optional google.protobuf.Any value = 1;
+ optional int32 int_value = 2;
+ extensions 10 to max;
+}
+
+message TestAnyExtension1 {
+ extend TestAny {
+ optional TestAnyExtension1 extension1 = 98418603;
+ }
+ optional int32 i = 15;
}
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index ce46d08c..de13018e 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -436,9 +436,11 @@ class ScalarMap(MutableMapping):
"""Simple, type-checked, dict-like container for holding repeated scalars."""
# Disallows assignment to other attributes.
- __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener']
+ __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener',
+ '_entry_descriptor']
- def __init__(self, message_listener, key_checker, value_checker):
+ def __init__(self, message_listener, key_checker, value_checker,
+ entry_descriptor):
"""
Args:
message_listener: A MessageListener implementation.
@@ -448,10 +450,12 @@ class ScalarMap(MutableMapping):
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._key_checker = key_checker
self._value_checker = value_checker
+ self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key):
@@ -513,6 +517,9 @@ class ScalarMap(MutableMapping):
self._values.clear()
self._message_listener.Modified()
+ def GetEntryClass(self):
+ return self._entry_descriptor._concrete_class
+
class MessageMap(MutableMapping):
@@ -520,9 +527,10 @@ class MessageMap(MutableMapping):
# Disallows assignment to other attributes.
__slots__ = ['_key_checker', '_values', '_message_listener',
- '_message_descriptor']
+ '_message_descriptor', '_entry_descriptor']
- def __init__(self, message_listener, message_descriptor, key_checker):
+ def __init__(self, message_listener, message_descriptor, key_checker,
+ entry_descriptor):
"""
Args:
message_listener: A MessageListener implementation.
@@ -532,10 +540,12 @@ class MessageMap(MutableMapping):
inserted into this container.
value_checker: A type_checkers.ValueChecker instance to run on values
inserted into this container.
+ entry_descriptor: The MessageDescriptor of a map entry: key and value.
"""
self._message_listener = message_listener
self._message_descriptor = message_descriptor
self._key_checker = key_checker
+ self._entry_descriptor = entry_descriptor
self._values = {}
def __getitem__(self, key):
@@ -613,3 +623,6 @@ class MessageMap(MutableMapping):
def clear(self):
self._values.clear()
self._message_listener.Modified()
+
+ def GetEntryClass(self):
+ return self._entry_descriptor._concrete_class
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 3c8c7935..d4de2d81 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -119,6 +119,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertEqual('google.protobuf.python.internal.Factory1Message',
msg1.full_name)
self.assertEqual(None, msg1.containing_type)
+ self.assertFalse(msg1.has_options)
nested_msg1 = msg1.nested_types[0]
self.assertEqual('NestedFactory1Message', nested_msg1.name)
@@ -202,6 +203,7 @@ class DescriptorPoolTest(unittest.TestCase):
self.assertIsInstance(enum1, descriptor.EnumDescriptor)
self.assertEqual(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
self.assertEqual(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
+ self.assertFalse(enum1.has_options)
nested_enum1 = self.pool.FindEnumTypeByName(
'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum')
@@ -234,6 +236,8 @@ class DescriptorPoolTest(unittest.TestCase):
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
self.assertEqual(field.label, field.LABEL_REPEATED)
+ self.assertFalse(field.has_options)
+
with self.assertRaises(KeyError):
self.pool.FindFieldByName('Does not exist')
@@ -448,6 +452,7 @@ class EnumField(object):
test.assertTrue(field_desc.has_default_value)
test.assertEqual(enum_desc.values_by_name[self.default_value].number,
field_desc.default_value)
+ test.assertFalse(enum_desc.values_by_name[self.default_value].has_options)
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(enum_desc, field_desc.enum_type)
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 623198c8..1f148ab9 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -766,6 +766,8 @@ class MakeDescriptorTest(unittest.TestCase):
'Foo2.Sub.bar_field')
self.assertEqual(result.nested_types[0].fields[0].enum_type,
result.nested_types[0].enum_types[0])
+ self.assertFalse(result.has_options)
+ self.assertFalse(result.fields[0].has_options)
def testMakeDescriptorWithUnsignedIntField(self):
file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
@@ -818,6 +820,23 @@ class MakeDescriptorTest(unittest.TestCase):
self.assertEqual(result.fields[index].camelcase_name,
camelcase_names[index])
+ def testJsonName(self):
+ descriptor_proto = descriptor_pb2.DescriptorProto()
+ descriptor_proto.name = 'TestJsonName'
+ names = ['field_name', 'fieldName', 'FieldName',
+ '_field_name', 'FIELD_NAME', 'json_name']
+ json_names = ['fieldName', 'fieldName', 'FieldName',
+ 'FieldName', 'FIELDNAME', '@type']
+ for index in range(len(names)):
+ field = descriptor_proto.field.add()
+ field.number = index + 1
+ field.name = names[index]
+ field.json_name = '@type'
+ result = descriptor.MakeDescriptor(descriptor_proto)
+ for index in range(len(json_names)):
+ self.assertEqual(result.fields[index].json_name,
+ json_names[index])
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index 83ea5f50..7f13f9da 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -227,7 +227,8 @@ class GeneratorTest(unittest.TestCase):
[unittest_import_pb2.DESCRIPTOR])
self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies,
[unittest_import_public_pb2.DESCRIPTOR])
-
+ self.assertEqual(unittest_import_pb2.DESCRIPTOR.public_dependencies,
+ [unittest_import_public_pb2.DESCRIPTOR])
def testNoGenericServices(self):
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py
index a5ee8ace..5ed65622 100644
--- a/python/google/protobuf/internal/json_format_test.py
+++ b/python/google/protobuf/internal/json_format_test.py
@@ -205,6 +205,15 @@ class JsonFormatTest(JsonFormatBase):
parsed_message = json_format_proto3_pb2.TestMessage()
self.CheckParseBack(message, parsed_message)
+ def testIntegersRepresentedAsFloat(self):
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse('{"int32Value": -2.147483648e9}', message)
+ self.assertEqual(message.int32_value, -2147483648)
+ json_format.Parse('{"int32Value": 1e5}', message)
+ self.assertEqual(message.int32_value, 100000)
+ json_format.Parse('{"int32Value": 1.0}', message)
+ self.assertEqual(message.int32_value, 1)
+
def testMapFields(self):
message = json_format_proto3_pb2.TestMap()
message.bool_map[True] = 1
@@ -428,6 +437,9 @@ class JsonFormatTest(JsonFormatBase):
' "value": "hello",'
' "repeatedValue": [11.1, false, null, null]'
'}'))
+ message.Clear()
+ json_format.Parse('{"value": null}', message)
+ self.assertEqual(message.value.WhichOneof('kind'), 'null_value')
def testListValueMessage(self):
message = json_format_proto3_pb2.TestListValue()
@@ -600,6 +612,11 @@ class JsonFormatTest(JsonFormatBase):
'}',
parsed_message)
self.assertEqual(message, parsed_message)
+ # Null and {} should have different behavior for sub message.
+ self.assertFalse(parsed_message.HasField('message_value'))
+ json_format.Parse('{"messageValue": {}}', parsed_message)
+ self.assertTrue(parsed_message.HasField('message_value'))
+ # Null is not allowed to be used as an element in repeated field.
self.assertRaisesRegexp(
json_format.ParseError,
'Failed to parse repeatedInt32Value field: '
@@ -621,15 +638,16 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('',
r'Failed to load JSON: (Expecting value)|(No JSON).')
- def testParseBadEnumValue(self):
- self.CheckError(
- '{"enumValue": 1}',
- 'Enum value must be a string literal with double quotes. '
- 'Type "proto3.EnumType" has no value named 1.')
+ def testParseEnumValue(self):
+ message = json_format_proto3_pb2.TestMessage()
+ text = '{"enumValue": 0}'
+ json_format.Parse(text, message)
+ text = '{"enumValue": 1}'
+ json_format.Parse(text, message)
self.CheckError(
'{"enumValue": "baz"}',
- 'Enum value must be a string literal with double quotes. '
- 'Type "proto3.EnumType" has no value named baz.')
+ 'Failed to parse enumValue field: Invalid enum value baz '
+ 'for enum type proto3.EnumType.')
def testParseBadIdentifer(self):
self.CheckError('{int32Value: 1}',
@@ -672,12 +690,12 @@ class JsonFormatTest(JsonFormatBase):
text = '{"int32Value": 0x12345}'
self.assertRaises(json_format.ParseError,
json_format.Parse, text, message)
+ self.CheckError('{"int32Value": 1.5}',
+ 'Failed to parse int32Value field: '
+ 'Couldn\'t parse integer: 1.5.')
self.CheckError('{"int32Value": 012345}',
(r'Failed to load JSON: Expecting \'?,\'? delimiter: '
r'line 1.'))
- self.CheckError('{"int32Value": 1.0}',
- 'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 1.0.')
self.CheckError('{"int32Value": " 1 "}',
'Failed to parse int32Value field: '
'Couldn\'t parse integer: " 1 ".')
@@ -687,9 +705,6 @@ class JsonFormatTest(JsonFormatBase):
self.CheckError('{"int32Value": 12345678901234567890}',
'Failed to parse int32Value field: Value out of range: '
'12345678901234567890.')
- self.CheckError('{"int32Value": 1e5}',
- 'Failed to parse int32Value field: '
- 'Couldn\'t parse integer: 100000.0.')
self.CheckError('{"uint32Value": -1}',
'Failed to parse uint32Value field: '
'Value out of range: -1.')
@@ -810,6 +825,43 @@ class JsonFormatTest(JsonFormatBase):
r'"value": 1234}')
json_format.Parse(text, message)
+ def testPreservingProtoFieldNames(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ self.assertEqual('{\n "int32Value": 12345\n}',
+ json_format.MessageToJson(message))
+ self.assertEqual('{\n "int32_value": 12345\n}',
+ json_format.MessageToJson(message, False, True))
+
+ # Parsers accept both original proto field names and lowerCamelCase names.
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.Parse('{"int32Value": 54321}', message)
+ self.assertEqual(54321, message.int32_value)
+ json_format.Parse('{"int32_value": 12345}', message)
+ self.assertEqual(12345, message.int32_value)
+
+ def testParseDict(self):
+ expected = 12345
+ js_dict = {'int32Value': expected}
+ message = json_format_proto3_pb2.TestMessage()
+ json_format.ParseDict(js_dict, message)
+ self.assertEqual(expected, message.int32_value)
+
+ def testMessageToDict(self):
+ message = json_format_proto3_pb2.TestMessage()
+ message.int32_value = 12345
+ expected = {'int32Value': 12345}
+ self.assertEqual(expected,
+ json_format.MessageToDict(message))
+
+ def testJsonName(self):
+ message = json_format_proto3_pb2.TestCustomJsonName()
+ message.value = 12345
+ self.assertEqual('{\n "@value": 12345\n}',
+ json_format.MessageToJson(message))
+ parsed_message = json_format_proto3_pb2.TestCustomJsonName()
+ self.CheckParseBack(message, parsed_message)
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 1e95adf9..9986c0d9 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -67,6 +67,7 @@ from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf import message
from google.protobuf.internal import _parameterized
@@ -88,10 +89,13 @@ def IsNegInf(val):
return isinf(val) and (val < 0)
-@_parameterized.Parameters(
- (unittest_pb2),
- (unittest_proto3_arena_pb2))
-class MessageTest(unittest.TestCase):
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
+@_parameterized.NamedParameters(
+ ('_proto2', unittest_pb2),
+ ('_proto3', unittest_proto3_arena_pb2))
+class MessageTest(BaseTestCase):
def testBadUtf8String(self, message_module):
if api_implementation.Type() != 'python':
@@ -957,7 +961,7 @@ class MessageTest(unittest.TestCase):
# Class to test proto2-only features (required, extensions, etc.)
-class Proto2Test(unittest.TestCase):
+class Proto2Test(BaseTestCase):
def testFieldPresence(self):
message = unittest_pb2.TestAllTypes()
@@ -1113,6 +1117,7 @@ class Proto2Test(unittest.TestCase):
optional_bytes=b'x',
optionalgroup={'a': 400},
optional_nested_message={'bb': 500},
+ optional_foreign_message={},
optional_nested_enum='BAZ',
repeatedgroup=[{'a': 600},
{'a': 700}],
@@ -1125,8 +1130,12 @@ class Proto2Test(unittest.TestCase):
self.assertEqual(300.5, message.optional_float)
self.assertEqual(b'x', message.optional_bytes)
self.assertEqual(400, message.optionalgroup.a)
- self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
+ self.assertIsInstance(message.optional_nested_message,
+ unittest_pb2.TestAllTypes.NestedMessage)
self.assertEqual(500, message.optional_nested_message.bb)
+ self.assertTrue(message.HasField('optional_foreign_message'))
+ self.assertEqual(message.optional_foreign_message,
+ unittest_pb2.ForeignMessage())
self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.optional_nested_enum)
self.assertEqual(2, len(message.repeatedgroup))
@@ -1164,7 +1173,7 @@ class Proto2Test(unittest.TestCase):
# Class to test proto3-only features/behavior (updated field presence & enums)
-class Proto3Test(unittest.TestCase):
+class Proto3Test(BaseTestCase):
# Utility method for comparing equality with a map.
def assertMapIterEquals(self, map_iter, dict_value):
@@ -1720,7 +1729,7 @@ class Proto3Test(unittest.TestCase):
-class ValidTypeNamesTest(unittest.TestCase):
+class ValidTypeNamesTest(BaseTestCase):
def assertImportFromName(self, msg, base_name):
# Parse <type 'module.class_name'> to extra 'some.name' as a string.
@@ -1741,7 +1750,7 @@ class ValidTypeNamesTest(unittest.TestCase):
self.assertImportFromName(pb.repeated_int32, 'Scalar')
self.assertImportFromName(pb.repeated_nested_message, 'Composite')
-class PackedFieldTest(unittest.TestCase):
+class PackedFieldTest(BaseTestCase):
def setMessage(self, message):
message.repeated_int32.append(1)
@@ -1800,10 +1809,14 @@ class PackedFieldTest(unittest.TestCase):
@unittest.skipIf(api_implementation.Type() != 'cpp',
'explicit tests of the C++ implementation')
-class OversizeProtosTest(unittest.TestCase):
-
- def setUp(self):
- self.file_desc = """
+class OversizeProtosTest(BaseTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # At the moment, reference cycles between DescriptorPool and Message classes
+ # are not detected and these objects are never freed.
+ # To avoid errors with ReferenceLeakChecker, we create the class only once.
+ file_desc = """
name: "f/f.msg2"
package: "f"
message_type {
@@ -1828,10 +1841,12 @@ class OversizeProtosTest(unittest.TestCase):
"""
pool = descriptor_pool.DescriptorPool()
desc = descriptor_pb2.FileDescriptorProto()
- text_format.Parse(self.file_desc, desc)
+ text_format.Parse(file_desc, desc)
pool.Add(desc)
- self.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
+ cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
pool.FindMessageTypeByName('f.msg2'))
+
+ def setUp(self):
self.p = self.proto_cls()
self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
self.p_serialized = self.p.SerializeToString()
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index c0d0ad45..60b4baad 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -380,13 +380,15 @@ def _GetInitializeDefaultForMap(field):
if _IsMessageMapField(field):
def MakeMessageMapDefault(message):
return containers.MessageMap(
- message._listener_for_children, value_field.message_type, key_checker)
+ message._listener_for_children, value_field.message_type, key_checker,
+ field.message_type)
return MakeMessageMapDefault
else:
value_checker = type_checkers.GetTypeChecker(value_field)
def MakePrimitiveMapDefault(message):
return containers.ScalarMap(
- message._listener_for_children, key_checker, value_checker)
+ message._listener_for_children, key_checker, value_checker,
+ field.message_type)
return MakePrimitiveMapDefault
def _DefaultValueConstructorForField(field):
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 6f3b818a..dad79c37 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -60,9 +60,13 @@ from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import wire_format
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import decoder
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
class _MiniDecoder(object):
"""Decodes a stream of values from a string.
@@ -108,7 +112,7 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(unittest.TestCase):
+class ReflectionTest(BaseTestCase):
def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
@@ -1552,6 +1556,20 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.HasField('optional_foreign_message'))
self.assertEqual(0, proto.optional_foreign_message.c)
+ def testDisconnectingInOneof(self):
+ m = unittest_pb2.TestOneof2() # This message has two messages in a oneof.
+ m.foo_message.qux_int = 5
+ sub_message = m.foo_message
+ # Accessing another message's field does not clear the first one
+ self.assertEqual(m.foo_lazy_message.qux_int, 0)
+ self.assertEqual(m.foo_message.qux_int, 5)
+ # But mutating another message in the oneof detaches the first one.
+ m.foo_lazy_message.qux_int = 6
+ self.assertEqual(m.foo_message.qux_int, 0)
+ # The reference we got above was detached and is still valid.
+ self.assertEqual(sub_message.qux_int, 5)
+ sub_message.qux_int = 7
+
def testOneOf(self):
proto = unittest_pb2.TestAllTypes()
proto.oneof_uint32 = 10
@@ -1810,7 +1828,7 @@ class ReflectionTest(unittest.TestCase):
# into separate TestCase classes.
-class TestAllTypesEqualityTest(unittest.TestCase):
+class TestAllTypesEqualityTest(BaseTestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
@@ -1826,7 +1844,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(unittest.TestCase):
+class FullProtosEqualityTest(BaseTestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1912,7 +1930,7 @@ class FullProtosEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(unittest.TestCase):
+class ExtensionEqualityTest(BaseTestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1945,7 +1963,7 @@ class ExtensionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(unittest.TestCase):
+class MutualRecursionEqualityTest(BaseTestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1957,7 +1975,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(unittest.TestCase):
+class ByteSizeTest(BaseTestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -2253,7 +2271,7 @@ class ByteSizeTest(unittest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(unittest.TestCase):
+class SerializationTest(BaseTestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
@@ -2814,7 +2832,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(unittest.TestCase):
+class OptionsTest(BaseTestCase):
def testMessageOptions(self):
proto = message_set_extensions_pb2.TestMessageSet()
@@ -2841,7 +2859,7 @@ class OptionsTest(unittest.TestCase):
-class ClassAPITest(unittest.TestCase):
+class ClassAPITest(BaseTestCase):
@unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
@@ -2924,6 +2942,9 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(file_descriptor_str, file_descriptor)
return file_descriptor.SerializeToString()
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
+ # This test can only run once; the second time, it raises errors about
+ # conflicting message descriptors.
def testParsingFlatClassWithExplicitClassDeclaration(self):
"""Test that the generated class can parse a flat message."""
# TODO(xiaofeng): This test fails with cpp implemetnation in the call
@@ -2948,6 +2969,7 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(msg_str, msg)
self.assertEqual(msg.flat, [0, 1, 2])
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testParsingFlatClass(self):
"""Test that the generated class can parse a flat message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
@@ -2963,6 +2985,7 @@ class ClassAPITest(unittest.TestCase):
text_format.Merge(msg_str, msg)
self.assertEqual(msg.flat, [0, 1, 2])
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
diff --git a/python/google/protobuf/internal/testing_refleaks.py b/python/google/protobuf/internal/testing_refleaks.py
new file mode 100644
index 00000000..b2787901
--- /dev/null
+++ b/python/google/protobuf/internal/testing_refleaks.py
@@ -0,0 +1,124 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""A subclass of unittest.TestCase which checks for reference leaks.
+
+To use:
+- Use testing_refleak.BaseTestCase instead of unittest.TestCase
+- Configure and compile Python with --with-pydebug
+
+If sys.gettotalrefcount() is not available (because Python was built without
+the Py_DEBUG option), then this module is a no-op and tests will run normally.
+"""
+
+import copy_reg
+import gc
+import sys
+
+try:
+ import unittest2 as unittest #PY26
+except ImportError:
+ import unittest
+
+
+class LocalTestResult(unittest.TestResult):
+ """A TestResult which forwards events to a parent object, except for Skips."""
+
+ def __init__(self, parent_result):
+ unittest.TestResult.__init__(self)
+ self.parent_result = parent_result
+
+ def addError(self, test, error):
+ self.parent_result.addError(test, error)
+
+ def addFailure(self, test, error):
+ self.parent_result.addFailure(test, error)
+
+ def addSkip(self, test, reason):
+ pass
+
+
+class ReferenceLeakCheckerTestCase(unittest.TestCase):
+ """A TestCase which runs tests multiple times, collecting reference counts."""
+
+ NB_RUNS = 3
+
+ def run(self, result=None):
+ # python_message.py registers all Message classes to some pickle global
+ # registry, which makes the classes immortal.
+ # We save a copy of this registry, and reset it before we could references.
+ self._saved_pickle_registry = copy_reg.dispatch_table.copy()
+
+ # Run the test twice, to warm up the instance attributes.
+ super(ReferenceLeakCheckerTestCase, self).run(result=result)
+ super(ReferenceLeakCheckerTestCase, self).run(result=result)
+
+ oldrefcount = 0
+ local_result = LocalTestResult(result)
+
+ refcount_deltas = []
+ for _ in range(self.NB_RUNS):
+ oldrefcount = self._getRefcounts()
+ super(ReferenceLeakCheckerTestCase, self).run(result=local_result)
+ newrefcount = self._getRefcounts()
+ refcount_deltas.append(newrefcount - oldrefcount)
+ print refcount_deltas, self
+
+ try:
+ self.assertEqual(refcount_deltas, [0] * self.NB_RUNS)
+ except Exception: # pylint: disable=broad-except
+ result.addError(self, sys.exc_info())
+
+ def _getRefcounts(self):
+ copy_reg.dispatch_table.clear()
+ copy_reg.dispatch_table.update(self._saved_pickle_registry)
+ # It is sometimes necessary to gc.collect() multiple times, to ensure
+ # that all objects can be collected.
+ gc.collect()
+ gc.collect()
+ gc.collect()
+ return sys.gettotalrefcount()
+
+
+if hasattr(sys, 'gettotalrefcount'):
+ BaseTestCase = ReferenceLeakCheckerTestCase
+ SkipReferenceLeakChecker = unittest.skip
+
+else:
+ # When PyDEBUG is not enabled, run the tests normally.
+ BaseTestCase = unittest.TestCase
+
+ def SkipReferenceLeakChecker(reason):
+ del reason # Don't skip, so don't need a reason.
+ def Same(func):
+ return func
+ return Same
+
+
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 0e38e0e9..ab481ab4 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -52,6 +52,7 @@ from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
from google.protobuf.internal import api_implementation
+from google.protobuf.internal import any_test_pb2 as test_extend_any
from google.protobuf.internal import test_util
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf import descriptor_pool
@@ -684,6 +685,21 @@ class Proto2Tests(TextFormatBase):
self.assertEqual(23, message.message_set.Extensions[ext1].i)
self.assertEqual('foo', message.message_set.Extensions[ext2].str)
+ def testExtensionInsideAnyMessage(self):
+ message = test_extend_any.TestAny()
+ text = ('value {\n'
+ ' [type.googleapis.com/google.protobuf.internal.TestAny] {\n'
+ ' [google.protobuf.internal.TestAnyExtension1.extension1] {\n'
+ ' i: 10\n'
+ ' }\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
+ self.CompareToGoldenText(
+ text_format.MessageToString(
+ message, descriptor_pool=descriptor_pool.Default()),
+ text)
+
def testParseMessageByFieldNumber(self):
message = unittest_pb2.TestAllTypes()
text = ('34: 1\n' 'repeated_uint64: 2\n')
@@ -1184,7 +1200,8 @@ class TokenizerTest(unittest.TestCase):
'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
- 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ')
+ 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f '
+ 'False_bool: False True_bool: True')
tokenizer = text_format.Tokenizer(text.splitlines())
methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':',
(tokenizer.ConsumeString, 'string1'),
@@ -1228,7 +1245,11 @@ class TokenizerTest(unittest.TestCase):
(tokenizer.ConsumeIdentifier, 'true_bool1'), ':',
(tokenizer.ConsumeBool, True),
(tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':',
- (tokenizer.ConsumeBool, False)]
+ (tokenizer.ConsumeBool, False),
+ (tokenizer.ConsumeIdentifier, 'False_bool'), ':',
+ (tokenizer.ConsumeBool, False),
+ (tokenizer.ConsumeIdentifier, 'True_bool'), ':',
+ (tokenizer.ConsumeBool, True)]
i = 0
while not tokenizer.AtEnd():
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 84073f1c..d614eaa8 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -47,16 +47,20 @@ from google.protobuf.internal import encoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
+from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
+BaseTestCase = testing_refleaks.BaseTestCase
+
+
def SkipIfCppImplementation(func):
return unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation does not expose unknown fields to Python')(func)
-class UnknownFieldsTest(unittest.TestCase):
+class UnknownFieldsTest(BaseTestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -140,7 +144,7 @@ class UnknownFieldsTest(unittest.TestCase):
b'', message.repeated_nested_message[0].SerializeToString())
-class UnknownFieldsAccessorsTest(unittest.TestCase):
+class UnknownFieldsAccessorsTest(BaseTestCase):
def setUp(self):
self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -149,21 +153,18 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.all_fields_data = self.all_fields.SerializeToString()
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- if api_implementation.Type() != 'cpp':
- # _unknown_fields is an implementation detail.
- self.unknown_fields = self.empty_message._unknown_fields
- # All the tests that use GetField() check an implementation detail of the
- # Python implementation, which stores unknown fields as serialized strings.
- # These tests are skipped by the C++ implementation: it's enough to check that
- # the message is correctly serialized.
+ # GetUnknownField() checks a detail of the Python implementation, which stores
+ # unknown fields as serialized strings. It cannot be used by the C++
+ # implementation: it's enough to check that the message is correctly
+ # serialized.
- def GetField(self, name):
+ def GetUnknownField(self, name):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
- for tag_bytes, value in self.unknown_fields:
+ for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
decoder(value, 0, len(value), self.all_fields, result_dict)
@@ -171,37 +172,37 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
@SkipIfCppImplementation
def testEnum(self):
- value = self.GetField('optional_nested_enum')
+ value = self.GetUnknownField('optional_nested_enum')
self.assertEqual(self.all_fields.optional_nested_enum, value)
@SkipIfCppImplementation
def testRepeatedEnum(self):
- value = self.GetField('repeated_nested_enum')
+ value = self.GetUnknownField('repeated_nested_enum')
self.assertEqual(self.all_fields.repeated_nested_enum, value)
@SkipIfCppImplementation
def testVarint(self):
- value = self.GetField('optional_int32')
+ value = self.GetUnknownField('optional_int32')
self.assertEqual(self.all_fields.optional_int32, value)
@SkipIfCppImplementation
def testFixed32(self):
- value = self.GetField('optional_fixed32')
+ value = self.GetUnknownField('optional_fixed32')
self.assertEqual(self.all_fields.optional_fixed32, value)
@SkipIfCppImplementation
def testFixed64(self):
- value = self.GetField('optional_fixed64')
+ value = self.GetUnknownField('optional_fixed64')
self.assertEqual(self.all_fields.optional_fixed64, value)
@SkipIfCppImplementation
def testLengthDelimited(self):
- value = self.GetField('optional_string')
+ value = self.GetUnknownField('optional_string')
self.assertEqual(self.all_fields.optional_string, value)
@SkipIfCppImplementation
def testGroup(self):
- value = self.GetField('optionalgroup')
+ value = self.GetUnknownField('optionalgroup')
self.assertEqual(self.all_fields.optionalgroup, value)
def testCopyFrom(self):
@@ -241,43 +242,41 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
self.assertEqual(message.SerializeToString(), self.all_fields_data)
-class UnknownEnumValuesTest(unittest.TestCase):
+class UnknownEnumValuesTest(BaseTestCase):
def setUp(self):
self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
self.message = missing_enum_values_pb2.TestEnumValues()
+ # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
self.message.optional_nested_enum = (
- missing_enum_values_pb2.TestEnumValues.ZERO)
+ missing_enum_values_pb2.TestEnumValues.ZERO)
self.message.repeated_nested_enum.extend([
- missing_enum_values_pb2.TestEnumValues.ZERO,
- missing_enum_values_pb2.TestEnumValues.ONE,
- ])
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
self.message.packed_nested_enum.extend([
- missing_enum_values_pb2.TestEnumValues.ZERO,
- missing_enum_values_pb2.TestEnumValues.ONE,
- ])
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
self.message_data = self.message.SerializeToString()
self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
self.missing_message.ParseFromString(self.message_data)
- if api_implementation.Type() != 'cpp':
- # _unknown_fields is an implementation detail.
- self.unknown_fields = self.missing_message._unknown_fields
- # All the tests that use GetField() check an implementation detail of the
- # Python implementation, which stores unknown fields as serialized strings.
- # These tests are skipped by the C++ implementation: it's enough to check that
- # the message is correctly serialized.
+ # GetUnknownField() checks a detail of the Python implementation, which stores
+ # unknown fields as serialized strings. It cannot be used by the C++
+ # implementation: it's enough to check that the message is correctly
+ # serialized.
- def GetField(self, name):
+ def GetUnknownField(self, name):
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
result_dict = {}
- for tag_bytes, value in self.unknown_fields:
+ for tag_bytes, value in self.missing_message._unknown_fields:
if tag_bytes == field_tag:
decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
- tag_bytes][0]
+ tag_bytes][0]
decoder(value, 0, len(value), self.message, result_dict)
return result_dict[field_descriptor]
@@ -294,21 +293,39 @@ class UnknownEnumValuesTest(unittest.TestCase):
# default value.
self.assertEqual(missing.optional_nested_enum, 0)
- @SkipIfCppImplementation
def testUnknownEnumValue(self):
+ if api_implementation.Type() == 'cpp':
+ # The CPP implementation of protos (wrongly) allows unknown enum values
+ # for proto2.
+ self.assertTrue(self.missing_message.HasField('optional_nested_enum'))
+ self.assertEqual(self.message.optional_nested_enum,
+ self.missing_message.optional_nested_enum)
+ else:
+ # On the other hand, the Python implementation considers unknown values
+ # as unknown fields. This is the correct behavior.
+ self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
+ value = self.GetUnknownField('optional_nested_enum')
+ self.assertEqual(self.message.optional_nested_enum, value)
+ self.missing_message.ClearField('optional_nested_enum')
self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
- value = self.GetField('optional_nested_enum')
- self.assertEqual(self.message.optional_nested_enum, value)
- @SkipIfCppImplementation
def testUnknownRepeatedEnumValue(self):
- value = self.GetField('repeated_nested_enum')
- self.assertEqual(self.message.repeated_nested_enum, value)
+ if api_implementation.Type() == 'cpp':
+ # For repeated enums, both implementations agree.
+ self.assertEqual([], self.missing_message.repeated_nested_enum)
+ else:
+ self.assertEqual([], self.missing_message.repeated_nested_enum)
+ value = self.GetUnknownField('repeated_nested_enum')
+ self.assertEqual(self.message.repeated_nested_enum, value)
- @SkipIfCppImplementation
def testUnknownPackedEnumValue(self):
- value = self.GetField('packed_nested_enum')
- self.assertEqual(self.message.packed_nested_enum, value)
+ if api_implementation.Type() == 'cpp':
+ # For repeated enums, both implementations agree.
+ self.assertEqual([], self.missing_message.packed_nested_enum)
+ else:
+ self.assertEqual([], self.missing_message.packed_nested_enum)
+ value = self.GetUnknownField('packed_nested_enum')
+ self.assertEqual(self.message.packed_nested_enum, value)
def testRoundTrip(self):
new_message = missing_enum_values_pb2.TestEnumValues()
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index 7c5dffd0..d631abee 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -53,6 +53,7 @@ _NANOS_PER_MICROSECOND = 1000
_MILLIS_PER_SECOND = 1000
_MICROS_PER_SECOND = 1000000
_SECONDS_PER_DAY = 24 * 3600
+_DURATION_SECONDS_MAX = 315576000000
class Error(Exception):
@@ -247,6 +248,7 @@ class Duration(object):
represent the exact Duration value. For example: "1s", "1.010s",
"1.000000100s", "-3.100s"
"""
+ _CheckDurationValid(self.seconds, self.nanos)
if self.seconds < 0 or self.nanos < 0:
result = '-'
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
@@ -286,14 +288,17 @@ class Duration(object):
try:
pos = value.find('.')
if pos == -1:
- self.seconds = int(value[:-1])
- self.nanos = 0
+ seconds = int(value[:-1])
+ nanos = 0
else:
- self.seconds = int(value[:pos])
+ seconds = int(value[:pos])
if value[0] == '-':
- self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
else:
- self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
+ _CheckDurationValid(seconds, nanos)
+ self.seconds = seconds
+ self.nanos = nanos
except ValueError:
raise ParseError(
'Couldn\'t parse duration: {0}.'.format(value))
@@ -359,6 +364,17 @@ class Duration(object):
self.nanos = nanos
+def _CheckDurationValid(seconds, nanos):
+ if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
+ raise Error(
+ 'Duration is not valid: Seconds {0} must be in range '
+ '[-315576000000, 315576000000].'.format(seconds))
+ if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
+ raise Error(
+ 'Duration is not valid: Nanos {0} must be in range '
+ '[-999999999, 999999999].'.format(nanos))
+
+
def _RoundTowardZero(value, divider):
"""Truncates the remainder part after division."""
# For some languanges, the sign of the remainder is implementation
@@ -379,13 +395,16 @@ class FieldMask(object):
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
- return ','.join(self.paths)
+ camelcase_paths = []
+ for path in self.paths:
+ camelcase_paths.append(_SnakeCaseToCamelCase(path))
+ return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
self.Clear()
for path in value.split(','):
- self.paths.append(path)
+ self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
@@ -472,6 +491,48 @@ def _CheckFieldMaskMessage(message):
message_descriptor.full_name))
+def _SnakeCaseToCamelCase(path_name):
+ """Converts a path name from snake_case to camelCase."""
+ result = []
+ after_underscore = False
+ for c in path_name:
+ if c.isupper():
+ raise Error('Fail to print FieldMask to Json string: Path name '
+ '{0} must not contain uppercase letters.'.format(path_name))
+ if after_underscore:
+ if c.islower():
+ result.append(c.upper())
+ after_underscore = False
+ else:
+ raise Error('Fail to print FieldMask to Json string: The '
+ 'character after a "_" must be a lowercase letter '
+ 'in path name {0}.'.format(path_name))
+ elif c == '_':
+ after_underscore = True
+ else:
+ result += c
+
+ if after_underscore:
+ raise Error('Fail to print FieldMask to Json string: Trailing "_" '
+ 'in path name {0}.'.format(path_name))
+ return ''.join(result)
+
+
+def _CamelCaseToSnakeCase(path_name):
+ """Converts a field name from camelCase to snake_case."""
+ result = []
+ for c in path_name:
+ if c == '_':
+ raise ParseError('Fail to parse FieldMask: Path name '
+ '{0} must not contain "_"s.'.format(path_name))
+ if c.isupper():
+ result += '_'
+ result += c.lower()
+ else:
+ result += c
+ return ''.join(result)
+
+
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 2f32ac99..077f630f 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -303,6 +303,25 @@ class TimeUtilTest(TimeUtilTestBase):
well_known_types.ParseError,
'Couldn\'t parse duration: 1...2s.',
message.FromJsonString, '1...2s')
+ text = '-315576000001.000000000s'
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds -315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.FromJsonString, text)
+ text = '315576000001.000000000s'
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds 315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.FromJsonString, text)
+ message.seconds = -315576000001
+ message.nanos = 0
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ r'Duration is not valid\: Seconds -315576000001 must be in range'
+ r' \[-315576000000\, 315576000000\].',
+ message.ToJsonString)
class FieldMaskTest(unittest.TestCase):
@@ -322,6 +341,20 @@ class FieldMaskTest(unittest.TestCase):
mask.FromJsonString('foo,bar')
self.assertEqual(['foo', 'bar'], mask.paths)
+ # Test camel case
+ mask.Clear()
+ mask.paths.append('foo_bar')
+ self.assertEqual('fooBar', mask.ToJsonString())
+ mask.paths.append('bar_quz')
+ self.assertEqual('fooBar,barQuz', mask.ToJsonString())
+
+ mask.FromJsonString('')
+ self.assertEqual('', mask.ToJsonString())
+ mask.FromJsonString('fooBar')
+ self.assertEqual(['foo_bar'], mask.paths)
+ mask.FromJsonString('fooBar,barQuz')
+ self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
+
def testDescriptorToFieldMask(self):
mask = field_mask_pb2.FieldMask()
msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -502,17 +535,68 @@ class FieldMaskTest(unittest.TestCase):
nested_src.payload.repeated_int32.append(1234)
nested_dst.payload.repeated_int32.append(5678)
# Repeated fields will be appended by default.
- mask.FromJsonString('payload.repeated_int32')
+ mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst)
self.assertEqual(2, len(nested_dst.payload.repeated_int32))
self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
# Change the behavior to replace repeated fields.
- mask.FromJsonString('payload.repeated_int32')
+ mask.FromJsonString('payload.repeatedInt32')
mask.MergeMessage(nested_src, nested_dst, False, True)
self.assertEqual(1, len(nested_dst.payload.repeated_int32))
self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
+ def testSnakeCaseToCamelCase(self):
+ self.assertEqual('fooBar',
+ well_known_types._SnakeCaseToCamelCase('foo_bar'))
+ self.assertEqual('FooBar',
+ well_known_types._SnakeCaseToCamelCase('_foo_bar'))
+ self.assertEqual('foo3Bar',
+ well_known_types._SnakeCaseToCamelCase('foo3_bar'))
+
+ # No uppercase letter is allowed.
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: Path name Foo must '
+ 'not contain uppercase letters.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'Foo')
+ # Any character after a "_" must be a lowercase letter.
+ # 1. "_" cannot be followed by another "_".
+ # 2. "_" cannot be followed by a digit.
+ # 3. "_" cannot appear as the last character.
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: The character after a '
+ '"_" must be a lowercase letter in path name foo__bar.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo__bar')
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: The character after a '
+ '"_" must be a lowercase letter in path name foo_3bar.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo_3bar')
+ self.assertRaisesRegexp(
+ well_known_types.Error,
+ 'Fail to print FieldMask to Json string: Trailing "_" in path '
+ 'name foo_bar_.',
+ well_known_types._SnakeCaseToCamelCase,
+ 'foo_bar_')
+
+ def testCamelCaseToSnakeCase(self):
+ self.assertEqual('foo_bar',
+ well_known_types._CamelCaseToSnakeCase('fooBar'))
+ self.assertEqual('_foo_bar',
+ well_known_types._CamelCaseToSnakeCase('FooBar'))
+ self.assertEqual('foo3_bar',
+ well_known_types._CamelCaseToSnakeCase('foo3Bar'))
+ self.assertRaisesRegexp(
+ well_known_types.ParseError,
+ 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
+ well_known_types._CamelCaseToSnakeCase,
+ 'foo_bar')
+
class StructTest(unittest.TestCase):
@@ -529,52 +613,52 @@ class StructTest(unittest.TestCase):
struct_list.add_struct()['subkey2'] = 9
self.assertTrue(isinstance(struct, well_known_types.Struct))
- self.assertEquals(5, struct['key1'])
- self.assertEquals('abc', struct['key2'])
+ self.assertEqual(5, struct['key1'])
+ self.assertEqual('abc', struct['key2'])
self.assertIs(True, struct['key3'])
- self.assertEquals(11, struct['key4']['subkey'])
+ self.assertEqual(11, struct['key4']['subkey'])
inner_struct = struct_class()
inner_struct['subkey2'] = 9
- self.assertEquals([6, 'seven', True, False, None, inner_struct],
- list(struct['key5'].items()))
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct['key5'].items()))
serialized = struct.SerializeToString()
struct2 = struct_pb2.Struct()
struct2.ParseFromString(serialized)
- self.assertEquals(struct, struct2)
+ self.assertEqual(struct, struct2)
self.assertTrue(isinstance(struct2, well_known_types.Struct))
- self.assertEquals(5, struct2['key1'])
- self.assertEquals('abc', struct2['key2'])
+ self.assertEqual(5, struct2['key1'])
+ self.assertEqual('abc', struct2['key2'])
self.assertIs(True, struct2['key3'])
- self.assertEquals(11, struct2['key4']['subkey'])
- self.assertEquals([6, 'seven', True, False, None, inner_struct],
- list(struct2['key5'].items()))
+ self.assertEqual(11, struct2['key4']['subkey'])
+ self.assertEqual([6, 'seven', True, False, None, inner_struct],
+ list(struct2['key5'].items()))
struct_list = struct2['key5']
- self.assertEquals(6, struct_list[0])
- self.assertEquals('seven', struct_list[1])
- self.assertEquals(True, struct_list[2])
- self.assertEquals(False, struct_list[3])
- self.assertEquals(None, struct_list[4])
- self.assertEquals(inner_struct, struct_list[5])
+ self.assertEqual(6, struct_list[0])
+ self.assertEqual('seven', struct_list[1])
+ self.assertEqual(True, struct_list[2])
+ self.assertEqual(False, struct_list[3])
+ self.assertEqual(None, struct_list[4])
+ self.assertEqual(inner_struct, struct_list[5])
struct_list[1] = 7
- self.assertEquals(7, struct_list[1])
+ self.assertEqual(7, struct_list[1])
struct_list.add_list().extend([1, 'two', True, False, None])
- self.assertEquals([1, 'two', True, False, None],
- list(struct_list[6].items()))
+ self.assertEqual([1, 'two', True, False, None],
+ list(struct_list[6].items()))
text_serialized = str(struct)
struct3 = struct_pb2.Struct()
text_format.Merge(text_serialized, struct3)
- self.assertEquals(struct, struct3)
+ self.assertEqual(struct, struct3)
struct.get_or_create_struct('key3')['replace'] = 12
- self.assertEquals(12, struct['key3']['replace'])
+ self.assertEqual(12, struct['key3']['replace'])
class AnyTest(unittest.TestCase):
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index edc0cb50..c42371d0 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -86,7 +86,9 @@ class ParseError(Error):
"""Thrown in case of parsing error."""
-def MessageToJson(message, including_default_value_fields=False):
+def MessageToJson(message,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False):
"""Converts protobuf message to JSON format.
Args:
@@ -95,14 +97,42 @@ def MessageToJson(message, including_default_value_fields=False):
repeated fields, and map fields will always be serialized. If
False, only serialize non-empty fields. Singular message fields
and oneof fields are not affected by this option.
+ preserving_proto_field_name: If True, use the original proto field
+ names as defined in the .proto file. If False, convert the field
+ names to lowerCamelCase.
Returns:
A string containing the JSON formatted protocol buffer message.
"""
- printer = _Printer(including_default_value_fields)
+ printer = _Printer(including_default_value_fields,
+ preserving_proto_field_name)
return printer.ToJsonString(message)
+def MessageToDict(message,
+ including_default_value_fields=False,
+ preserving_proto_field_name=False):
+ """Converts protobuf message to a JSON dictionary.
+
+ Args:
+ message: The protocol buffers message instance to serialize.
+ including_default_value_fields: If True, singular primitive fields,
+ repeated fields, and map fields will always be serialized. If
+ False, only serialize non-empty fields. Singular message fields
+ and oneof fields are not affected by this option.
+ preserving_proto_field_name: If True, use the original proto field
+ names as defined in the .proto file. If False, convert the field
+ names to lowerCamelCase.
+
+ Returns:
+ A dict representation of the JSON formatted protocol buffer message.
+ """
+ printer = _Printer(including_default_value_fields,
+ preserving_proto_field_name)
+ # pylint: disable=protected-access
+ return printer._MessageToJsonObject(message)
+
+
def _IsMapEntry(field):
return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and
@@ -113,8 +143,10 @@ class _Printer(object):
"""JSON format printer for protocol message."""
def __init__(self,
- including_default_value_fields=False):
+ including_default_value_fields=False,
+ preserving_proto_field_name=False):
self.including_default_value_fields = including_default_value_fields
+ self.preserving_proto_field_name = preserving_proto_field_name
def ToJsonString(self, message):
js = self._MessageToJsonObject(message)
@@ -137,7 +169,10 @@ class _Printer(object):
try:
for field, value in fields:
- name = field.camelcase_name
+ if self.preserving_proto_field_name:
+ name = field.name
+ else:
+ name = field.json_name
if _IsMapEntry(field):
# Convert a map field.
v_field = field.message_type.fields_by_name['value']
@@ -169,7 +204,10 @@ class _Printer(object):
field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or
field.containing_oneof):
continue
- name = field.camelcase_name
+ if self.preserving_proto_field_name:
+ name = field.name
+ else:
+ name = field.json_name
if name in js:
# Skip the field which has been serailized already.
continue
@@ -328,8 +366,22 @@ def Parse(text, message, ignore_unknown_fields=False):
js = json.loads(text, object_pairs_hook=_DuplicateChecker)
except ValueError as e:
raise ParseError('Failed to load JSON: {0}.'.format(str(e)))
+ return ParseDict(js, message, ignore_unknown_fields)
+
+
+def ParseDict(js_dict, message, ignore_unknown_fields=False):
+ """Parses a JSON dictionary representation into a message.
+
+ Args:
+ js_dict: Dict representation of a JSON message.
+ message: A protocol buffer message to merge into.
+ ignore_unknown_fields: If True, do not raise errors for unknown fields.
+
+ Returns:
+ The same message passed as argument.
+ """
parser = _Parser(ignore_unknown_fields)
- parser.ConvertMessage(js, message)
+ parser.ConvertMessage(js_dict, message)
return message
@@ -374,9 +426,13 @@ class _Parser(object):
"""
names = []
message_descriptor = message.DESCRIPTOR
+ fields_by_json_name = dict((f.json_name, f)
+ for f in message_descriptor.fields)
for name in js:
try:
- field = message_descriptor.fields_by_camelcase_name.get(name, None)
+ field = fields_by_json_name.get(name, None)
+ if not field:
+ field = message_descriptor.fields_by_name.get(name, None)
if not field:
if self.ignore_unknown_fields:
continue
@@ -399,7 +455,12 @@ class _Parser(object):
value = js[name]
if value is None:
- message.ClearField(field.name)
+ if (field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE
+ and field.message_type.full_name == 'google.protobuf.Value'):
+ sub_message = getattr(message, field.name)
+ sub_message.null_value = 0
+ else:
+ message.ClearField(field.name)
continue
# Parse field value.
@@ -431,6 +492,7 @@ class _Parser(object):
_ConvertScalarFieldValue(item, field))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
sub_message = getattr(message, field.name)
+ sub_message.SetInParent()
self.ConvertMessage(value, sub_message)
else:
setattr(message, field.name, _ConvertScalarFieldValue(value, field))
@@ -574,10 +636,15 @@ def _ConvertScalarFieldValue(value, field, require_str=False):
# Convert an enum value.
enum_value = field.enum_type.values_by_name.get(value, None)
if enum_value is None:
- raise ParseError(
- 'Enum value must be a string literal with double quotes. '
- 'Type "{0}" has no value named {1}.'.format(
- field.enum_type.full_name, value))
+ try:
+ number = int(value)
+ enum_value = field.enum_type.values_by_number.get(number, None)
+ except ValueError:
+ raise ParseError('Invalid enum value {0} for enum type {1}.'.format(
+ value, field.enum_type.full_name))
+ if enum_value is None:
+ raise ParseError('Invalid enum value {0} for enum type {1}.'.format(
+ value, field.enum_type.full_name))
return enum_value.number
@@ -593,7 +660,7 @@ def _ConvertInteger(value):
Raises:
ParseError: If an integer couldn't be consumed.
"""
- if isinstance(value, float):
+ if isinstance(value, float) and not value.is_integer():
raise ParseError('Couldn\'t parse integer: {0}.'.format(value))
if isinstance(value, six.text_type) and value.find(' ') != -1:
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index 606f735f..aab250e4 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -225,10 +225,11 @@ class Message(object):
# """
def ListFields(self):
"""Returns a list of (FieldDescriptor, value) tuples for all
- fields in the message which are not empty. A singular field is non-empty
- if HasField() would return true, and a repeated field is non-empty if
- it contains at least one element. The fields are ordered by field
- number"""
+ fields in the message which are not empty. A message field is
+ non-empty if HasField() would return true. A singular primitive field
+ is non-empty if HasField() would return true in proto2 or it is non zero
+ in proto3. A repeated field is non-empty if it contains at least one
+ element. The fields are ordered by field number"""
raise NotImplementedError
def HasField(self, field_name):
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index 1b059d13..8ab1c513 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -103,13 +103,8 @@ class MessageFactory(object):
result = {}
for file_name in files:
file_desc = self.pool.FindFileByName(file_name)
- for name, msg in file_desc.message_types_by_name.items():
- if file_desc.package:
- full_name = '.'.join([file_desc.package, name])
- else:
- full_name = msg.name
- result[full_name] = self.GetPrototype(
- self.pool.FindMessageTypeByName(full_name))
+ for desc in file_desc.message_types_by_name.values():
+ result[desc.full_name] = self.GetPrototype(desc)
# While the extension FieldDescriptors are created by the descriptor pool,
# the python classes created in the factory need them to be registered
@@ -120,7 +115,7 @@ class MessageFactory(object):
# ignore the registration if the original was the same, or raise
# an error if they were different.
- for name, extension in file_desc.extensions_by_name.items():
+ for extension in file_desc.extensions_by_name.values():
if extension.containing_type.full_name not in self._classes:
self.GetPrototype(extension.containing_type)
extended_class = self._classes[extension.containing_type.full_name]
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index e6ef5ef5..924ae0b9 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -41,6 +41,7 @@
#include <google/protobuf/pyext/descriptor_containers.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
@@ -204,8 +205,9 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// read-only instance.
const Message& options(descriptor->options());
const Descriptor *message_type = options.GetDescriptor();
- CMessageClass* message_class(
- cdescriptor_pool::GetMessageClass(pool, message_type));
+ PyMessageFactory* message_factory = pool->py_message_factory;
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ message_factory, message_type);
if (message_class == NULL) {
// The Options message was not found in the current DescriptorPool.
// This means that the pool cannot contain any extensions to the Options
@@ -213,7 +215,9 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// the chances of successfully parsing the options.
PyErr_Clear();
pool = GetDefaultDescriptorPool();
- message_class = cdescriptor_pool::GetMessageClass(pool, message_type);
+ message_factory = pool->py_message_factory;
+ message_class = message_factory::GetMessageClass(
+ message_factory, message_type);
}
if (message_class == NULL) {
PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s",
@@ -243,7 +247,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
options.SerializeToString(&serialized);
io::CodedInputStream input(
reinterpret_cast<const uint8*>(serialized.c_str()), serialized.size());
- input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ input.SetExtensionRegistry(pool->pool, message_factory->message_factory);
bool success = cmsg->message->MergePartialFromCodedStream(&input);
if (!success) {
PyErr_Format(PyExc_ValueError, "Error parsing Options message");
@@ -439,8 +443,9 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) {
// which contains this descriptor.
// This might not be the one you expect! For example the returned object does
// not know about extensions defined in a custom pool.
- CMessageClass* concrete_class(cdescriptor_pool::GetMessageClass(
- GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()),
+ CMessageClass* concrete_class(message_factory::GetMessageClass(
+ GetDescriptorPool_FromPool(
+ _GetDescriptor(self)->file()->pool())->py_message_factory,
_GetDescriptor(self)));
Py_XINCREF(concrete_class);
return concrete_class->AsPyObject();
@@ -699,6 +704,10 @@ static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) {
return PyString_FromCppString(_GetDescriptor(self)->camelcase_name());
}
+static PyObject* GetJsonName(PyBaseDescriptor* self, void *closure) {
+ return PyString_FromCppString(_GetDescriptor(self)->json_name());
+}
+
static PyObject* GetType(PyBaseDescriptor *self, void *closure) {
return PyInt_FromLong(_GetDescriptor(self)->type());
}
@@ -888,6 +897,7 @@ static PyGetSetDef Getters[] = {
{ "full_name", (getter)GetFullName, NULL, "Full name"},
{ "name", (getter)GetName, NULL, "Unqualified name"},
{ "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"},
+ { "json_name", (getter)GetJsonName, NULL, "Json name"},
{ "type", (getter)GetType, NULL, "C++ Type"},
{ "cpp_type", (getter)GetCppType, NULL, "C++ Type"},
{ "label", (getter)GetLabel, NULL, "Label"},
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index cfd98690..a42e5431 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -33,11 +33,11 @@
#include <Python.h>
#include <google/protobuf/descriptor.pb.h>
-#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_database.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
@@ -73,18 +73,16 @@ static PyDescriptorPool* _CreateDescriptorPool() {
cpool->underlay = NULL;
cpool->database = NULL;
- DynamicMessageFactory* message_factory = new DynamicMessageFactory();
- // This option might be the default some day.
- message_factory->SetDelegateToGeneratedFactory(true);
- cpool->message_factory = message_factory;
-
- // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same
- // storage.
- cpool->classes_by_descriptor =
- new PyDescriptorPool::ClassesByMessageMap();
cpool->descriptor_options =
new hash_map<const void*, PyObject *>();
+ cpool->py_message_factory = message_factory::NewMessageFactory(
+ &PyMessageFactory_Type, cpool);
+ if (cpool->py_message_factory == NULL) {
+ Py_DECREF(cpool);
+ return NULL;
+ }
+
return cpool;
}
@@ -151,20 +149,14 @@ static PyObject* New(PyTypeObject* type,
}
static void Dealloc(PyDescriptorPool* self) {
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
descriptor_pool_map.erase(self->pool);
- for (iterator it = self->classes_by_descriptor->begin();
- it != self->classes_by_descriptor->end(); ++it) {
- Py_DECREF(it->second);
- }
- delete self->classes_by_descriptor;
+ Py_CLEAR(self->py_message_factory);
for (hash_map<const void*, PyObject*>::iterator it =
self->descriptor_options->begin();
it != self->descriptor_options->end(); ++it) {
Py_DECREF(it->second);
}
delete self->descriptor_options;
- delete self->message_factory;
delete self->database;
delete self->pool;
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
@@ -188,35 +180,8 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) {
return PyMessageDescriptor_FromDescriptor(message_descriptor);
}
-// Add a message class to our database.
-int RegisterMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor,
- CMessageClass* message_class) {
- Py_INCREF(message_class);
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
- std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
- std::make_pair(message_descriptor, message_class));
- if (!ret.second) {
- // Update case: DECREF the previous value.
- Py_DECREF(ret.first->second);
- ret.first->second = message_class;
- }
- return 0;
-}
-// Retrieve the message class added to our database.
-CMessageClass* GetMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor) {
- typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator;
- iterator ret = self->classes_by_descriptor->find(message_descriptor);
- if (ret == self->classes_by_descriptor->end()) {
- PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
- message_descriptor->full_name().c_str());
- return NULL;
- } else {
- return ret->second;
- }
-}
+
PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
Py_ssize_t name_size;
@@ -228,11 +193,9 @@ PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) {
const FileDescriptor* file_descriptor =
self->pool->FindFileByName(string(name, name_size));
if (file_descriptor == NULL) {
- PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s",
- name);
+ PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", name);
return NULL;
}
-
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
index 2a42c112..8de6c60b 100644
--- a/python/google/protobuf/pyext/descriptor_pool.h
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -38,10 +38,10 @@
namespace google {
namespace protobuf {
-class MessageFactory;
-
namespace python {
+class PyMessageFactory;
+
// The (meta) type of all Messages classes.
struct CMessageClass;
@@ -69,20 +69,10 @@ typedef struct PyDescriptorPool {
// This pointer is owned.
const DescriptorDatabase* database;
- // DynamicMessageFactory used to create C++ instances of messages.
- // This object cache the descriptors that were used, so the DescriptorPool
- // needs to get rid of it before it can delete itself.
- //
- // Note: A C++ MessageFactory is different from the Python MessageFactory.
- // The C++ one creates messages, when the Python one creates classes.
- MessageFactory* message_factory;
-
- // Make our own mapping to retrieve Python classes from C++ descriptors.
- //
- // Descriptor pointers stored here are owned by the DescriptorPool above.
- // Python references to classes are owned by this PyDescriptorPool.
- typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap;
- ClassesByMessageMap* classes_by_descriptor;
+ // The preferred MessageFactory to be used by descriptors.
+ // TODO(amauryfa): Don't create the Factory from the DescriptorPool, but
+ // use the one passed while creating message classes. And remove this member.
+ PyMessageFactory* py_message_factory;
// Cache the options for any kind of descriptor.
// Descriptor pointers are owned by the DescriptorPool above.
@@ -100,19 +90,6 @@ namespace cdescriptor_pool {
const Descriptor* FindMessageTypeByName(PyDescriptorPool* self,
const string& name);
-// Registers a new Python class for the given message descriptor.
-// On error, returns -1 with a Python exception set.
-int RegisterMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor,
- CMessageClass* message_class);
-
-// Retrieves the Python class registered with the given message descriptor.
-//
-// Returns a *borrowed* reference if found, otherwise returns NULL with an
-// exception set.
-CMessageClass* GetMessageClass(PyDescriptorPool* self,
- const Descriptor* message_descriptor);
-
// The functions below are also exposed as methods of the DescriptorPool type.
// Looks up a message by name. Returns a PyMessageDescriptor corresponding to
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index 21bbb8c2..dbb7bca0 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -39,8 +39,8 @@
#include <google/protobuf/dynamic_message.h>
#include <google/protobuf/message.h>
#include <google/protobuf/pyext/descriptor.h>
-#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
@@ -60,35 +60,6 @@ PyObject* len(ExtensionDict* self) {
#endif
}
-// TODO(tibell): Use VisitCompositeField.
-int ReleaseExtension(ExtensionDict* self,
- PyObject* extension,
- const FieldDescriptor* descriptor) {
- if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- if (repeated_composite_container::Release(
- reinterpret_cast<RepeatedCompositeContainer*>(
- extension)) < 0) {
- return -1;
- }
- } else {
- if (repeated_scalar_container::Release(
- reinterpret_cast<RepeatedScalarContainer*>(
- extension)) < 0) {
- return -1;
- }
- }
- } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- if (cmessage::ReleaseSubMessage(
- self->parent, descriptor,
- reinterpret_cast<CMessage*>(extension)) < 0) {
- return -1;
- }
- }
-
- return 0;
-}
-
PyObject* subscript(ExtensionDict* self, PyObject* key) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == NULL) {
@@ -130,8 +101,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
- cmessage::GetDescriptorPoolForMessage(self->parent),
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ cmessage::GetFactoryForMessage(self->parent),
descriptor->message_type());
if (message_class == NULL) {
return NULL;
@@ -183,47 +154,6 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}
-PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
- const FieldDescriptor* descriptor =
- cmessage::GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- PyObject* value = PyDict_GetItem(self->values, extension);
- if (self->parent) {
- if (value != NULL) {
- if (ReleaseExtension(self, value, descriptor) < 0) {
- return NULL;
- }
- }
- if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor(
- self->parent, descriptor)) == NULL) {
- return NULL;
- }
- }
- if (PyDict_DelItem(self->values, extension) < 0) {
- PyErr_Clear();
- }
- Py_RETURN_NONE;
-}
-
-PyObject* HasExtension(ExtensionDict* self, PyObject* extension) {
- const FieldDescriptor* descriptor =
- cmessage::GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- if (self->parent) {
- return cmessage::HasFieldByDescriptor(self->parent, descriptor);
- } else {
- int exists = PyDict_Contains(self->values, extension);
- if (exists < 0) {
- return NULL;
- }
- return PyBool_FromLong(exists);
- }
-}
-
PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString(
reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name"));
@@ -282,8 +212,6 @@ static PyMappingMethods MpMethods = {
#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
static PyMethodDef Methods[] = {
- EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."),
- EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."),
EDMETHOD(_FindExtensionByName, METH_O,
"Finds an extension by name."),
EDMETHOD(_FindExtensionByNumber, METH_O,
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
index 2456eda1..65b87862 100644
--- a/python/google/protobuf/pyext/extension_dict.h
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -86,49 +86,6 @@ namespace extension_dict {
// Builds an Extensions dict for a specific message.
ExtensionDict* NewExtensionDict(CMessage *parent);
-// Gets the number of extension values in this ExtensionDict as a python object.
-//
-// Returns a new reference.
-PyObject* len(ExtensionDict* self);
-
-// Releases extensions referenced outside this dictionary to keep outside
-// references alive.
-//
-// Returns 0 on success, -1 on failure.
-int ReleaseExtension(ExtensionDict* self,
- PyObject* extension,
- const FieldDescriptor* descriptor);
-
-// Gets an extension from the dict for the given extension descriptor.
-//
-// Returns a new reference.
-PyObject* subscript(ExtensionDict* self, PyObject* key);
-
-// Assigns a value to an extension in the dict. Can only be used for singular
-// simple types.
-//
-// Returns 0 on success, -1 on failure.
-int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value);
-
-// Clears an extension from the dict. Will release the extension if there
-// is still an external reference left to it.
-//
-// Returns None on success.
-PyObject* ClearExtension(ExtensionDict* self,
- PyObject* extension);
-
-// Gets an extension from the dict given the extension name as opposed to
-// descriptor.
-//
-// Returns a new reference.
-PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name);
-
-// Gets an extension from the dict given the extension field number as
-// opposed to descriptor.
-//
-// Returns a new reference.
-PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number);
-
} // namespace extension_dict
} // namespace python
} // namespace protobuf
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 0987b898..318c2e7c 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -42,7 +42,9 @@
#include <google/protobuf/map_field.h>
#include <google/protobuf/map.h>
#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#if PY_MAJOR_VERSION >= 3
@@ -328,6 +330,15 @@ PyObject* Clear(PyObject* _self) {
Py_RETURN_NONE;
}
+PyObject* GetEntryClass(PyObject* _self) {
+ MapContainer* self = GetMap(_self);
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ cmessage::GetFactoryForMessage(self->parent),
+ self->parent_field_descriptor->message_type());
+ Py_XINCREF(message_class);
+ return reinterpret_cast<PyObject*>(message_class);
+}
+
PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
MapContainer* self = GetMap(_self);
@@ -400,12 +411,7 @@ PyObject *NewScalarMapContainer(
return NULL;
}
-#if PY_MAJOR_VERSION >= 3
- ScopedPyObjectPtr obj(PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
-#else
- ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
-#endif
+ ScopedPyObjectPtr obj(PyType_GenericAlloc(ScalarMapContainer_Type, 0));
if (obj.get() == NULL) {
return PyErr_Format(PyExc_RuntimeError,
"Could not allocate new container.");
@@ -527,6 +533,8 @@ static PyMethodDef ScalarMapMethods[] = {
"Removes all elements from the map." },
{ "get", ScalarMapGet, METH_VARARGS,
"Gets the value for the given key if present, or otherwise a default" },
+ { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
+ "Return the class used to build Entries of (key, value) pairs." },
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -536,6 +544,7 @@ static PyMethodDef ScalarMapMethods[] = {
{NULL, NULL},
};
+PyTypeObject *ScalarMapContainer_Type;
#if PY_MAJOR_VERSION >= 3
static PyType_Slot ScalarMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void *)ScalarMapDealloc},
@@ -554,7 +563,6 @@ static PyMethodDef ScalarMapMethods[] = {
Py_TPFLAGS_DEFAULT,
ScalarMapContainer_Type_slots
};
- PyObject *ScalarMapContainer_Type;
#else
static PyMappingMethods ScalarMapMappingMethods = {
MapReflectionFriend::Length, // mp_length
@@ -562,7 +570,7 @@ static PyMethodDef ScalarMapMethods[] = {
MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript
};
- PyTypeObject ScalarMapContainer_Type = {
+ PyTypeObject _ScalarMapContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
sizeof(MapContainer), // tp_basicsize
@@ -643,12 +651,7 @@ PyObject* NewMessageMapContainer(
return NULL;
}
-#if PY_MAJOR_VERSION >= 3
- PyObject* obj = PyType_GenericAlloc(
- reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
-#else
- PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
-#endif
+ PyObject* obj = PyType_GenericAlloc(MessageMapContainer_Type, 0);
if (obj == NULL) {
return PyErr_Format(PyExc_RuntimeError,
"Could not allocate new container.");
@@ -780,6 +783,8 @@ static PyMethodDef MessageMapMethods[] = {
"Gets the value for the given key if present, or otherwise a default" },
{ "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
"Alias for getitem, useful to make explicit that the map is mutated." },
+ { "GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
+ "Return the class used to build Entries of (key, value) pairs." },
/*
{ "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
"Makes a deep copy of the class." },
@@ -789,6 +794,7 @@ static PyMethodDef MessageMapMethods[] = {
{NULL, NULL},
};
+PyTypeObject *MessageMapContainer_Type;
#if PY_MAJOR_VERSION >= 3
static PyType_Slot MessageMapContainer_Type_slots[] = {
{Py_tp_dealloc, (void *)MessageMapDealloc},
@@ -807,8 +813,6 @@ static PyMethodDef MessageMapMethods[] = {
Py_TPFLAGS_DEFAULT,
MessageMapContainer_Type_slots
};
-
- PyObject *MessageMapContainer_Type;
#else
static PyMappingMethods MessageMapMappingMethods = {
MapReflectionFriend::Length, // mp_length
@@ -816,7 +820,7 @@ static PyMethodDef MessageMapMethods[] = {
MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript
};
- PyTypeObject MessageMapContainer_Type = {
+ PyTypeObject _MessageMapContainer_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".MessageMapContainer", // tp_name
sizeof(MessageMapContainer), // tp_basicsize
@@ -965,6 +969,63 @@ PyTypeObject MapIterator_Type = {
0, // tp_init
};
+bool InitMapContainers() {
+ // ScalarMapContainer_Type derives from our MutableMapping type.
+ ScopedPyObjectPtr containers(PyImport_ImportModule(
+ "google.protobuf.internal.containers"));
+ if (containers == NULL) {
+ return false;
+ }
+
+ ScopedPyObjectPtr mutable_mapping(
+ PyObject_GetAttrString(containers.get(), "MutableMapping"));
+ if (mutable_mapping == NULL) {
+ return false;
+ }
+
+ if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
+ return false;
+ }
+
+ Py_INCREF(mutable_mapping.get());
+#if PY_MAJOR_VERSION >= 3
+ PyObject* bases = PyTuple_New(1);
+ PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
+
+ ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
+ PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases));
+#else
+ _ScalarMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&_ScalarMapContainer_Type) < 0) {
+ return false;
+ }
+
+ ScalarMapContainer_Type = &_ScalarMapContainer_Type;
+#endif
+
+ if (PyType_Ready(&MapIterator_Type) < 0) {
+ return false;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
+ PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases));
+#else
+ Py_INCREF(mutable_mapping.get());
+ _MessageMapContainer_Type.tp_base =
+ reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
+
+ if (PyType_Ready(&_MessageMapContainer_Type) < 0) {
+ return false;
+ }
+
+ MessageMapContainer_Type = &_MessageMapContainer_Type;
+#endif
+ return true;
+}
+
} // namespace python
} // namespace protobuf
} // namespace google
diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h
index fbd6713f..615657b0 100644
--- a/python/google/protobuf/pyext/map_container.h
+++ b/python/google/protobuf/pyext/map_container.h
@@ -112,16 +112,10 @@ struct MessageMapContainer : public MapContainer {
PyObject* message_dict;
};
-#if PY_MAJOR_VERSION >= 3
- extern PyObject *MessageMapContainer_Type;
- extern PyType_Spec MessageMapContainer_Type_spec;
- extern PyObject *ScalarMapContainer_Type;
- extern PyType_Spec ScalarMapContainer_Type_spec;
-#else
- extern PyTypeObject MessageMapContainer_Type;
- extern PyTypeObject ScalarMapContainer_Type;
-#endif
+bool InitMapContainers();
+extern PyTypeObject* MessageMapContainer_Type;
+extern PyTypeObject* ScalarMapContainer_Type;
extern PyTypeObject MapIterator_Type; // Both map types use the same iterator.
// Builds a MapContainer object, from a parent message and a
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index 5535338d..1b325469 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -63,6 +63,7 @@
#include <google/protobuf/pyext/repeated_composite_container.h>
#include <google/protobuf/pyext/repeated_scalar_container.h>
#include <google/protobuf/pyext/map_container.h>
+#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
#include <google/protobuf/stubs/strutil.h>
@@ -244,6 +245,12 @@ static PyObject* New(PyTypeObject* type,
return NULL;
}
+ // Messages have no __dict__
+ ScopedPyObjectPtr slots(PyTuple_New(0));
+ if (PyDict_SetItemString(dict, "__slots__", slots.get()) < 0) {
+ return NULL;
+ }
+
// Build the arguments to the base metaclass.
// We change the __bases__ classes.
ScopedPyObjectPtr new_args;
@@ -300,16 +307,19 @@ static PyObject* New(PyTypeObject* type,
newtype->message_descriptor = descriptor;
// TODO(amauryfa): Don't always use the canonical pool of the descriptor,
// use the MessageFactory optionally passed in the class dict.
- newtype->py_descriptor_pool = GetDescriptorPool_FromPool(
- descriptor->file()->pool());
- if (newtype->py_descriptor_pool == NULL) {
+ PyDescriptorPool* py_descriptor_pool =
+ GetDescriptorPool_FromPool(descriptor->file()->pool());
+ if (py_descriptor_pool == NULL) {
return NULL;
}
- Py_INCREF(newtype->py_descriptor_pool);
+ newtype->py_message_factory = py_descriptor_pool->py_message_factory;
+ Py_INCREF(newtype->py_message_factory);
- // Add the message to the DescriptorPool.
- if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool,
- descriptor, newtype) < 0) {
+ // Register the message in the MessageFactory.
+ // TODO(amauryfa): Move this call to MessageFactory.GetPrototype() when the
+ // MessageFactory is fully implemented in C++.
+ if (message_factory::RegisterMessageClass(newtype->py_message_factory,
+ descriptor, newtype) < 0) {
return NULL;
}
@@ -321,8 +331,8 @@ static PyObject* New(PyTypeObject* type,
}
static void Dealloc(CMessageClass *self) {
- Py_DECREF(self->py_message_descriptor);
- Py_DECREF(self->py_descriptor_pool);
+ Py_XDECREF(self->py_message_descriptor);
+ Py_XDECREF(self->py_message_factory);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -752,15 +762,9 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
namespace cmessage {
-PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) {
- // No need to check the type: the type of instances of CMessage is always
- // an instance of CMessageClass. Let's prove it with a debug-only check.
+PyMessageFactory* GetFactoryForMessage(CMessage* message) {
GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
- return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_descriptor_pool;
-}
-
-MessageFactory* GetFactoryForMessage(CMessage* message) {
- return GetDescriptorPoolForMessage(message)->message_factory;
+ return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
}
static int MaybeReleaseOverlappingOneofField(
@@ -813,7 +817,8 @@ static Message* GetMutableMessage(
return NULL;
}
return reflection->MutableMessage(
- parent_message, parent_field, GetFactoryForMessage(parent));
+ parent_message, parent_field,
+ GetFactoryForMessage(parent)->message_factory);
}
struct FixupMessageReference : public ChildVisitor {
@@ -1172,6 +1177,8 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
}
CMessage* cmessage = reinterpret_cast<CMessage*>(message.get());
if (PyDict_Check(value)) {
+ // Make the message exist even if the dict is empty.
+ AssureWritable(cmessage);
if (InitAttributes(cmessage, NULL, value) < 0) {
return -1;
}
@@ -1231,7 +1238,7 @@ static PyObject* New(PyTypeObject* cls,
if (message_descriptor == NULL) {
return NULL;
}
- const Message* default_message = type->py_descriptor_pool->message_factory
+ const Message* default_message = type->py_message_factory->message_factory
->GetPrototype(message_descriptor);
if (default_message == NULL) {
PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str());
@@ -1292,6 +1299,9 @@ struct ClearWeakReferences : public ChildVisitor {
};
static void Dealloc(CMessage* self) {
+ if (self->weakreflist) {
+ PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
+ }
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
if (self->extensions) {
@@ -1459,18 +1469,20 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
}
PyObject* ClearExtension(CMessage* self, PyObject* extension) {
+ const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
+ if (descriptor == NULL) {
+ return NULL;
+ }
if (self->extensions != NULL) {
- return extension_dict::ClearExtension(self->extensions, extension);
- } else {
- const FieldDescriptor* descriptor = GetExtensionDescriptor(extension);
- if (descriptor == NULL) {
- return NULL;
- }
- if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) {
- return NULL;
+ PyObject* value = PyDict_GetItem(self->extensions->values, extension);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->extensions->values, extension);
}
}
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, descriptor);
}
PyObject* HasExtension(CMessage* self, PyObject* extension) {
@@ -1556,7 +1568,7 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
Message* ReleaseMessage(CMessage* self,
const Descriptor* descriptor,
const FieldDescriptor* field_descriptor) {
- MessageFactory* message_factory = GetFactoryForMessage(self);
+ MessageFactory* message_factory = GetFactoryForMessage(self)->message_factory;
Message* released_message = self->message->GetReflection()->ReleaseMessage(
self->message, field_descriptor, message_factory);
// ReleaseMessage will return NULL which differs from
@@ -1624,12 +1636,19 @@ int InternalReleaseFieldByDescriptor(
PyObject* ClearFieldByDescriptor(
CMessage* self,
- const FieldDescriptor* descriptor) {
- if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
+ const FieldDescriptor* field_descriptor) {
+ if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) {
return NULL;
}
AssureWritable(self);
- self->message->GetReflection()->ClearField(self->message, descriptor);
+ Message* message = self->message;
+ message->GetReflection()->ClearField(message, field_descriptor);
+ if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
+ !message->GetReflection()->SupportsUnknownEnumValues()) {
+ UnknownFieldSet* unknown_field_set =
+ message->GetReflection()->MutableUnknownFields(message);
+ unknown_field_set->DeleteByNumber(field_descriptor->number());
+ }
Py_RETURN_NONE;
}
@@ -1665,27 +1684,17 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
arg = arg_in_oneof.get();
}
- PyObject* composite_field = self->composite_fields ?
- PyDict_GetItem(self->composite_fields, arg) : NULL;
-
- // Only release the field if there's a possibility that there are
- // references to it.
- if (composite_field != NULL) {
- if (InternalReleaseFieldByDescriptor(self, field_descriptor,
- composite_field) < 0) {
- return NULL;
+ // Release the field if it exists in the dict of composite fields.
+ if (self->composite_fields) {
+ PyObject* value = PyDict_GetItem(self->composite_fields, arg);
+ if (value != NULL) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->composite_fields, arg);
}
- PyDict_DelItem(self->composite_fields, arg);
- }
- message->GetReflection()->ClearField(message, field_descriptor);
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM &&
- !message->GetReflection()->SupportsUnknownEnumValues()) {
- UnknownFieldSet* unknown_field_set =
- message->GetReflection()->MutableUnknownFields(message);
- unknown_field_set->DeleteByNumber(field_descriptor->number());
}
-
- Py_RETURN_NONE;
+ return ClearFieldByDescriptor(self, field_descriptor);
}
PyObject* Clear(CMessage* self) {
@@ -1927,8 +1936,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
if (allow_oversize_protos) {
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
}
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
- input.SetExtensionRegistry(pool->pool, pool->message_factory);
+ PyMessageFactory* factory = GetFactoryForMessage(self);
+ input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
if (success) {
return PyInt_FromLong(input.CurrentPosition());
@@ -2108,8 +2117,8 @@ static PyObject* ListFields(CMessage* self) {
// is no message class and we cannot retrieve the value.
// TODO(amauryfa): consider building the class on the fly!
if (fields[i]->message_type() != NULL &&
- cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self),
+ message_factory::GetMessageClass(
+ GetFactoryForMessage(self),
fields[i]->message_type()) == NULL) {
PyErr_Clear();
continue;
@@ -2306,12 +2315,12 @@ PyObject* InternalGetScalar(const Message* message,
PyObject* InternalGetSubMessage(
CMessage* self, const FieldDescriptor* field_descriptor) {
const Reflection* reflection = self->message->GetReflection();
- PyDescriptorPool* pool = GetDescriptorPoolForMessage(self);
+ PyMessageFactory* factory = GetFactoryForMessage(self);
const Message& sub_message = reflection->GetMessage(
- *self->message, field_descriptor, pool->message_factory);
+ *self->message, field_descriptor, factory->message_factory);
- CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
- pool, field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ factory, field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2656,8 +2665,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
const Descriptor* entry_type = field_descriptor->message_type();
const FieldDescriptor* value_type = entry_type->FindFieldByName("value");
if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- CMessageClass* value_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), value_type->message_type());
+ CMessageClass* value_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), value_type->message_type());
if (value_class == NULL) {
return NULL;
}
@@ -2679,8 +2688,8 @@ PyObject* GetAttr(CMessage* self, PyObject* name) {
if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
PyObject* py_container = NULL;
if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- CMessageClass* message_class = cdescriptor_pool::GetMessageClass(
- GetDescriptorPoolForMessage(self), field_descriptor->message_type());
+ CMessageClass* message_class = message_factory::GetMessageClass(
+ GetFactoryForMessage(self), field_descriptor->message_type());
if (message_class == NULL) {
return NULL;
}
@@ -2775,7 +2784,7 @@ PyTypeObject CMessage_Type = {
0, // tp_traverse
0, // tp_clear
(richcmpfunc)cmessage::RichCompare, // tp_richcompare
- 0, // tp_weaklistoffset
+ offsetof(CMessage, weakreflist), // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
cmessage::Methods, // tp_methods
@@ -2863,6 +2872,11 @@ bool InitProto2MessageModule(PyObject *m) {
return false;
}
+ // Initialize types and globals in message_factory.cc
+ if (!InitMessageFactory()) {
+ return false;
+ }
+
// Initialize constants defined in this file.
InitGlobals();
@@ -2944,69 +2958,15 @@ bool InitProto2MessageModule(PyObject *m) {
}
// Initialize Map container types.
- {
- // ScalarMapContainer_Type derives from our MutableMapping type.
- ScopedPyObjectPtr containers(PyImport_ImportModule(
- "google.protobuf.internal.containers"));
- if (containers == NULL) {
- return false;
- }
-
- ScopedPyObjectPtr mutable_mapping(
- PyObject_GetAttrString(containers.get(), "MutableMapping"));
- if (mutable_mapping == NULL) {
- return false;
- }
-
- if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
- return false;
- }
-
- Py_INCREF(mutable_mapping.get());
-#if PY_MAJOR_VERSION >= 3
- PyObject* bases = PyTuple_New(1);
- PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
-
- ScalarMapContainer_Type =
- PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "ScalarMapContainer", ScalarMapContainer_Type);
-#else
- ScalarMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&ScalarMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "ScalarMapContainer",
- reinterpret_cast<PyObject*>(&ScalarMapContainer_Type));
-#endif
-
- if (PyType_Ready(&MapIterator_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MapIterator",
- reinterpret_cast<PyObject*>(&MapIterator_Type));
-
-
-#if PY_MAJOR_VERSION >= 3
- MessageMapContainer_Type =
- PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases);
- PyModule_AddObject(m, "MessageMapContainer", MessageMapContainer_Type);
-#else
- Py_INCREF(mutable_mapping.get());
- MessageMapContainer_Type.tp_base =
- reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
-
- if (PyType_Ready(&MessageMapContainer_Type) < 0) {
- return false;
- }
-
- PyModule_AddObject(m, "MessageMapContainer",
- reinterpret_cast<PyObject*>(&MessageMapContainer_Type));
-#endif
+ if (!InitMapContainers()) {
+ return false;
}
+ PyModule_AddObject(m, "ScalarMapContainer",
+ reinterpret_cast<PyObject*>(ScalarMapContainer_Type));
+ PyModule_AddObject(m, "MessageMapContainer",
+ reinterpret_cast<PyObject*>(MessageMapContainer_Type));
+ PyModule_AddObject(m, "MapIterator",
+ reinterpret_cast<PyObject*>(&MapIterator_Type));
if (PyType_Ready(&ExtensionDict_Type) < 0) {
return false;
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index c44a2ae2..1550724c 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -62,7 +62,7 @@ using internal::shared_ptr;
namespace python {
struct ExtensionDict;
-struct PyDescriptorPool;
+struct PyMessageFactory;
typedef struct CMessage {
PyObject_HEAD;
@@ -112,6 +112,9 @@ typedef struct CMessage {
// Similar to composite_fields, acting as a cache, but also contains the
// required extension dict logic.
ExtensionDict* extensions;
+
+ // Implements the "weakref" protocol for this object.
+ PyObject* weakreflist;
} CMessage;
extern PyTypeObject CMessage_Type;
@@ -132,14 +135,11 @@ struct CMessageClass {
// Owned reference, used to keep the pointer above alive.
PyObject* py_message_descriptor;
- // The Python DescriptorPool used to create the class. It is needed to resolve
+ // The Python MessageFactory used to create the class. It is needed to resolve
// fields descriptors, including extensions fields; its C++ MessageFactory is
// used to instantiate submessages.
- // This can be different from DESCRIPTOR.file.pool, in the case of a custom
- // DescriptorPool which defines new extensions.
- // We own the reference, because it's important to keep the descriptors and
- // factory alive.
- PyDescriptorPool* py_descriptor_pool;
+ // We own the reference, because it's important to keep the factory alive.
+ PyMessageFactory* py_message_factory;
PyObject* AsPyObject() {
return reinterpret_cast<PyObject*>(this);
@@ -154,14 +154,6 @@ namespace cmessage {
// The caller must fill self->message, self->owner and eventually self->parent.
CMessage* NewEmptyMessage(CMessageClass* type);
-// Release a submessage from its proto tree, making it a new top-level messgae.
-// A new message will be created if this is a read-only default instance.
-//
-// Corresponds to reflection api method ReleaseMessage.
-int ReleaseSubMessage(CMessage* self,
- const FieldDescriptor* field_descriptor,
- CMessage* child_cmessage);
-
// Retrieves the C++ descriptor of a Python Extension descriptor.
// On error, return NULL with an exception set.
const FieldDescriptor* GetExtensionDescriptor(PyObject* extension);
@@ -262,14 +254,13 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner);
int AssureWritable(CMessage* self);
-// Returns the "best" DescriptorPool for the given message.
-// This is often equivalent to message.DESCRIPTOR.pool, but not always, when
-// the message class was created from a MessageFactory using a custom pool which
-// uses the generated pool as an underlay.
+// Returns the message factory for the given message.
+// This is equivalent to message.MESSAGE_FACTORY
//
-// The returned pool is suitable for finding fields and building submessages,
+// The returned factory is suitable for finding fields and building submessages,
// even in the case of extensions.
-PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message);
+// Returns a *borrowed* reference, and never fails because we pass a CMessage.
+PyMessageFactory* GetFactoryForMessage(CMessage* message);
PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg);
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
new file mode 100644
index 00000000..2ad89022
--- /dev/null
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -0,0 +1,214 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <Python.h>
+
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #endif
+ #define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob)? \
+ ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \
+ PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+namespace message_factory {
+
+PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool) {
+ PyMessageFactory* factory = reinterpret_cast<PyMessageFactory*>(
+ PyType_GenericAlloc(type, 0));
+ if (factory == NULL) {
+ return NULL;
+ }
+
+ DynamicMessageFactory* message_factory = new DynamicMessageFactory();
+ // This option might be the default some day.
+ message_factory->SetDelegateToGeneratedFactory(true);
+ factory->message_factory = message_factory;
+
+ factory->pool = pool;
+ // TODO(amauryfa): When the MessageFactory is not created from the
+ // DescriptorPool this reference should be owned, not borrowed.
+ // Py_INCREF(pool);
+
+ factory->classes_by_descriptor = new PyMessageFactory::ClassesByMessageMap();
+
+ return factory;
+}
+
+PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
+ static char* kwlist[] = {"pool", 0};
+ PyObject* pool = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &pool)) {
+ return NULL;
+ }
+ ScopedPyObjectPtr owned_pool;
+ if (pool == NULL || pool == Py_None) {
+ owned_pool.reset(PyObject_CallFunction(
+ reinterpret_cast<PyObject*>(&PyDescriptorPool_Type), NULL));
+ if (owned_pool == NULL) {
+ return NULL;
+ }
+ pool = owned_pool.get();
+ } else {
+ if (!PyObject_TypeCheck(pool, &PyDescriptorPool_Type)) {
+ PyErr_Format(PyExc_TypeError, "Expected a DescriptorPool, got %s",
+ pool->ob_type->tp_name);
+ return NULL;
+ }
+ }
+
+ return reinterpret_cast<PyObject*>(
+ NewMessageFactory(type, reinterpret_cast<PyDescriptorPool*>(pool)));
+}
+
+static void Dealloc(PyMessageFactory* self) {
+ // TODO(amauryfa): When the MessageFactory is not created from the
+ // DescriptorPool this reference should be owned, not borrowed.
+ // Py_CLEAR(self->pool);
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ for (iterator it = self->classes_by_descriptor->begin();
+ it != self->classes_by_descriptor->end(); ++it) {
+ Py_DECREF(it->second);
+ }
+ delete self->classes_by_descriptor;
+ delete self->message_factory;
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+// Add a message class to our database.
+int RegisterMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class) {
+ Py_INCREF(message_class);
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ std::pair<iterator, bool> ret = self->classes_by_descriptor->insert(
+ std::make_pair(message_descriptor, message_class));
+ if (!ret.second) {
+ // Update case: DECREF the previous value.
+ Py_DECREF(ret.first->second);
+ ret.first->second = message_class;
+ }
+ return 0;
+}
+
+// Retrieve the message class added to our database.
+CMessageClass* GetMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor) {
+ typedef PyMessageFactory::ClassesByMessageMap::iterator iterator;
+ iterator ret = self->classes_by_descriptor->find(message_descriptor);
+ if (ret == self->classes_by_descriptor->end()) {
+ PyErr_Format(PyExc_TypeError, "No message class registered for '%s'",
+ message_descriptor->full_name().c_str());
+ return NULL;
+ } else {
+ return ret->second;
+ }
+}
+
+static PyMethodDef Methods[] = {
+ {NULL}};
+
+static PyObject* GetPool(PyMessageFactory* self, void* closure) {
+ Py_INCREF(self->pool);
+ return reinterpret_cast<PyObject*>(self->pool);
+}
+
+static PyGetSetDef Getters[] = {
+ {"pool", (getter)GetPool, NULL, "DescriptorPool"},
+ {NULL}
+};
+
+} // namespace message_factory
+
+PyTypeObject PyMessageFactory_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME
+ ".MessageFactory", // tp_name
+ sizeof(PyMessageFactory), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)message_factory::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
+ "A static Message Factory", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ message_factory::Methods, // tp_methods
+ 0, // tp_members
+ message_factory::Getters, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ 0, // tp_alloc
+ message_factory::New, // tp_new
+ PyObject_Del, // tp_free
+};
+
+bool InitMessageFactory() {
+ if (PyType_Ready(&PyMessageFactory_Type) < 0) {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h
new file mode 100644
index 00000000..07cccbfb
--- /dev/null
+++ b/python/google/protobuf/pyext/message_factory.h
@@ -0,0 +1,103 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
+
+#include <Python.h>
+
+#include <google/protobuf/stubs/hash.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/pyext/descriptor_pool.h>
+
+namespace google {
+namespace protobuf {
+class MessageFactory;
+
+namespace python {
+
+// The (meta) type of all Messages classes.
+struct CMessageClass;
+
+struct PyMessageFactory {
+ PyObject_HEAD
+
+ // DynamicMessageFactory used to create C++ instances of messages.
+ // This object cache the descriptors that were used, so the DescriptorPool
+ // needs to get rid of it before it can delete itself.
+ //
+ // Note: A C++ MessageFactory is different from the PyMessageFactory.
+ // The C++ one creates messages, when the Python one creates classes.
+ MessageFactory* message_factory;
+
+ // borrowed reference to a Python DescriptorPool.
+ // TODO(amauryfa): invert the dependency: the MessageFactory owns the
+ // DescriptorPool, not the opposite.
+ PyDescriptorPool* pool;
+
+ // Make our own mapping to retrieve Python classes from C++ descriptors.
+ //
+ // Descriptor pointers stored here are owned by the DescriptorPool above.
+ // Python references to classes are owned by this PyDescriptorPool.
+ typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap;
+ ClassesByMessageMap* classes_by_descriptor;
+};
+
+extern PyTypeObject PyMessageFactory_Type;
+
+namespace message_factory {
+
+// Creates a new MessageFactory instance.
+PyMessageFactory* NewMessageFactory(PyTypeObject* type, PyDescriptorPool* pool);
+
+// Registers a new Python class for the given message descriptor.
+// On error, returns -1 with a Python exception set.
+int RegisterMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor,
+ CMessageClass* message_class);
+
+// Retrieves the Python class registered with the given message descriptor.
+//
+// Returns a *borrowed* reference if found, otherwise returns NULL with an
+// exception set.
+CMessageClass* GetMessageClass(PyMessageFactory* self,
+ const Descriptor* message_descriptor);
+
+} // namespace message_factory
+
+// Initialize objects used by this module.
+// On error, returns false with a Python exception set.
+bool InitMessageFactory();
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index bb2f6db2..43a2bc12 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -364,7 +364,7 @@ static int SortPythonMessages(RepeatedCompositeContainer* self,
ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort"));
if (m == NULL)
return -1;
- if (PyObject_Call(m.get(), args, kwds) == NULL)
+ if (ScopedPyObjectPtr(PyObject_Call(m.get(), args, kwds)) == NULL)
return -1;
if (self->message != NULL) {
ReorderAttached(self);
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
index aa466abd..ecbef211 100644
--- a/python/google/protobuf/symbol_database.py
+++ b/python/google/protobuf/symbol_database.py
@@ -129,7 +129,8 @@ class SymbolDatabase(message_factory.MessageFactory):
Only messages already created and registered will be returned; (this is the
case for imported _pb2 modules)
- But unlike MessageFactory, this version also returns nested messages.
+ But unlike MessageFactory, this version also returns already defined nested
+ messages, but does not register any message extensions.
Args:
files: The file names to extract messages from.
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 06b79d77..90f6ce42 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -228,13 +228,13 @@ def _BuildMessageFromTypeName(type_name, descriptor_pool):
wasn't found matching type_name.
"""
# pylint: disable=g-import-not-at-top
- from google.protobuf import message_factory
- factory = message_factory.MessageFactory(descriptor_pool)
+ from google.protobuf import symbol_database
+ database = symbol_database.Default()
try:
message_descriptor = descriptor_pool.FindMessageTypeByName(type_name)
except KeyError:
return None
- message_type = factory.GetPrototype(message_descriptor)
+ message_type = database.GetPrototype(message_descriptor)
return message_type()
@@ -317,8 +317,7 @@ class _Printer(object):
# of this file to work around.
#
# TODO(haberman): refactor and optimize if this becomes an issue.
- entry_submsg = field.message_type._concrete_class(key=key,
- value=value[key])
+ entry_submsg = value.GetEntryClass()(key=key, value=value[key])
self.PrintField(field, entry_submsg)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
for element in value:
@@ -749,8 +748,7 @@ class _Parser(object):
if field.is_extension:
sub_message = message.Extensions[field].add()
elif is_map_entry:
- # pylint: disable=protected-access
- sub_message = field.message_type._concrete_class()
+ sub_message = getattr(message, field.name).GetEntryClass()()
else:
sub_message = getattr(message, field.name).add()
else:
@@ -1448,9 +1446,9 @@ def ParseBool(text):
Raises:
ValueError: If text is not a valid boolean.
"""
- if text in ('true', 't', '1'):
+ if text in ('true', 't', '1', 'True'):
return True
- elif text in ('false', 'f', '0'):
+ elif text in ('false', 'f', '0', 'False'):
return False
else:
raise ValueError('Expected "true" or "false".')